Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
8293100a
Commit
8293100a
authored
Jan 16, 2025
by
luopl
Browse files
update to 0.9.2.dev0
parent
2778a3d0
Changes
124
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1411 additions
and
530 deletions
+1411
-530
src/llamafactory/api/chat.py
src/llamafactory/api/chat.py
+1
-1
src/llamafactory/chat/hf_engine.py
src/llamafactory/chat/hf_engine.py
+21
-4
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+20
-16
src/llamafactory/cli.py
src/llamafactory/cli.py
+6
-2
src/llamafactory/data/collator.py
src/llamafactory/data/collator.py
+60
-1
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+2
-2
src/llamafactory/data/formatter.py
src/llamafactory/data/formatter.py
+12
-18
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+14
-10
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+316
-130
src/llamafactory/data/preprocess.py
src/llamafactory/data/preprocess.py
+2
-2
src/llamafactory/data/processors/pretrain.py
src/llamafactory/data/processors/pretrain.py
+5
-0
src/llamafactory/data/processors/unsupervised.py
src/llamafactory/data/processors/unsupervised.py
+2
-0
src/llamafactory/data/template.py
src/llamafactory/data/template.py
+254
-121
src/llamafactory/data/tool_utils.py
src/llamafactory/data/tool_utils.py
+189
-20
src/llamafactory/eval/evaluator.py
src/llamafactory/eval/evaluator.py
+1
-1
src/llamafactory/extras/constants.py
src/llamafactory/extras/constants.py
+449
-177
src/llamafactory/extras/env.py
src/llamafactory/extras/env.py
+1
-1
src/llamafactory/extras/logging.py
src/llamafactory/extras/logging.py
+4
-4
src/llamafactory/extras/misc.py
src/llamafactory/extras/misc.py
+44
-20
src/llamafactory/extras/packages.py
src/llamafactory/extras/packages.py
+8
-0
No files found.
src/llamafactory/api/chat.py
View file @
8293100a
...
...
@@ -168,7 +168,7 @@ async def create_chat_completion_response(
if
isinstance
(
result
,
list
):
tool_calls
=
[]
for
tool
in
result
:
function
=
Function
(
name
=
tool
[
0
]
,
arguments
=
tool
[
1
]
)
function
=
Function
(
name
=
tool
.
name
,
arguments
=
tool
.
arguments
)
tool_calls
.
append
(
FunctionCall
(
id
=
f
"call_
{
uuid
.
uuid4
().
hex
}
"
,
function
=
function
))
response_message
=
ChatCompletionMessage
(
role
=
Role
.
ASSISTANT
,
tool_calls
=
tool_calls
)
...
...
src/llamafactory/chat/hf_engine.py
View file @
8293100a
...
...
@@ -63,7 +63,7 @@ class HuggingfaceEngine(BaseEngine):
try
:
asyncio
.
get_event_loop
()
except
RuntimeError
:
logger
.
warning_once
(
"There is no current event loop, creating a new one."
)
logger
.
warning_
rank0_
once
(
"There is no current event loop, creating a new one."
)
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
...
...
@@ -133,7 +133,7 @@ class HuggingfaceEngine(BaseEngine):
if
repetition_penalty
is
not
None
else
generating_args
[
"repetition_penalty"
],
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
generating_args
[
"length_penalty"
],
eos_token_id
=
[
tokenizer
.
eos
_token_id
]
+
tokenizer
.
additional_special_tokens_ids
,
eos_token_id
=
template
.
get_stop
_token_id
s
(
tokenizer
)
,
pad_token_id
=
tokenizer
.
pad_token_id
,
)
)
...
...
@@ -168,11 +168,21 @@ class HuggingfaceEngine(BaseEngine):
for
key
,
value
in
mm_inputs
.
items
():
if
isinstance
(
value
,
list
)
and
all
(
isinstance
(
v
,
torch
.
Tensor
)
for
v
in
value
):
# for pixtral inputs
value
=
torch
.
stack
(
value
)
# assume they have same sizes
elif
isinstance
(
value
,
list
)
and
all
(
isinstance
(
v
,
list
)
for
v
in
value
):
# for minicpmv inputs
value
=
torch
.
stack
([
torch
.
stack
(
v
)
for
v
in
value
])
elif
not
isinstance
(
value
,
torch
.
Tensor
):
value
=
torch
.
tensor
(
value
)
if
torch
.
is_floating_point
(
value
):
# cast data dtype for paligemma
value
=
value
.
to
(
model
.
dtype
)
gen_kwargs
[
key
]
=
value
.
to
(
model
.
device
)
if
getattr
(
model
.
config
,
"model_type"
,
None
)
in
[
"minicpmv"
,
"minicpmo"
]:
gen_kwargs
[
"input_ids"
]
=
inputs
del
gen_kwargs
[
"image_sizes"
]
gen_kwargs
[
"tokenizer"
]
=
tokenizer
return
gen_kwargs
,
prompt_length
@
staticmethod
...
...
@@ -204,8 +214,13 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs
,
)
generate_output
=
model
.
generate
(
**
gen_kwargs
)
if
isinstance
(
generate_output
,
tuple
):
generate_output
=
generate_output
[
1
][
0
]
# post-process the minicpm_o output
response_ids
=
generate_output
[:,
prompt_length
:]
response
=
tokenizer
.
batch_decode
(
response_ids
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
response
=
tokenizer
.
batch_decode
(
response_ids
,
skip_special_tokens
=
generating_args
[
"skip_special_tokens"
],
clean_up_tokenization_spaces
=
True
)
results
=
[]
for
i
in
range
(
len
(
response
)):
eos_index
=
(
response_ids
[
i
]
==
tokenizer
.
eos_token_id
).
nonzero
()
...
...
@@ -249,7 +264,9 @@ class HuggingfaceEngine(BaseEngine):
videos
,
input_kwargs
,
)
streamer
=
TextIteratorStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
True
)
streamer
=
TextIteratorStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
generating_args
[
"skip_special_tokens"
]
)
gen_kwargs
[
"streamer"
]
=
streamer
thread
=
Thread
(
target
=
model
.
generate
,
kwargs
=
gen_kwargs
,
daemon
=
True
)
thread
.
start
()
...
...
src/llamafactory/chat/vllm_engine.py
View file @
8293100a
...
...
@@ -19,7 +19,7 @@ from typing_extensions import override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
IMAGE_PLACEHOLDER
from
..extras.constants
import
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.misc
import
get_device_count
from
..extras.packages
import
is_pillow_available
,
is_vllm_available
from
..model
import
load_config
,
load_tokenizer
...
...
@@ -67,11 +67,12 @@ class VllmEngine(BaseEngine):
self
.
processor
=
tokenizer_module
[
"processor"
]
self
.
tokenizer
.
padding_side
=
"left"
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
template
.
mm_plugin
.
expand_mm_tokens
=
False
# for vllm generate
self
.
generating_args
=
generating_args
.
to_dict
()
engine_args
=
{
"model"
:
model_args
.
model_name_or_path
,
"trust_remote_code"
:
Tru
e
,
"trust_remote_code"
:
model_args
.
trust_remote_cod
e
,
"download_dir"
:
model_args
.
cache_dir
,
"dtype"
:
model_args
.
infer_dtype
,
"max_model_len"
:
model_args
.
vllm_maxlen
,
...
...
@@ -83,6 +84,9 @@ class VllmEngine(BaseEngine):
"enable_lora"
:
model_args
.
adapter_name_or_path
is
not
None
,
"max_lora_rank"
:
model_args
.
vllm_max_lora_rank
,
}
if
self
.
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
engine_args
[
"limit_mm_per_prompt"
]
=
{
"image"
:
4
,
"video"
:
2
}
if
isinstance
(
model_args
.
vllm_config
,
dict
):
engine_args
.
update
(
model_args
.
vllm_config
)
...
...
@@ -108,19 +112,21 @@ class VllmEngine(BaseEngine):
**
input_kwargs
,
)
->
AsyncIterator
[
"RequestOutput"
]:
request_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
]}
if
images
is
not
None
:
mm_input_dict
.
update
({
"images"
:
images
,
"imglens"
:
[
len
(
images
)]})
if
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
*
len
(
images
)
+
messages
[
0
][
"content"
]
if
self
.
template
.
mm_plugin
.
__class__
.
__name__
==
"Qwen2vlPlugin"
:
# temporary solution
image_str
=
f
"<|vision_start|>
{
self
.
template
.
mm_plugin
.
image_token
}
<|vision_end|>"
else
:
image_str
=
self
.
template
.
mm_plugin
.
image_token
or
""
if
videos
is
not
None
:
mm_input_dict
.
update
({
"videos"
:
videos
,
"vidlens"
:
[
len
(
videos
)]})
if
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
)
:
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
*
len
(
videos
)
+
messages
[
0
][
"content"
]
paired
_messages
=
[
{
"role"
:
message
[
"role"
],
"content"
:
message
[
"content"
].
replace
(
IMAGE_PLACEHOLDER
,
image_str
)}
for
message
in
messages
]
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
messages
=
self
.
template
.
mm_plugin
.
process
_messages
(
messages
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
self
.
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
...
...
@@ -162,13 +168,13 @@ class VllmEngine(BaseEngine):
top_p
=
(
top_p
if
top_p
is
not
None
else
self
.
generating_args
[
"top_p"
])
or
1.0
,
# top_p must > 0
top_k
=
top_k
if
top_k
is
not
None
else
self
.
generating_args
[
"top_k"
],
stop
=
stop
,
stop_token_ids
=
[
self
.
t
okenizer
.
eos
_token_id
]
+
self
.
tokenizer
.
additional_special_tokens_ids
,
stop_token_ids
=
self
.
t
emplate
.
get_stop
_token_id
s
(
self
.
tokenizer
)
,
max_tokens
=
max_tokens
,
skip_special_tokens
=
True
,
skip_special_tokens
=
self
.
generating_args
[
"skip_special_tokens"
]
,
)
if
images
is
not
None
:
# add image features
image_data
=
[]
multi_modal_data
=
{
"image"
:
[]
}
for
image
in
images
:
if
not
isinstance
(
image
,
(
str
,
ImageObject
)):
raise
ValueError
(
f
"Expected image input is a path or PIL.Image, but got
{
type
(
image
)
}
."
)
...
...
@@ -176,9 +182,7 @@ class VllmEngine(BaseEngine):
if
isinstance
(
image
,
str
):
image
=
Image
.
open
(
image
).
convert
(
"RGB"
)
image_data
.
append
(
image
)
multi_modal_data
=
{
"image"
:
image_data
}
multi_modal_data
[
"image"
].
append
(
image
)
else
:
multi_modal_data
=
None
...
...
src/llamafactory/cli.py
View file @
8293100a
...
...
@@ -24,7 +24,7 @@ from .chat.chat_model import run_chat
from
.eval.evaluator
import
run_eval
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.misc
import
get_device_count
from
.extras.misc
import
get_device_count
,
use_ray
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
...
...
@@ -87,7 +87,7 @@ def main():
export_model
()
elif
command
==
Command
.
TRAIN
:
force_torchrun
=
os
.
getenv
(
"FORCE_TORCHRUN"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
if
force_torchrun
or
get_device_count
()
>
1
:
if
force_torchrun
or
(
get_device_count
()
>
1
and
not
use_ray
())
:
master_addr
=
os
.
getenv
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
str
(
random
.
randint
(
20001
,
29999
)))
logger
.
info_rank0
(
f
"Initializing distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
...
...
@@ -120,3 +120,7 @@ def main():
print
(
USAGE
)
else
:
raise
NotImplementedError
(
f
"Unknown command:
{
command
}
."
)
if
__name__
==
"__main__"
:
main
()
src/llamafactory/data/collator.py
View file @
8293100a
...
...
@@ -19,8 +19,16 @@ from dataclasses import dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Literal
,
Optional
,
Sequence
import
torch
import
torch.nn.functional
as
F
from
transformers
import
DataCollatorForSeq2Seq
from
..extras.constants
import
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
from
..extras.packages
import
is_pillow_available
if
is_pillow_available
():
from
PIL
import
Image
if
TYPE_CHECKING
:
from
transformers
import
ProcessorMixin
...
...
@@ -72,12 +80,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r
"""
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels and
image
s.
Features should contain input_ids, attention_mask, labels
,
and
optionally contain images and video
s.
"""
template
:
Optional
[
"Template"
]
=
None
processor
:
Optional
[
"ProcessorMixin"
]
=
None
def
__post_init__
(
self
):
if
self
.
template
is
None
:
raise
ValueError
(
"Template is required for MultiModalDataCollator."
)
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
"torch.Tensor"
]:
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_input_ids
=
[],
[],
[],
[],
[]
for
feature
in
features
:
...
...
@@ -89,6 +101,29 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_vidlens
.
append
(
len
(
videos
))
batch_input_ids
.
append
(
feature
[
"input_ids"
])
if
(
self
.
processor
is
not
None
and
sum
(
batch_imglens
)
==
0
and
sum
(
batch_vidlens
)
==
0
):
# avoid process hanging in zero3/fsdp case
fake_messages
=
[{
"role"
:
"user"
,
"content"
:
IMAGE_PLACEHOLDER
}]
fake_images
=
[
Image
.
new
(
"RGB"
,
(
64
,
64
),
(
255
,
255
,
255
))]
fake_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
fake_messages
,
fake_images
,
[],
self
.
processor
)
fake_input_ids
=
self
.
tokenizer
.
encode
(
fake_messages
[
0
][
"content"
],
add_special_tokens
=
False
)
fake_input_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
fake_input_ids
,
None
,
fake_images
,
[],
self
.
tokenizer
,
self
.
processor
)
if
self
.
tokenizer
.
padding_side
==
"right"
:
features
[
0
][
"input_ids"
]
=
features
[
0
][
"input_ids"
]
+
fake_input_ids
features
[
0
][
"attention_mask"
]
=
features
[
0
][
"attention_mask"
]
+
[
0
]
*
len
(
fake_input_ids
)
features
[
0
][
"labels"
]
=
features
[
0
][
"labels"
]
+
[
IGNORE_INDEX
]
*
len
(
fake_input_ids
)
else
:
features
[
0
][
"input_ids"
]
=
fake_input_ids
+
features
[
0
][
"input_ids"
]
features
[
0
][
"attention_mask"
]
=
[
0
]
*
len
(
fake_input_ids
)
+
features
[
0
][
"attention_mask"
]
features
[
0
][
"labels"
]
=
[
IGNORE_INDEX
]
*
len
(
fake_input_ids
)
+
features
[
0
][
"labels"
]
batch_images
=
fake_images
batch_imglens
[
0
]
=
1
batch_input_ids
[
0
]
=
features
[
0
][
"input_ids"
]
mm_inputs
=
self
.
template
.
mm_plugin
.
get_mm_inputs
(
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_input_ids
,
self
.
processor
)
...
...
@@ -98,10 +133,30 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
feature
[
"token_type_ids"
]
=
token_type_ids
[
i
]
features
:
Dict
[
str
,
"torch.Tensor"
]
=
super
().
__call__
(
features
)
if
self
.
model
is
not
None
and
hasattr
(
self
.
model
,
"get_rope_index"
):
# for qwen2vl mrope
features
[
"position_ids"
],
features
[
"rope_deltas"
]
=
self
.
model
.
get_rope_index
(
input_ids
=
features
[
"input_ids"
],
image_grid_thw
=
mm_inputs
.
get
(
"image_grid_thw"
,
None
),
video_grid_thw
=
mm_inputs
.
get
(
"video_grid_thw"
,
None
),
attention_mask
=
features
[
"attention_mask"
],
)
if
"cross_attention_mask"
in
mm_inputs
:
# for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask
=
mm_inputs
.
pop
(
"cross_attention_mask"
)
seq_len
=
features
[
"input_ids"
].
size
(
1
)
orig_len
=
cross_attention_mask
.
size
(
1
)
mm_inputs
[
"cross_attention_mask"
]
=
F
.
pad
(
cross_attention_mask
,
(
0
,
0
,
0
,
0
,
0
,
seq_len
-
orig_len
))
features
.
update
(
mm_inputs
)
if
isinstance
(
features
.
get
(
"pixel_values"
),
list
):
# for pixtral inputs
features
=
features
.
data
# use default_collate() instead of BatchEncoding.to()
if
"image_bound"
in
features
:
# for minicpmv inputs
bsz
,
seq_length
=
features
[
"input_ids"
].
shape
features
[
"position_ids"
]
=
torch
.
arange
(
seq_length
).
long
().
repeat
(
bsz
,
1
)
return
{
"data"
:
features
,
"input_ids"
:
features
[
"input_ids"
],
"labels"
:
features
[
"labels"
]}
return
features
...
...
@@ -120,6 +175,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
if
self
.
block_diag_attn
and
self
.
attn_implementation
!=
"flash_attention_2"
:
features
[
"attention_mask"
]
=
prepare_4d_attention_mask
(
features
[
"attention_mask"
],
self
.
compute_dtype
)
for
key
,
value
in
features
.
items
():
# cast data dtype for paligemma
if
torch
.
is_tensor
(
value
)
and
torch
.
is_floating_point
(
value
):
features
[
key
]
=
value
.
to
(
self
.
compute_dtype
)
return
features
...
...
src/llamafactory/data/data_utils.py
View file @
8293100a
...
...
@@ -56,12 +56,12 @@ def merge_dataset(
return
all_datasets
[
0
]
elif
data_args
.
mix_strategy
==
"concat"
:
if
data_args
.
streaming
:
logger
.
warning_once
(
"The samples between different datasets will not be mixed in streaming mode."
)
logger
.
warning_
rank0_
once
(
"The samples between different datasets will not be mixed in streaming mode."
)
return
concatenate_datasets
(
all_datasets
)
elif
data_args
.
mix_strategy
.
startswith
(
"interleave"
):
if
not
data_args
.
streaming
:
logger
.
warning_once
(
"We recommend using `mix_strategy=concat` in non-streaming mode."
)
logger
.
warning_
rank0_
once
(
"We recommend using `mix_strategy=concat` in non-streaming mode."
)
return
interleave_datasets
(
datasets
=
all_datasets
,
...
...
src/llamafactory/data/formatter.py
View file @
8293100a
...
...
@@ -16,16 +16,12 @@ import json
import
re
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Union
from
typing_extensions
import
override
from
.data_utils
import
SLOTS
from
.tool_utils
import
get_tool_utils
if
TYPE_CHECKING
:
from
.tool_utils
import
FunctionCall
from
.tool_utils
import
FunctionCall
,
get_tool_utils
@
dataclass
...
...
@@ -98,33 +94,31 @@ class StringFormatter(Formatter):
@
dataclass
class
FunctionFormatter
(
Formatter
):
def
__post_init__
(
self
):
self
.
slot
s
=
get_tool_utils
(
self
.
tool_format
)
.
get_function_slots
()
+
self
.
slots
self
.
tool_util
s
=
get_tool_utils
(
self
.
tool_format
)
@
override
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
content
=
kwargs
.
pop
(
"content"
)
functions
:
List
[
Tuple
[
str
,
str
]
]
=
[]
functions
:
List
[
"FunctionCall"
]
=
[]
try
:
tool_calls
=
json
.
loads
(
content
)
if
not
isinstance
(
tool_calls
,
list
):
# parallel function call
tool_calls
=
[
tool_calls
]
for
tool_call
in
tool_calls
:
functions
.
append
((
tool_call
[
"name"
],
json
.
dumps
(
tool_call
[
"arguments"
],
ensure_ascii
=
False
)))
functions
.
append
(
FunctionCall
(
tool_call
[
"name"
],
json
.
dumps
(
tool_call
[
"arguments"
],
ensure_ascii
=
False
))
)
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
f
"Invalid JSON format in function message:
{
str
([
content
])
}
"
)
# flat string
elements
=
[]
for
name
,
arguments
in
functions
:
for
slot
in
self
.
slots
:
if
isinstance
(
slot
,
str
):
slot
=
slot
.
replace
(
"{{name}}"
,
name
).
replace
(
"{{arguments}}"
,
arguments
)
elements
.
append
(
slot
)
elif
isinstance
(
slot
,
(
dict
,
set
)):
elements
.
append
(
slot
)
else
:
raise
RuntimeError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
slot
)
}
"
)
for
slot
in
self
.
slots
:
if
slot
==
"{{content}}"
:
elements
+=
self
.
tool_utils
.
function_formatter
(
functions
)
else
:
elements
.
append
(
slot
)
return
elements
...
...
src/llamafactory/data/loader.py
View file @
8293100a
...
...
@@ -18,11 +18,10 @@ from typing import TYPE_CHECKING, Dict, Literal, Optional, Sequence, Union
import
numpy
as
np
from
datasets
import
DatasetDict
,
load_dataset
,
load_from_disk
from
transformers.utils.versions
import
require_version
from
..extras
import
logging
from
..extras.constants
import
FILEEXT2TYPE
from
..extras.misc
import
has_tokenized_data
from
..extras.misc
import
check_version
,
has_tokenized_data
from
.aligner
import
align_dataset
from
.data_utils
import
merge_dataset
,
split_dataset
from
.parser
import
get_dataset_list
...
...
@@ -84,7 +83,7 @@ def _load_single_dataset(
raise
NotImplementedError
(
f
"Unknown load type:
{
dataset_attr
.
load_from
}
."
)
if
dataset_attr
.
load_from
==
"ms_hub"
:
require
_version
(
"modelscope>=1.11.0"
,
"To fix: pip install modelscope>=1.11.0"
)
check
_version
(
"modelscope>=1.11.0"
,
mandatory
=
True
)
from
modelscope
import
MsDataset
# type: ignore
from
modelscope.utils.config_ds
import
MS_DATASETS_CACHE
# type: ignore
...
...
@@ -103,7 +102,7 @@ def _load_single_dataset(
dataset
=
dataset
.
to_hf_dataset
()
elif
dataset_attr
.
load_from
==
"om_hub"
:
require
_version
(
"openmind>=0.8.0"
,
"To fix: pip install openmind>=0.8.0"
)
check
_version
(
"openmind>=0.8.0"
,
mandatory
=
True
)
from
openmind
import
OmDataset
# type: ignore
from
openmind.utils.hub
import
OM_DATASETS_CACHE
# type: ignore
...
...
@@ -128,7 +127,8 @@ def _load_single_dataset(
cache_dir
=
model_args
.
cache_dir
,
token
=
model_args
.
hf_hub_token
,
streaming
=
data_args
.
streaming
,
trust_remote_code
=
True
,
num_proc
=
data_args
.
preprocessing_num_workers
,
trust_remote_code
=
model_args
.
trust_remote_code
,
)
if
dataset_attr
.
num_samples
is
not
None
and
not
data_args
.
streaming
:
...
...
@@ -238,15 +238,19 @@ def get_dataset(
if
data_args
.
tokenized_path
is
not
None
:
if
has_tokenized_data
(
data_args
.
tokenized_path
):
logger
.
warning_rank0
(
"Loading dataset from disk will ignore other data arguments."
)
dataset_dict
:
"DatasetDict"
=
load_from_disk
(
data_args
.
tokenized_path
)
tokenized_data
:
Union
[
"Dataset"
,
"DatasetDict"
]
=
load_from_disk
(
data_args
.
tokenized_path
)
logger
.
info_rank0
(
f
"Loaded tokenized dataset from
{
data_args
.
tokenized_path
}
."
)
dataset_module
:
Dict
[
str
,
"Dataset"
]
=
{}
if
"train"
in
dataset_dict
:
dataset_module
[
"train_dataset"
]
=
dataset_dict
[
"train"
]
if
isinstance
(
tokenized_data
,
DatasetDict
):
if
"train"
in
tokenized_data
:
dataset_module
[
"train_dataset"
]
=
tokenized_data
[
"train"
]
if
"validation"
in
tokenized_data
:
dataset_module
[
"eval_dataset"
]
=
tokenized_data
[
"validation"
]
if
"validation"
in
dataset_dict
:
dataset_module
[
"
eval
_dataset"
]
=
dataset_dict
[
"validation"
]
else
:
# Dataset
dataset_module
[
"
train
_dataset"
]
=
tokenized_data
if
data_args
.
streaming
:
dataset_module
=
{
k
:
v
.
to_iterable_dataset
()
for
k
,
v
in
dataset_module
.
items
()}
...
...
src/llamafactory/data/mm_plugin.py
View file @
8293100a
import
math
import
re
from
copy
import
deepcopy
from
io
import
BytesIO
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
TypedDict
,
Union
...
...
@@ -62,6 +63,7 @@ class BasePlugin:
def
__init__
(
self
,
image_token
:
Optional
[
str
],
video_token
:
Optional
[
str
])
->
None
:
self
.
image_token
=
image_token
self
.
video_token
=
video_token
self
.
expand_mm_tokens
=
True
def
_validate_input
(
self
,
...
...
@@ -72,10 +74,14 @@ class BasePlugin:
Validates if this model accepts the input modalities.
"""
if
len
(
images
)
!=
0
and
self
.
image_token
is
None
:
raise
ValueError
(
"This model does not support image input."
)
raise
ValueError
(
"This model does not support image input. Please check whether the correct `template` is used."
)
if
len
(
videos
)
!=
0
and
self
.
video_token
is
None
:
raise
ValueError
(
"This model does not support video input."
)
raise
ValueError
(
"This model does not support video input. Please check whether the correct `template` is used."
)
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
**
kwargs
)
->
"ImageObject"
:
r
"""
...
...
@@ -241,7 +247,7 @@ class BasePlugin:
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
batch_ids:
input
ids of samples, shape (batch_size, seq_len)
batch_ids:
token
ids of
input
samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self
.
_validate_input
(
images
,
videos
)
...
...
@@ -259,13 +265,13 @@ class LlavaPlugin(BasePlugin):
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
num_image_tokens
=
0
image_seqlen
=
getattr
(
processor
,
"image_seqlen"
)
image_seqlen
=
getattr
(
processor
,
"image_seqlen"
)
if
self
.
expand_mm_tokens
else
1
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
num_image_tokens
+=
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
...
...
@@ -310,14 +316,16 @@ class LlavaNextPlugin(BasePlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
image_size
=
next
(
image_sizes
)
orig_height
,
orig_width
=
image_size
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
getattr
(
processor
,
"vision_feature_select_strategy"
)
==
"default"
:
image_seqlen
-=
1
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
getattr
(
processor
,
"vision_feature_select_strategy"
)
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
=
1
num_image_tokens
+=
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
...
...
@@ -359,14 +367,16 @@ class LlavaNextVideoPlugin(BasePlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
image_size
=
next
(
image_sizes
)
orig_height
,
orig_width
=
image_size
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
getattr
(
processor
,
"vision_feature_select_strategy"
)
==
"default"
:
image_seqlen
-=
1
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
getattr
(
processor
,
"vision_feature_select_strategy"
)
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
=
1
num_image_tokens
+=
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
...
...
@@ -376,6 +386,7 @@ class LlavaNextVideoPlugin(BasePlugin):
num_frames
=
pixel_values_video
.
shape
[
0
]
# frame dim is always after batch dim
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
video_seqlen
=
image_seqlen
//
4
*
num_frames
# divide by 4 needed for avg pooling layer
video_seqlen
=
video_seqlen
if
self
.
expand_mm_tokens
else
1
for
message
in
messages
:
content
=
message
[
"content"
]
while
VIDEO_PLACEHOLDER
in
content
:
...
...
@@ -406,7 +417,7 @@ class LlavaNextVideoPlugin(BasePlugin):
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
class
PaliGemma
Plugin
(
BasePlugin
):
class
MiniCPMV
Plugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
...
...
@@ -417,12 +428,241 @@ class PaliGemmaPlugin(BasePlugin):
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
num_image_tokens
=
0
num_video_tokens
=
0
messages
=
deepcopy
(
messages
)
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
and
len
(
videos
)
!=
0
:
raise
ValueError
(
"MiniCPM-V model does not support input images and videos at the same time."
)
if
len
(
videos
)
!=
0
:
max_slice_nums
=
2
use_image_id
=
False
mm_inputs
=
self
.
_get_mm_inputs
([],
videos
,
processor
)
else
:
max_slice_nums
=
image_processor
.
max_slice_nums
use_image_id
=
image_processor
.
use_image_id
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
,
1
)
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
video_seqlen
=
len
(
mm_inputs
[
"pixel_values"
][
num_video_tokens
])
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{image}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
"(<image>./</image>)"
)
if
num_image_tokens
>
0
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
[],
processor
)
if
mm_inputs
:
pattern
=
"(<image>./</image>)"
image_sizes
=
mm_inputs
[
"image_sizes"
]
for
index
,
message
in
enumerate
(
messages
):
text
=
message
[
"content"
]
image_tags
=
re
.
findall
(
pattern
,
text
)
text_chunks
=
text
.
split
(
pattern
)
final_text
=
""
for
i
in
range
(
len
(
image_tags
)):
final_text
=
(
final_text
+
text_chunks
[
i
]
+
image_processor
.
get_slice_image_placeholder
(
image_sizes
[
0
][
i
],
i
,
max_slice_nums
,
use_image_id
)
)
final_text
+=
text_chunks
[
-
1
]
messages
[
index
][
"content"
]
=
final_text
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
@
override
def
_get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
processor
:
"ProcessorMixin"
,
**
kwargs
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_resolution
=
getattr
(
processor
,
"image_resolution"
,
512
*
512
),
)
if
"valid_image_nums_ls"
in
kwargs
:
valid_image_nums_ls
=
kwargs
[
"valid_image_nums_ls"
]
new_images
=
[]
idx
=
0
for
valid_image_nums
in
valid_image_nums_ls
:
new_images
.
append
(
images
[
idx
:
idx
+
valid_image_nums
])
idx
+=
valid_image_nums
images
=
new_images
image_inputs
=
image_processor
(
images
,
do_pad
=
True
,
max_slice_nums
=
image_processor
.
max_slice_nums
,
return_tensors
=
"pt"
)
mm_inputs
.
update
(
image_inputs
)
if
len
(
videos
)
!=
0
:
videos
=
self
.
_regularize_videos
(
videos
,
image_resolution
=
getattr
(
processor
,
"video_resolution"
,
128
*
128
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
64
),
)
video_inputs
=
image_processor
(
videos
,
do_pad
=
True
,
max_slice_nums
=
2
,
return_tensors
=
"pt"
)
mm_inputs
.
update
(
video_inputs
)
return
mm_inputs
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
image_bounds_list
=
[]
valid_image_nums_ls
=
[]
for
input_ids
in
batch_ids
:
input_ids_
=
torch
.
tensor
(
input_ids
)
start_cond
=
(
input_ids_
==
processor
.
tokenizer
.
im_start_id
)
|
(
input_ids_
==
processor
.
tokenizer
.
slice_start_id
)
end_cond
=
(
input_ids_
==
processor
.
tokenizer
.
im_end_id
)
|
(
input_ids_
==
processor
.
tokenizer
.
slice_end_id
)
image_start_tokens
=
torch
.
where
(
start_cond
)[
0
]
image_start_tokens
+=
1
image_end_tokens
=
torch
.
where
(
end_cond
)[
0
]
valid_image_nums
=
max
(
len
(
image_start_tokens
),
len
(
image_end_tokens
))
valid_image_nums_ls
.
append
(
valid_image_nums
)
image_bounds
=
torch
.
hstack
(
[
image_start_tokens
[:
valid_image_nums
].
unsqueeze
(
-
1
),
image_end_tokens
[:
valid_image_nums
].
unsqueeze
(
-
1
),
]
)
image_bounds_list
.
append
(
image_bounds
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
,
valid_image_nums_ls
=
valid_image_nums_ls
)
mm_inputs
.
update
({
"image_bound"
:
image_bounds_list
})
return
mm_inputs
class
MllamaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
num_image_tokens
+=
content
.
count
(
IMAGE_PLACEHOLDER
)
message
[
"content"
]
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
self
.
image_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
override
def
_get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
processor
:
"ProcessorMixin"
,
**
kwargs
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
r
"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
imglens
:
List
[
int
]
=
kwargs
[
"imglens"
]
images
=
self
.
_regularize_images
(
images
,
image_resolution
=
getattr
(
processor
,
"image_resolution"
,
512
*
512
))
batch_images
=
[]
for
image_length
in
imglens
:
batch_images
.
append
(
images
[:
image_length
])
images
=
images
[
image_length
:]
return
image_processor
(
batch_images
,
return_tensors
=
"pt"
)
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
,
imglens
=
imglens
)
num_tiles
=
mm_inputs
.
pop
(
"num_tiles"
)
image_token_id
=
getattr
(
processor
,
"image_token_id"
)
max_image_tiles
=
getattr
(
processor
.
image_processor
,
"max_image_tiles"
)
cross_attention_token_mask
=
[
get_cross_attention_token_mask
(
input_ids
,
image_token_id
)
for
input_ids
in
batch_ids
]
mm_inputs
[
"cross_attention_mask"
]
=
torch
.
from_numpy
(
convert_sparse_cross_attention_mask_to_dense
(
cross_attention_token_mask
,
num_tiles
=
num_tiles
,
max_num_tiles
=
max_image_tiles
,
length
=
max
(
len
(
input_ids
)
for
input_ids
in
batch_ids
),
)
)
# shape: (batch_size, length, max_num_images, max_num_tiles)
return
mm_inputs
class
PaliGemmaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
""
)
...
...
@@ -443,7 +683,7 @@ class PaliGemmaPlugin(BasePlugin):
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
self
.
_validate_input
(
images
,
videos
)
num_images
=
len
(
images
)
image_seqlen
=
num_images
*
getattr
(
processor
,
"image_seqlen"
)
image_seqlen
=
num_images
*
getattr
(
processor
,
"image_seqlen"
)
if
self
.
expand_mm_tokens
else
0
# skip mm token
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
self
.
image_token
)
input_ids
=
[
image_token_id
]
*
image_seqlen
+
input_ids
if
labels
is
not
None
:
...
...
@@ -493,14 +733,18 @@ class PixtralPlugin(BasePlugin):
if
image_input_sizes
is
None
:
raise
ValueError
(
"Cannot get image input sizes."
)
image_size
=
image_input_sizes
[
0
][
num_image_tokens
]
height
,
width
=
image_size
num_height_tokens
=
height
//
patch_size
num_width_tokens
=
width
//
patch_size
replace_tokens
=
[[
image_token
]
*
num_width_tokens
+
[
image_break_token
]]
*
num_height_tokens
replace_tokens
=
[
item
for
sublist
in
replace_tokens
for
item
in
sublist
]
# flatten list
replace_tokens
[
-
1
]
=
image_end_token
replace_str
=
""
.
join
(
replace_tokens
)
if
self
.
expand_mm_tokens
:
image_size
=
image_input_sizes
[
0
][
num_image_tokens
]
height
,
width
=
image_size
num_height_tokens
=
height
//
patch_size
num_width_tokens
=
width
//
patch_size
replace_tokens
=
[[
image_token
]
*
num_width_tokens
+
[
image_break_token
]]
*
num_height_tokens
replace_tokens
=
[
item
for
sublist
in
replace_tokens
for
item
in
sublist
]
# flatten list
replace_tokens
[
-
1
]
=
image_end_token
replace_str
=
""
.
join
(
replace_tokens
)
else
:
replace_str
=
image_token
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
replace_str
,
1
)
num_image_tokens
+=
1
...
...
@@ -549,10 +793,27 @@ class Qwen2vlPlugin(BasePlugin):
return
image
@
override
def
_get_video_sample_frames
(
self
,
video_stream
:
"Stream"
,
**
kwargs
)
->
int
:
sample_frames
=
super
().
_get_video_sample_frames
(
video_stream
,
**
kwargs
)
sample_frames
=
sample_frames
//
2
*
2
return
sample_frames
def
_regularize_videos
(
self
,
videos
:
Sequence
[
"VideoInput"
],
**
kwargs
)
->
List
[
List
[
"ImageObject"
]]:
results
=
[]
for
video
in
videos
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
total_frames
=
video_stream
.
frames
sample_frames
=
self
.
_get_video_sample_frames
(
video_stream
,
**
kwargs
)
sample_indices
=
np
.
linspace
(
0
,
total_frames
-
1
,
sample_frames
).
astype
(
np
.
int32
)
frames
:
List
[
"ImageObject"
]
=
[]
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
if
frame_idx
in
sample_indices
:
frames
.
append
(
frame
.
to_image
())
if
len
(
frames
)
%
2
!=
0
:
# qwen2-vl requires even number of frames
frames
.
append
(
frames
[
-
1
])
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)
results
.
append
(
frames
)
return
results
@
override
def
process_messages
(
...
...
@@ -577,12 +838,9 @@ class Qwen2vlPlugin(BasePlugin):
if
num_image_tokens
>=
len
(
image_grid_thw
):
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
image_seqlen
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"<|vision_start|>{}<|vision_end|>"
.
format
(
self
.
image_token
*
(
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
)
),
1
,
IMAGE_PLACEHOLDER
,
f
"<|vision_start|>
{
self
.
image_token
*
image_seqlen
}
<|vision_end|>"
,
1
)
num_image_tokens
+=
1
...
...
@@ -590,12 +848,9 @@ class Qwen2vlPlugin(BasePlugin):
if
num_video_tokens
>=
len
(
video_grid_thw
):
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
video_seqlen
=
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
if
self
.
expand_mm_tokens
else
1
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"<|vision_start|>{}<|vision_end|>"
.
format
(
self
.
video_token
*
(
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
)
),
1
,
VIDEO_PLACEHOLDER
,
f
"<|vision_start|>
{
self
.
video_token
*
video_seqlen
}
<|vision_end|>"
,
1
)
num_video_tokens
+=
1
...
...
@@ -640,29 +895,32 @@ class VideoLlavaPlugin(BasePlugin):
has_images
=
"pixel_values_images"
in
mm_inputs
has_videos
=
"pixel_values_videos"
in
mm_inputs
if
has_images
or
has_videos
:
if
has_images
:
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_images"
)[
0
]))
num_frames
=
1
if
has_videos
:
pixel_values_video
=
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_videos"
)[
0
])
height
,
width
=
get_image_size
(
pixel_values_video
[
0
])
num_frames
=
pixel_values_video
.
shape
[
0
]
# frame dim is always after batch dim
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
+
1
video_seqlen
=
image_seqlen
*
num_frames
if
getattr
(
processor
,
"vision_feature_select_strategy"
)
==
"default"
:
image_seqlen
-=
1
if
self
.
expand_mm_tokens
:
if
has_images
:
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_images"
)[
0
]))
num_frames
=
1
if
has_videos
:
pixel_values_video
=
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_videos"
)[
0
])
height
,
width
=
get_image_size
(
pixel_values_video
[
0
])
num_frames
=
pixel_values_video
.
shape
[
0
]
# frame dim is always after batch dim
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
+
1
video_seqlen
=
image_seqlen
*
num_frames
if
getattr
(
processor
,
"vision_feature_select_strategy"
)
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
,
video_seqlen
=
1
,
1
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
num_image_tokens
+=
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
while
VIDEO_PLACEHOLDER
in
content
:
num_video_tokens
+=
1
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{video}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
content
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
)
...
...
@@ -689,89 +947,17 @@ class VideoLlavaPlugin(BasePlugin):
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
class
MllamaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
num_image_tokens
+=
content
.
count
(
IMAGE_PLACEHOLDER
)
message
[
"content"
]
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
self
.
image_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
override
def
_get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
processor
:
"ProcessorMixin"
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
r
"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
images
=
self
.
_regularize_images
(
images
,
image_resolution
=
getattr
(
processor
,
"image_resolution"
,
512
*
512
))
return
image_processor
([[
image
]
for
image
in
images
],
return_tensors
=
"pt"
)
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
if
len
(
images
)
!=
len
(
batch_ids
):
raise
ValueError
(
"Mllama only supports one image per sample."
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
num_tiles
=
mm_inputs
.
pop
(
"num_tiles"
)
image_token_id
=
getattr
(
processor
,
"image_token_id"
)
max_image_tiles
=
getattr
(
processor
.
image_processor
,
"max_image_tiles"
)
cross_attention_token_mask
=
[
get_cross_attention_token_mask
(
input_ids
,
image_token_id
)
for
input_ids
in
batch_ids
]
mm_inputs
[
"cross_attention_mask"
]
=
convert_sparse_cross_attention_mask_to_dense
(
cross_attention_token_mask
,
num_tiles
=
num_tiles
,
max_num_tiles
=
max_image_tiles
,
length
=
max
(
len
(
input_ids
)
for
input_ids
in
batch_ids
),
)
return
mm_inputs
PLUGINS
=
{
"base"
:
BasePlugin
,
"llava"
:
LlavaPlugin
,
"llava_next"
:
LlavaNextPlugin
,
"llava_next_video"
:
LlavaNextVideoPlugin
,
"minicpm_v"
:
MiniCPMVPlugin
,
"mllama"
:
MllamaPlugin
,
"paligemma"
:
PaliGemmaPlugin
,
"pixtral"
:
PixtralPlugin
,
"qwen2_vl"
:
Qwen2vlPlugin
,
"video_llava"
:
VideoLlavaPlugin
,
"mllama"
:
MllamaPlugin
,
}
...
...
src/llamafactory/data/preprocess.py
View file @
8293100a
...
...
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple
from
.processors.feedback
import
preprocess_feedback_dataset
from
.processors.pairwise
import
preprocess_pairwise_dataset
,
print_pairwise_dataset_example
from
.processors.pretrain
import
preprocess_pretrain_dataset
from
.processors.pretrain
import
preprocess_pretrain_dataset
,
print_pretrain_dataset_example
from
.processors.supervised
import
(
preprocess_packed_supervised_dataset
,
preprocess_supervised_dataset
,
...
...
@@ -47,7 +47,7 @@ def get_preprocess_and_print_func(
tokenizer
=
tokenizer
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_
unsupervised
_dataset_example
,
tokenizer
=
tokenizer
)
print_function
=
partial
(
print_
pretrain
_dataset_example
,
tokenizer
=
tokenizer
)
elif
stage
==
"sft"
and
not
do_generate
:
if
data_args
.
packing
:
if
data_args
.
neat_packing
:
# hack datasets to have int32 attention mask
...
...
src/llamafactory/data/processors/pretrain.py
View file @
8293100a
...
...
@@ -52,3 +52,8 @@ def preprocess_pretrain_dataset(
result
[
"input_ids"
][
i
][
0
]
=
tokenizer
.
bos_token_id
return
result
def
print_pretrain_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/processors/unsupervised.py
View file @
8293100a
...
...
@@ -100,3 +100,5 @@ def preprocess_unsupervised_dataset(
def
print_unsupervised_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
"labels:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"labels"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/template.py
View file @
8293100a
...
...
@@ -15,10 +15,10 @@
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
transformers.utils.versions
import
require_version
from
typing_extensions
import
override
from
..extras
import
logging
from
..extras.misc
import
check_version
from
.data_utils
import
Role
from
.formatter
import
EmptyFormatter
,
FunctionFormatter
,
StringFormatter
,
ToolFormatter
from
.mm_plugin
import
get_mm_plugin
...
...
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
from
..hparams
import
DataArguments
from
.formatter
import
SLOTS
,
Formatter
from
.mm_plugin
import
BasePlugin
from
.tool_utils
import
FunctionCall
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -43,7 +44,6 @@ class Template:
format_function
:
"Formatter"
format_observation
:
"Formatter"
format_tools
:
"Formatter"
format_separator
:
"Formatter"
format_prefix
:
"Formatter"
default_system
:
str
stop_words
:
List
[
str
]
...
...
@@ -83,12 +83,22 @@ class Template:
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
Tuple
[
str
,
str
]
]]:
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
r
"""
Extracts tool message.
"""
return
self
.
format_tools
.
extract
(
content
)
def
get_stop_token_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
)
->
List
[
int
]:
r
"""
Returns stop token ids.
"""
stop_token_ids
=
{
tokenizer
.
eos_token_id
}
for
token
in
self
.
stop_words
:
stop_token_ids
.
add
(
tokenizer
.
convert_tokens_to_ids
(
token
))
return
list
(
stop_token_ids
)
def
_encode
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
...
...
@@ -112,9 +122,6 @@ class Template:
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
elements
+=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))
if
i
>
0
and
i
%
2
==
0
:
elements
+=
self
.
format_separator
.
apply
()
if
message
[
"role"
]
==
Role
.
USER
.
value
:
elements
+=
self
.
format_user
.
apply
(
content
=
message
[
"content"
],
idx
=
str
(
i
//
2
))
elif
message
[
"role"
]
==
Role
.
ASSISTANT
.
value
:
...
...
@@ -179,9 +186,6 @@ class Llama2Template(Template):
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
system_text
=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))[
0
]
if
i
>
0
and
i
%
2
==
0
:
elements
+=
self
.
format_separator
.
apply
()
if
message
[
"role"
]
==
Role
.
USER
.
value
:
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
ASSISTANT
.
value
:
...
...
@@ -209,13 +213,12 @@ def _register_template(
format_function
:
Optional
[
"Formatter"
]
=
None
,
format_observation
:
Optional
[
"Formatter"
]
=
None
,
format_tools
:
Optional
[
"Formatter"
]
=
None
,
format_separator
:
Optional
[
"Formatter"
]
=
None
,
format_prefix
:
Optional
[
"Formatter"
]
=
None
,
default_system
:
str
=
""
,
stop_words
:
Sequence
[
str
]
=
[]
,
stop_words
:
Optional
[
Sequence
[
str
]
]
=
None
,
efficient_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
replace_jinja_template
:
bool
=
Tru
e
,
replace_jinja_template
:
bool
=
Fals
e
,
mm_plugin
:
"BasePlugin"
=
get_mm_plugin
(
name
=
"base"
),
)
->
None
:
r
"""
...
...
@@ -223,34 +226,28 @@ def _register_template(
To add the following chat template:
```
[HUMAN]:
user prompt here
[AI]:
model response here
[HUMAN]:
user prompt here
[AI]:
model response here
<s><user>user prompt here
<model>model response here</s>
<user>user prompt here
<model>model response here</s>
```
The corresponding code should be:
```
_register_template(
name="custom",
format_user=StringFormatter(slots=["
[HUMAN]:\n
{{content}}\n
[AI]:\n
"]),
format_
separator=Empty
Formatter(slots=["
\n
\n"]),
efficient_eos=True
,
format_user=StringFormatter(slots=["
<user>
{{content}}\n
<model>
"]),
format_
assistant=String
Formatter(slots=["
{{content}}</s>
\n"]),
format_prefix=EmptyFormatter("<s>")
,
)
```
"""
eos_slots
=
[]
if
efficient_eos
else
[{
"eos_token"
}]
template_class
=
Llama2Template
if
name
.
startswith
(
"llama2"
)
else
Template
template_class
=
Llama2Template
if
any
(
k
in
name
for
k
in
(
"llama2"
,
"mistral"
,
"pixtral"
))
else
Template
default_slots
=
[
"{{content}}"
]
if
efficient_eos
else
[
"{{content}}"
,
{
"eos_token"
}]
default_user_formatter
=
StringFormatter
(
slots
=
[
"{{content}}"
])
default_assistant_formatter
=
StringFormatter
(
slots
=
[
"{{content}}"
]
+
eos
_slots
)
default_function_formatter
=
FunctionFormatter
(
slots
=
eos
_slots
,
tool_format
=
"default"
)
default_assistant_formatter
=
StringFormatter
(
slots
=
default
_slots
)
default_function_formatter
=
FunctionFormatter
(
slots
=
default
_slots
,
tool_format
=
"default"
)
default_tool_formatter
=
ToolFormatter
(
tool_format
=
"default"
)
default_separator_formatter
=
EmptyFormatter
()
default_prefix_formatter
=
EmptyFormatter
()
TEMPLATES
[
name
]
=
template_class
(
format_user
=
format_user
or
default_user_formatter
,
...
...
@@ -259,10 +256,9 @@ def _register_template(
format_function
=
format_function
or
default_function_formatter
,
format_observation
=
format_observation
or
format_user
or
default_user_formatter
,
format_tools
=
format_tools
or
default_tool_formatter
,
format_separator
=
format_separator
or
default_separator_formatter
,
format_prefix
=
format_prefix
or
default_prefix_formatter
,
default_system
=
default_system
,
stop_words
=
stop_words
,
stop_words
=
stop_words
or
[]
,
efficient_eos
=
efficient_eos
,
replace_eos
=
replace_eos
,
replace_jinja_template
=
replace_jinja_template
,
...
...
@@ -343,9 +339,7 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
jinja_template
+=
"{{ "
+
user_message
+
" }}"
jinja_template
+=
"{% elif message['role'] == 'assistant' %}"
assistant_message
=
_convert_slots_to_jinja
(
template
.
format_assistant
.
apply
()
+
template
.
format_separator
.
apply
(),
tokenizer
)
assistant_message
=
_convert_slots_to_jinja
(
template
.
format_assistant
.
apply
(),
tokenizer
)
jinja_template
+=
"{{ "
+
assistant_message
+
" }}"
jinja_template
+=
"{% endif %}"
jinja_template
+=
"{% endfor %}"
...
...
@@ -364,15 +358,15 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
raise
ValueError
(
f
"Template
{
data_args
.
template
}
does not exist."
)
if
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
require
_version
(
"
transformers>=4.45.0"
,
"To fix: pip install
transformers>=4.45.0"
)
check
_version
(
"transformers>=4.45.0"
)
if
data_args
.
train_on_prompt
and
template
.
efficient_eos
:
raise
ValueError
(
"Current template does not support `train_on_prompt`."
)
if
data_args
.
tool_format
is
not
None
:
logger
.
info_rank0
(
f
"Using tool format:
{
data_args
.
tool_format
}
."
)
eos
_slots
=
[]
if
template
.
efficient_eos
else
[{
"eos_token"
}]
template
.
format_function
=
FunctionFormatter
(
slots
=
eos
_slots
,
tool_format
=
data_args
.
tool_format
)
default
_slots
=
[
"{{content}}"
]
if
template
.
efficient_eos
else
[
"{{content}}"
,
{
"eos_token"
}]
template
.
format_function
=
FunctionFormatter
(
slots
=
default
_slots
,
tool_format
=
data_args
.
tool_format
)
template
.
format_tools
=
ToolFormatter
(
tool_format
=
data_args
.
tool_format
)
stop_words
=
template
.
stop_words
...
...
@@ -410,24 +404,24 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
_register_template
(
name
=
"alpaca"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n\n
### Response:
\n
"
]),
format_
separator
=
Empty
Formatter
(
slots
=
[
"
\n\n
"
]),
format_
assistant
=
String
Formatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n\n
"
]),
default_system
=
(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.
\n\n
"
"Below is an instruction that describes a task. Write a response that appropriately completes the request.
\n\n
"
),
replace_jinja_template
=
True
,
)
_register_template
(
name
=
"aquila"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}###Assistant:"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"###"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}###"
]),
format_system
=
StringFormatter
(
slots
=
[
"System: {{content}}###"
]),
default_system
=
(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
stop_words
=
[
"</s>"
],
efficient_eos
=
True
,
)
...
...
@@ -457,7 +451,7 @@ _register_template(
_register_template
(
name
=
"belle"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Belle: "
]),
format_
separator
=
Empty
Formatter
(
slots
=
[
"
\n\n
"
]),
format_
assistant
=
String
Formatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
...
...
@@ -479,7 +473,6 @@ _register_template(
_register_template
(
name
=
"chatglm2"
,
format_user
=
StringFormatter
(
slots
=
[
"[Round {{idx}}]
\n\n
问:{{content}}
\n\n
答:"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
efficient_eos
=
True
,
)
...
...
@@ -490,7 +483,7 @@ _register_template(
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<|user|>"
},
"
\n
"
,
"{{content}}"
,
{
"token"
:
"<|assistant|>"
}]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
"
,
"{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[{
"token"
:
"<|system|>"
},
"
\n
"
,
"{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[],
tool_format
=
"glm4"
),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[{
"token"
:
"<|observation|>"
},
"
\n
"
,
"{{content}}"
,
{
"token"
:
"<|assistant|>"
}]
),
...
...
@@ -504,23 +497,26 @@ _register_template(
_register_template
(
name
=
"chatml"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
stop_words
=
[
"<|im_end|>"
,
"<|im_start|>"
],
replace_eos
=
True
,
replace_jinja_template
=
True
,
)
# copied from chatml template
_register_template
(
name
=
"chatml_de"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
default_system
=
"Du bist ein freundlicher und hilfsbereiter KI-Assistent."
,
stop_words
=
[
"<|im_end|>"
,
"<|im_start|>"
],
replace_eos
=
True
,
replace_jinja_template
=
True
,
)
...
...
@@ -534,7 +530,7 @@ _register_template(
name
=
"codegeex4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[],
tool_format
=
"glm4"
),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>
\n
"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
...
...
@@ -569,21 +565,24 @@ _register_template(
)
# copied from chatml template
_register_template
(
name
=
"cpm3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|im_end|>"
],
)
# copied from chatml template
_register_template
(
name
=
"dbrx"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
default_system
=
(
"You are DBRX, created by Databricks. You were last updated in December 2023. "
"You answer questions based on information available up to that point.
\n
"
...
...
@@ -600,7 +599,6 @@ _register_template(
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
...
...
@@ -612,11 +610,17 @@ _register_template(
)
_register_template
(
name
=
"deepseek3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>{{content}}<|Assistant|>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
_register_template
(
name
=
"deepseekcoder"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}
\n
<|EOT|>"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}
\n
<|EOT|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"You are an AI programming assistant, utilizing the DeepSeek Coder model, "
...
...
@@ -630,8 +634,8 @@ _register_template(
_register_template
(
name
=
"default"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n
Assistant:"
]),
format_
system
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
format_s
eparator
=
Empty
Formatter
(
slots
=
[
"
\n
"
]),
format_
assistant
=
StringFormatter
(
slots
=
[
"{{content}}
"
,
{
"eos_token"
},
"
\n
"
]),
format_s
ystem
=
String
Formatter
(
slots
=
[
"
System: {{content}}
\n
"
]),
)
...
...
@@ -644,22 +648,22 @@ _register_template(
_register_template
(
name
=
"exaone"
,
format_user
=
StringFormatter
(
slots
=
[
"[|user|]{{content}}
\n
[|assistant|]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"eos_token"
},
"
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"[|system|]{{content}}[|endofturn|]
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
)
_register_template
(
name
=
"falcon"
,
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n
Falcon:"
]),
format_
separator
=
Empty
Formatter
(
slots
=
[
"
\n
"
]),
format_
assistant
=
String
Formatter
(
slots
=
[
"
{{content}}
\n
"
]),
efficient_eos
=
True
,
)
_register_template
(
name
=
"fewshot"
,
format_
separator
=
Empty
Formatter
(
slots
=
[
"
\n\n
"
]),
format_
assistant
=
String
Formatter
(
slots
=
[
"
{{content}}
\n\n
"
]),
efficient_eos
=
True
,
)
...
...
@@ -667,13 +671,11 @@ _register_template(
_register_template
(
name
=
"gemma"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
format_separator
=
EmptyFormatter
(
slots
=
[
"<end_of_turn>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
efficient_eos
=
True
,
replace_jinja_template
=
False
,
)
...
...
@@ -682,7 +684,7 @@ _register_template(
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[],
tool_format
=
"glm4"
),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}"
],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
...
...
@@ -691,6 +693,18 @@ _register_template(
)
_register_template
(
name
=
"granite3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|start_of_role|>user<|end_of_role|>{{content}}<|end_of_text|>
\n
<|start_of_role|>assistant<|end_of_role|>"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end_of_text|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|start_of_role|>system<|end_of_role|>{{content}}<|end_of_text|>
\n
"
]),
)
_register_template
(
name
=
"index"
,
format_user
=
StringFormatter
(
slots
=
[
"reserved_0{{content}}reserved_1"
]),
...
...
@@ -702,22 +716,31 @@ _register_template(
_register_template
(
name
=
"intern"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>:{{content}}
\n
<|Bot|>:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<eoa>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|System|>:{{content}}
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"<eoa>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<eoa>"
],
efficient_eos
=
True
,
# internlm tokenizer cannot set eos_token_id
)
_register_template
(
name
=
"intern2"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|im_end|>"
],
)
# copied from intern2 template
_register_template
(
name
=
"intern3"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|im_end|>"
],
efficient_eos
=
True
,
# internlm2 tokenizer cannot set eos_token_id
)
...
...
@@ -728,6 +751,7 @@ _register_template(
)
# copied from llama2 template
_register_template
(
name
=
"llama2_zh"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"[INST] {{content}} [/INST]"
]),
...
...
@@ -746,22 +770,24 @@ _register_template(
)
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|start_header_id|>system<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
),
format_observation
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>
tool
<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>
ipython
<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
],
replace_eos
=
True
,
replace_jinja_template
=
False
,
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
)
# copied from llama3 template
_register_template
(
name
=
"mllama"
,
format_user
=
StringFormatter
(
...
...
@@ -772,23 +798,25 @@ _register_template(
)
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|start_header_id|>system<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
),
format_observation
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>
tool
<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>
ipython
<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
],
replace_eos
=
True
,
replace_jinja_template
=
False
,
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"mllama"
,
image_token
=
"<|image|>"
),
)
# copied from vicuna template
_register_template
(
name
=
"llava"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
...
...
@@ -800,6 +828,7 @@ _register_template(
)
# copied from vicuna template
_register_template
(
name
=
"llava_next"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
...
...
@@ -811,6 +840,7 @@ _register_template(
)
# copied from llama3 template
_register_template
(
name
=
"llava_next_llama3"
,
format_user
=
StringFormatter
(
...
...
@@ -821,56 +851,67 @@ _register_template(
)
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|start_header_id|>system<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
),
format_observation
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>
tool
<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>
ipython
<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
],
replace_eos
=
True
,
replace_jinja_template
=
False
,
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
# copied from mistral template
_register_template
(
name
=
"llava_next_mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}} [/INST]"
]),
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] "
,
"{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
# copied from chatml template
_register_template
(
name
=
"llava_next_qwen"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
replace_jinja_template
=
False
,
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
# copied from chatml template
_register_template
(
name
=
"llava_next_yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next"
,
image_token
=
"<image>"
),
)
# copied from vicuna template
_register_template
(
name
=
"llava_next_video"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
...
...
@@ -882,28 +923,66 @@ _register_template(
)
# copied from mistral template
_register_template
(
name
=
"llava_next_video_mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}} [/INST]"
]),
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] "
,
"{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
# copied from chatml template
_register_template
(
name
=
"llava_next_video_yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"llava_next_video"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
# copied from chatml template
_register_template
(
name
=
"marco"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
default_system
=
(
"你是一个经过良好训练的AI助手,你的名字是Marco-o1.由阿里国际数字商业集团的AI Business创造.
\n
## 重要!!!!!
\n
"
"当你回答问题时,你的思考应该在<Thought>内完成,<Output>内输出你的结果。
\n
"
"<Thought>应该尽可能是英文,但是有2个特例,一个是对原文中的引用,另一个是是数学应该使用markdown格式,<Output>内的输出需要遵循用户输入的语言。
\n
"
),
stop_words
=
[
"<|im_end|>"
],
)
# copied from chatml template
_register_template
(
name
=
"minicpm_v"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
mm_plugin
=
get_mm_plugin
(
name
=
"minicpm_v"
,
image_token
=
"<image>"
,
video_token
=
"<video>"
),
)
_register_template
(
name
=
"mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}} [/INST]"
]),
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}}[/INST]"
]),
format_assistant
=
StringFormatter
(
slots
=
[
" {{content}}"
,
{
"eos_token"
}]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"[TOOL_CALLS] "
,
"{{content}}"
,
{
"eos_token"
}],
tool_format
=
"mistral"
),
format_observation
=
StringFormatter
(
slots
=
[
"""[TOOL_RESULTS] {"content": {{content}}}[/TOOL_RESULTS]"""
]),
format_tools
=
ToolFormatter
(
tool_format
=
"mistral"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
...
...
@@ -934,20 +1013,18 @@ _register_template(
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
],
replace_eos
=
True
,
)
# copied from chatml template
_register_template
(
name
=
"opencoder"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
default_system
=
"You are OpenCoder, created by OpenCoder Team."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
replace_jinja_template
=
False
,
)
...
...
@@ -958,15 +1035,15 @@ _register_template(
)
# copied from gemma template
_register_template
(
name
=
"paligemma"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<end_of_turn>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
format_separator
=
EmptyFormatter
(
slots
=
[
"<end_of_turn>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
efficient_eos
=
True
,
mm_plugin
=
get_mm_plugin
(
name
=
"paligemma"
,
image_token
=
"<image>"
),
)
...
...
@@ -974,56 +1051,71 @@ _register_template(
_register_template
(
name
=
"phi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"phi_small"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"<|endoftext|>"
}]),
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"phi4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user<|im_sep|>{{content}}<|im_end|><|im_start|>assistant<|im_sep|>"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system<|im_sep|>{{content}}<|im_end|>"
]),
stop_words
=
[
"<|im_end|>"
],
)
_register_template
(
name
=
"pixtral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}} [/INST]"
]),
format_user
=
StringFormatter
(
slots
=
[
"[INST]{{content}}[/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
mm_plugin
=
get_mm_plugin
(
name
=
"pixtral"
,
image_token
=
"[IMG]"
),
)
# copied from chatml template
_register_template
(
name
=
"qwen"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
replace_jinja_template
=
False
,
)
# copied from chatml template
_register_template
(
name
=
"qwen2_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
],
tool_format
=
"qwen"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
<tool_response>
\n
{{content}}
\n
</tool_response><|im_end|>
\n
<|im_start|>assistant
\n
"
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"qwen"
),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
replace_jinja_template
=
False
,
mm_plugin
=
get_mm_plugin
(
name
=
"qwen2_vl"
,
image_token
=
"<|image_pad|>"
,
video_token
=
"<|video_pad|>"
),
)
...
...
@@ -1031,14 +1123,48 @@ _register_template(
_register_template
(
name
=
"sailor"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>question
\n
{{content}}<|im_end|>
\n
<|im_start|>answer
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
default_system
=
(
"You are an AI assistant named Sailor created by Sea AI Lab. "
"Your answer should be friendly, unbiased, faithful, informative and detailed."
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
# copied from llama3 template
_register_template
(
name
=
"skywork_o1"
,
format_user
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>user<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|start_header_id|>system<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
]),
format_function
=
FunctionFormatter
(
slots
=
[
"{{content}}<|eot_id|>"
],
tool_format
=
"llama3"
),
format_observation
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>ipython<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_tools
=
ToolFormatter
(
tool_format
=
"llama3"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"You are Skywork-o1, a thinking model developed by Skywork AI, specializing in solving complex problems "
"involving mathematics, coding, and logical reasoning through deep thought. When faced with a user's request, "
"you first engage in a lengthy and in-depth thinking process to explore possible solutions to the problem. "
"After completing your thoughts, you then provide a detailed explanation of the solution process "
"in your response."
),
stop_words
=
[
"<|eot_id|>"
,
"<|eom_id|>"
],
)
...
...
@@ -1053,10 +1179,9 @@ _register_template(
_register_template
(
name
=
"starchat"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
...
...
@@ -1064,8 +1189,16 @@ _register_template(
name
=
"telechat"
,
format_user
=
StringFormatter
(
slots
=
[
"<_user>{{content}}<_bot>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<_system>{{content}}<_end>"
]),
stop_words
=
[
"<_end>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"telechat2"
,
format_user
=
StringFormatter
(
slots
=
[
"<_user>{{content}}<_bot>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<_system>{{content}}"
]),
default_system
=
(
"你是中国电信星辰语义大模型,英文名是TeleChat,你是由中电信人工智能科技有限公司和中国电信人工智能研究院(TeleAI)研发的人工智能助手。"
),
)
...
...
@@ -1076,6 +1209,7 @@ _register_template(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
replace_jinja_template
=
True
,
)
...
...
@@ -1110,8 +1244,8 @@ _register_template(
_register_template
(
name
=
"yayi"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<|Human|>"
},
":
\n
{{content}}
\n\n
"
,
{
"token"
:
"<|YaYi|>"
},
":"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_system
=
StringFormatter
(
slots
=
[{
"token"
:
"<|System|>"
},
":
\n
{{content}}
\n\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n\n
"
]),
default_system
=
(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
...
...
@@ -1127,20 +1261,20 @@ _register_template(
)
# copied from chatml template
_register_template
(
name
=
"yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}<|im_end|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"yi_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"### Human: {{content}}
\n
### Assistant:"
]),
format_
separator
=
Empty
Formatter
(
slots
=
[
"
\n
"
]),
format_
assistant
=
String
Formatter
(
slots
=
[
"
{{content}}
\n
"
]),
default_system
=
(
"This is a chat between an inquisitive human and an AI assistant. "
"Assume the role of the AI assistant. Read all the images carefully, "
...
...
@@ -1157,9 +1291,8 @@ _register_template(
_register_template
(
name
=
"yuan"
,
format_user
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"token"
:
"<sep>"
}]),
format_
separator
=
Empty
Formatter
(
slots
=
[
"
\n
"
]),
format_
assistant
=
String
Formatter
(
slots
=
[
"
{{content}}<eod>
\n
"
]),
stop_words
=
[
"<eod>"
],
replace_eos
=
True
,
)
...
...
@@ -1174,5 +1307,5 @@ _register_template(
_register_template
(
name
=
"ziya"
,
format_user
=
StringFormatter
(
slots
=
[
"<human>:{{content}}
\n
<bot>:"
]),
format_
separator
=
Empty
Formatter
(
slots
=
[
"
\n
"
]),
format_
assistant
=
String
Formatter
(
slots
=
[
"
{{content}}
\n
"
]),
)
src/llamafactory/data/tool_utils.py
View file @
8293100a
...
...
@@ -15,15 +15,20 @@
import
json
import
re
from
abc
import
ABC
,
abstractmethod
from
collections
import
namedtuple
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
datetime
import
datetime
from
typing
import
Any
,
Dict
,
List
,
NamedTuple
,
Tuple
,
Union
from
typing_extensions
import
override
from
.data_utils
import
SLOTS
class
FunctionCall
(
NamedTuple
):
name
:
str
arguments
:
str
DEFAULT_TOOL_PROMPT
=
(
"You have access to the following tools:
\n
{tool_text}"
"Use the following format if using a tool:
\n
"
...
...
@@ -34,14 +39,25 @@ DEFAULT_TOOL_PROMPT = (
"```
\n
"
)
GLM4_TOOL_PROMPT
=
(
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
)
LLAMA3_TOOL_PROMPT
=
(
"Cutting Knowledge Date: December 2023
\n
Today Date: {date}
\n\n
"
"You have access to the following functions. To call a function, please respond with JSON for a function call. "
"""Respond in the format {{"name": function name, "parameters": dictionary of argument name and its value}}. """
"Do not use variables.
\n\n
{tool_text}"
)
FunctionCall
=
namedtuple
(
"FunctionCall"
,
[
"name"
,
"arguments"
])
QWEN_TOOL_PROMPT
=
(
"
\n\n
# Tools
\n\n
You may call one or more functions to assist with the user query.
\n\n
"
"You are provided with function signatures within <tools></tools> XML tags:
\n
<tools>{tool_text}"
"
\n
</tools>
\n\n
For each function call, return a json object with function name and arguments within "
"""<tool_call></tool_call> XML tags:
\n
<tool_call>
\n
{{"name": <function-name>, """
""""arguments": <args-json-object>}}
\n
</tool_call><|im_end|>
\n
"""
)
@
dataclass
...
...
@@ -52,17 +68,17 @@ class ToolUtils(ABC):
@
staticmethod
@
abstractmethod
def
get_function_slots
()
->
SLOTS
:
def
tool_formatter
(
tools
:
List
[
Dict
[
str
,
Any
]])
->
str
:
r
"""
Ge
ts a list of slots corresponding to a single function call
.
Ge
nerates the system message describing all the available tools
.
"""
...
@
staticmethod
@
abstractmethod
def
tool
_formatter
(
tools
:
List
[
Dict
[
str
,
Any
]
])
->
str
:
def
function
_formatter
(
functions
:
List
[
"FunctionCall"
])
->
SLOTS
:
r
"""
Generates the
system
message
describ
ing all the
available too
ls.
Generates the
assistant
message
includ
ing all the
tool cal
ls.
"""
...
...
...
@@ -70,16 +86,17 @@ class ToolUtils(ABC):
@
abstractmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
r
"""
Extracts all the function calls from the response message.
Extracts all the function calls from the assistant message.
It should be an inverse function of `function_formatter`.
"""
...
class
DefaultToolUtils
(
ToolUtils
):
@
override
@
staticmethod
def
get_function_slots
()
->
SLOTS
:
return
[
"Action: {{name}}
\n
Action Input: {{arguments}}
\n
"
]
r
"""
Default tool using template.
"""
@
override
@
staticmethod
...
...
@@ -115,6 +132,15 @@ class DefaultToolUtils(ToolUtils):
return
DEFAULT_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
,
tool_names
=
", "
.
join
(
tool_names
))
@
override
@
staticmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
SLOTS
:
function_text
=
""
for
name
,
arguments
in
functions
:
function_text
+=
f
"Action:
{
name
}
\n
Action Input:
{
arguments
}
\n
"
return
[
function_text
]
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
...
...
@@ -129,7 +155,7 @@ class DefaultToolUtils(ToolUtils):
tool_input
=
match
[
1
].
strip
().
strip
(
'"'
).
strip
(
"```"
)
try
:
arguments
=
json
.
loads
(
tool_input
)
results
.
append
((
tool_name
,
json
.
dumps
(
arguments
,
ensure_ascii
=
False
)))
results
.
append
(
FunctionCall
(
tool_name
,
json
.
dumps
(
arguments
,
ensure_ascii
=
False
)))
except
json
.
JSONDecodeError
:
return
content
...
...
@@ -137,10 +163,9 @@ class DefaultToolUtils(ToolUtils):
class
GLM4ToolUtils
(
ToolUtils
):
@
override
@
staticmethod
def
get_function_slots
()
->
SLOTS
:
return
[
"{{name}}
\n
{{arguments}}"
]
r
"""
GLM-4 tool using template.
"""
@
override
@
staticmethod
...
...
@@ -153,6 +178,14 @@ class GLM4ToolUtils(ToolUtils):
return
GLM4_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
@
override
@
staticmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
SLOTS
:
if
len
(
functions
)
>
1
:
raise
ValueError
(
"GLM-4 does not support parallel functions."
)
return
[
f
"
{
functions
[
0
].
name
}
\n
{
functions
[
0
].
arguments
}
"
]
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
...
...
@@ -161,16 +194,152 @@ class GLM4ToolUtils(ToolUtils):
tool_name
,
tool_input
=
content
.
split
(
"
\n
"
,
maxsplit
=
1
)
try
:
arguments
=
json
.
loads
(
tool_input
)
arguments
=
json
.
loads
(
tool_input
.
strip
())
except
json
.
JSONDecodeError
:
return
content
return
[
FunctionCall
(
tool_name
,
json
.
dumps
(
arguments
,
ensure_ascii
=
False
))]
class
Llama3ToolUtils
(
ToolUtils
):
r
"""
Llama 3.x tool using template with `tools_in_user_message=False`.
Reference: https://www.llama.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling
"""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
List
[
Dict
[
str
,
Any
]])
->
str
:
date
=
datetime
.
now
().
strftime
(
"%d %b %Y"
)
tool_text
=
""
for
tool
in
tools
:
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
json
.
dumps
(
wrapped_tool
,
indent
=
4
,
ensure_ascii
=
False
)
+
"
\n\n
"
return
LLAMA3_TOOL_PROMPT
.
format
(
date
=
date
,
tool_text
=
tool_text
)
@
override
@
staticmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
SLOTS
:
if
len
(
functions
)
>
1
:
raise
ValueError
(
"Llama-3 does not support parallel functions."
)
return
[
f
'{{"name": "
{
functions
[
0
].
name
}
", "parameters":
{
functions
[
0
].
arguments
}
}}'
]
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
try
:
tool
=
json
.
loads
(
content
.
strip
())
except
json
.
JSONDecodeError
:
return
content
return
[(
tool_name
,
json
.
dumps
(
arguments
,
ensure_ascii
=
False
))]
if
"name"
not
in
tool
or
"parameters"
not
in
tool
:
return
content
return
[
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"parameters"
],
ensure_ascii
=
False
))]
class
MistralToolUtils
(
ToolUtils
):
r
"""
Mistral v0.3 tool using template.
"""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
List
[
Dict
[
str
,
Any
]])
->
str
:
wrapped_tools
=
[]
for
tool
in
tools
:
wrapped_tools
.
append
({
"type"
:
"function"
,
"function"
:
tool
})
return
"[AVAILABLE_TOOLS] "
+
json
.
dumps
(
wrapped_tools
,
ensure_ascii
=
False
)
+
"[/AVAILABLE_TOOLS]"
@
override
@
staticmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
SLOTS
:
function_texts
=
[]
for
name
,
arguments
in
functions
:
function_texts
.
append
(
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
)
return
[
"["
+
", "
.
join
(
function_texts
)
+
"]"
]
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
try
:
tools
=
json
.
loads
(
content
.
strip
())
except
json
.
JSONDecodeError
:
return
content
if
not
isinstance
(
tools
,
list
):
tools
=
[
tools
]
results
=
[]
for
tool
in
tools
:
if
"name"
not
in
tool
or
"arguments"
not
in
tool
:
return
content
results
.
append
(
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"arguments"
],
ensure_ascii
=
False
)))
return
results
class
QwenToolUtils
(
ToolUtils
):
r
"""
Qwen 2.5 tool using template.
"""
@
override
@
staticmethod
def
tool_formatter
(
tools
:
List
[
Dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
wrapped_tool
=
{
"type"
:
"function"
,
"function"
:
tool
}
tool_text
+=
"
\n
"
+
json
.
dumps
(
wrapped_tool
,
ensure_ascii
=
False
)
return
QWEN_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
@
override
@
staticmethod
def
function_formatter
(
functions
:
List
[
"FunctionCall"
])
->
SLOTS
:
function_texts
=
[]
for
name
,
arguments
in
functions
:
function_texts
.
append
(
"<tool_call>
\n
"
+
f
'{{"name": "
{
name
}
", "arguments":
{
arguments
}
}}'
+
"
\n
</tool_call>"
)
return
[
"
\n
"
.
join
(
function_texts
)]
@
override
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
regex
=
re
.
compile
(
r
"<tool_call>(.+?)</tool_call>(?=\s*<tool_call>|\s*$)"
,
re
.
DOTALL
)
tool_match
:
List
[
str
]
=
re
.
findall
(
regex
,
content
)
if
not
tool_match
:
return
content
results
=
[]
for
tool
in
tool_match
:
try
:
tool
=
json
.
loads
(
tool
.
strip
())
except
json
.
JSONDecodeError
:
return
content
if
"name"
not
in
tool
or
"arguments"
not
in
tool
:
return
content
results
.
append
(
FunctionCall
(
tool
[
"name"
],
json
.
dumps
(
tool
[
"arguments"
],
ensure_ascii
=
False
)))
return
results
TOOLS
=
{
"default"
:
DefaultToolUtils
(),
"glm4"
:
GLM4ToolUtils
(),
"llama3"
:
Llama3ToolUtils
(),
"mistral"
:
MistralToolUtils
(),
"qwen"
:
QwenToolUtils
(),
}
...
...
src/llamafactory/eval/evaluator.py
View file @
8293100a
...
...
@@ -100,7 +100,7 @@ class Evaluator:
cache_dir
=
self
.
model_args
.
cache_dir
,
download_mode
=
self
.
eval_args
.
download_mode
,
token
=
self
.
model_args
.
hf_hub_token
,
trust_remote_code
=
Tru
e
,
trust_remote_code
=
self
.
model_args
.
trust_remote_cod
e
,
)
pbar
.
set_postfix_str
(
categorys
[
subject
][
"name"
])
inputs
,
outputs
,
labels
=
[],
[],
[]
...
...
src/llamafactory/extras/constants.py
View file @
8293100a
...
...
@@ -81,19 +81,6 @@ TRAINING_STAGES = {
STAGES_USE_PAIR_DATA
=
{
"rm"
,
"dpo"
}
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
=
{
"cohere"
,
"falcon"
,
"gemma"
,
"gemma2"
,
"llama"
,
"mistral"
,
"phi"
,
"phi3"
,
"qwen2"
,
"starcoder2"
,
}
SUPPORTED_CLASS_FOR_S2ATTN
=
{
"llama"
}
VIDEO_PLACEHOLDER
=
os
.
environ
.
get
(
"VIDEO_PLACEHOLDER"
,
"<video>"
)
...
...
@@ -118,7 +105,7 @@ def register_model_group(
)
->
None
:
for
name
,
path
in
models
.
items
():
SUPPORTED_MODELS
[
name
]
=
path
if
template
is
not
None
and
any
(
suffix
in
name
for
suffix
in
(
"-Chat"
,
"-Instruct"
)):
if
template
is
not
None
and
(
any
(
suffix
in
name
for
suffix
in
(
"-Chat"
,
"-Instruct"
))
or
vision
)
:
DEFAULT_TEMPLATE
[
name
]
=
template
if
vision
:
VISION_MODELS
.
add
(
name
)
...
...
@@ -338,6 +325,7 @@ register_model_group(
models
=
{
"Codestral-22B-v0.1-Chat"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Codestral-22B-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"swift/Codestral-22B-v0.1"
,
},
},
template
=
"mistral"
,
...
...
@@ -433,15 +421,19 @@ register_model_group(
},
"DeepSeek-Coder-V2-16B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-Coder-V2-Lite-Base"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-Coder-V2-Lite-Base"
,
},
"DeepSeek-Coder-V2-236B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-Coder-V2-Base"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-Coder-V2-Base"
,
},
"DeepSeek-Coder-V2-16B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
,
},
"DeepSeek-Coder-V2-236B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-Coder-V2-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-Coder-V2-Instruct"
,
},
},
template
=
"deepseek"
,
...
...
@@ -456,6 +448,7 @@ register_model_group(
},
"DeepSeek-Coder-7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-coder-7b-base-v1.5"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-coder-7b-base-v1.5"
,
},
"DeepSeek-Coder-33B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-coder-33b-base"
,
...
...
@@ -467,6 +460,7 @@ register_model_group(
},
"DeepSeek-Coder-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-coder-7b-instruct-v1.5"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-coder-7b-instruct-v1.5"
,
},
"DeepSeek-Coder-33B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-coder-33b-instruct"
,
...
...
@@ -477,6 +471,33 @@ register_model_group(
)
register_model_group
(
models
=
{
"DeepSeek-V2-236B-Chat-0628"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2-Chat-0628"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2-Chat-0628"
,
},
"DeepSeek-V2.5-236B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2.5"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2.5"
,
},
"DeepSeek-V2.5-236B-Chat-1210"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2.5-1210"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2.5-1210"
,
},
"DeepSeek-V3-685B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3-Base"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3-Base"
,
},
"DeepSeek-V3-685B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V3"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V3"
,
},
},
template
=
"deepseek3"
,
)
register_model_group
(
models
=
{
"EXAONE-3.0-7.8B-Instruct"
:
{
...
...
@@ -495,6 +516,7 @@ register_model_group(
},
"Falcon-11B"
:
{
DownloadSource
.
DEFAULT
:
"tiiuae/falcon-11B"
,
DownloadSource
.
MODELSCOPE
:
"tiiuae/falcon-11B"
,
},
"Falcon-40B"
:
{
DownloadSource
.
DEFAULT
:
"tiiuae/falcon-40b"
,
...
...
@@ -598,14 +620,99 @@ register_model_group(
register_model_group
(
models
=
{
"
Index-1.9B-Chat
"
:
{
DownloadSource
.
DEFAULT
:
"
IndexTeam/Index-1.9B-Chat
"
,
DownloadSource
.
MODELSCOPE
:
"
IndexTeam/Index-1.9B-Chat
"
,
"
GPT-2-Small
"
:
{
DownloadSource
.
DEFAULT
:
"
openai-community/gpt2
"
,
DownloadSource
.
MODELSCOPE
:
"
AI-ModelScope/gpt2
"
,
},
"Index-1.9B-Character-Chat"
:
{
DownloadSource
.
DEFAULT
:
"IndexTeam/Index-1.9B-Character"
,
DownloadSource
.
MODELSCOPE
:
"IndexTeam/Index-1.9B-Character"
,
"GPT-2-Medium"
:
{
DownloadSource
.
DEFAULT
:
"openai-community/gpt2-medium"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/gpt2-medium"
,
},
"GPT-2-Large"
:
{
DownloadSource
.
DEFAULT
:
"openai-community/gpt2-large"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/gpt2-large"
,
},
"GPT-2-XL"
:
{
DownloadSource
.
DEFAULT
:
"openai-community/gpt2-xl"
,
DownloadSource
.
MODELSCOPE
:
"goodbai95/GPT2-xl"
,
},
},
)
register_model_group
(
models
=
{
"Granite-3.0-1B-A400M-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.0-1b-a400m-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.0-1b-a400m-base"
,
},
"Granite-3.0-3B-A800M-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.0-3b-a800m-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.0-3b-a800m-base"
,
},
"Granite-3.0-2B-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.0-2b-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.0-2b-base"
,
},
"Granite-3.0-8B-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.0-8b-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.0-8b-base"
,
},
"Granite-3.0-1B-A400M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.0-1b-a400m-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.0-1b-a400m-instruct"
,
},
"Granite-3.0-3B-A800M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.0-3b-a800m-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.0-3b-a800m-instruct"
,
},
"Granite-3.0-2B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.0-2b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.0-2b-instruct"
,
},
"Granite-3.0-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.0-8b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.0-8b-instruct"
,
},
"Granite-3.1-1B-A400M-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.1-1b-a400m-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.1-1b-a400m-base"
,
},
"Granite-3.1-3B-A800M-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.1-3b-a800m-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.1-3b-a800m-base"
,
},
"Granite-3.1-2B-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.1-2b-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.1-2b-base"
,
},
"Granite-3.1-8B-Base"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.1-8b-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.1-8b-base"
,
},
"Granite-3.1-1B-A400M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.1-1b-a400m-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.1-1b-a400m-instruct"
,
},
"Granite-3.1-3B-A800M-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.1-3b-a800m-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.1-3b-a800m-instruct"
,
},
"Granite-3.1-2B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.1-2b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.1-2b-instruct"
,
},
"Granite-3.1-8B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"ibm-granite/granite-3.1-8b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/granite-3.1-8b-instruct"
,
},
},
template
=
"granite3"
,
)
register_model_group
(
models
=
{
"Index-1.9B-Base"
:
{
DownloadSource
.
DEFAULT
:
"IndexTeam/Index-1.9B"
,
DownloadSource
.
MODELSCOPE
:
"IndexTeam/Index-1.9B"
,
...
...
@@ -614,6 +721,14 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"IndexTeam/Index-1.9B-Pure"
,
DownloadSource
.
MODELSCOPE
:
"IndexTeam/Index-1.9B-Pure"
,
},
"Index-1.9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"IndexTeam/Index-1.9B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"IndexTeam/Index-1.9B-Chat"
,
},
"Index-1.9B-Character-Chat"
:
{
DownloadSource
.
DEFAULT
:
"IndexTeam/Index-1.9B-Character"
,
DownloadSource
.
MODELSCOPE
:
"IndexTeam/Index-1.9B-Character"
,
},
"Index-1.9B-Chat-32K"
:
{
DownloadSource
.
DEFAULT
:
"IndexTeam/Index-1.9B-32K"
,
DownloadSource
.
MODELSCOPE
:
"IndexTeam/Index-1.9B-32K"
,
...
...
@@ -702,6 +817,15 @@ register_model_group(
template
=
"intern2"
,
)
register_model_group
(
models
=
{
"InternLM3-8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm3-8b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm3-8b-instruct"
,
},
},
template
=
"intern3"
,
)
register_model_group
(
models
=
{
...
...
@@ -850,6 +974,10 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-3.2-3B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-3.2-3B-Instruct"
,
},
"Llama-3.3-70B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-3.3-70B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-3.3-70B-Instruct"
,
},
},
template
=
"llama3"
,
)
...
...
@@ -857,10 +985,18 @@ register_model_group(
register_model_group
(
models
=
{
"Llama-3.2-11B-Vision"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-3.2-11B-Vision"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-3.2-11B-Vision"
,
},
"Llama-3.2-11B-Vision-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-3.2-11B-Vision-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-3.2-11B-Vision-Instruct"
,
},
"Llama-3.2-90B-Vision"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-3.2-90B-Vision"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-3.2-90B-Vision"
,
},
"Llama-3.2-90B-Vision-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-3.2-90B-Vision-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama-3.2-90B-Vision-Instruct"
,
...
...
@@ -998,6 +1134,17 @@ register_model_group(
)
register_model_group
(
models
=
{
"Marco-o1-Chat"
:
{
DownloadSource
.
DEFAULT
:
"AIDC-AI/Marco-o1"
,
DownloadSource
.
MODELSCOPE
:
"AIDC-AI/Marco-o1"
,
},
},
template
=
"marco"
,
)
register_model_group
(
models
=
{
"MiniCPM-2B-SFT-Chat"
:
{
...
...
@@ -1025,6 +1172,28 @@ register_model_group(
)
register_model_group
(
models
=
{
"MiniCPM-o-2_6-Chat"
:
{
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM-o-2_6"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/MiniCPM-o-2_6"
,
},
},
template
=
"minicpm_v"
,
)
register_model_group
(
models
=
{
"MiniCPM-V-2_6-Chat"
:
{
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM-V-2_6"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/MiniCPM-V-2_6"
,
},
},
template
=
"minicpm_v"
,
)
register_model_group
(
models
=
{
"Mistral-7B-v0.1"
:
{
...
...
@@ -1173,23 +1342,23 @@ register_model_group(
register_model_group
(
models
=
{
"PaliGemma-3B-pt-224
-Chat
"
:
{
"PaliGemma-3B-pt-224"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma-3b-pt-224"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma-3b-pt-224"
,
},
"PaliGemma-3B-pt-448
-Chat
"
:
{
"PaliGemma-3B-pt-448"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma-3b-pt-448"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma-3b-pt-448"
,
},
"PaliGemma-3B-pt-896
-Chat
"
:
{
"PaliGemma-3B-pt-896"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma-3b-pt-896"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma-3b-pt-896"
,
},
"PaliGemma-3B-mix-224
-Chat
"
:
{
"PaliGemma-3B-mix-224"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma-3b-mix-224"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma-3b-mix-224"
,
},
"PaliGemma-3B-mix-448
-Chat
"
:
{
"PaliGemma-3B-mix-448"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma-3b-mix-448"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma-3b-mix-448"
,
},
...
...
@@ -1199,6 +1368,50 @@ register_model_group(
)
register_model_group
(
models
=
{
"PaliGemma2-3B-pt-224"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-3b-pt-224"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma2-3b-pt-224"
,
},
"PaliGemma2-3B-pt-448"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-3b-pt-448"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma2-3b-pt-448"
,
},
"PaliGemma2-3B-pt-896"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-3b-pt-896"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma2-3b-pt-896"
,
},
"PaliGemma2-10B-pt-224"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-10b-pt-224"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma2-10b-pt-224"
,
},
"PaliGemma2-10B-pt-448"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-10b-pt-448"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma2-10b-pt-448"
,
},
"PaliGemma2-10B-pt-896"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-10b-pt-896"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma2-10b-pt-896"
,
},
"PaliGemma2-28B-pt-224"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-28b-pt-224"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma2-28b-pt-224"
,
},
"PaliGemma2-28B-pt-448"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-28b-pt-448"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma2-28b-pt-448"
,
},
"PaliGemma2-28B-pt-896"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma2-28b-pt-896"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma2-28b-pt-896"
,
},
},
template
=
"paligemma"
,
vision
=
True
,
)
register_model_group
(
models
=
{
"Phi-1.5-1.3B"
:
{
...
...
@@ -1231,6 +1444,14 @@ register_model_group(
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-medium-128k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-medium-128k-instruct"
,
},
"Phi-3.5-4B-instruct"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3.5-mini-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3.5-mini-instruct"
,
},
"Phi-3.5-MoE-42B-A6.6B-instruct"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3.5-MoE-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3.5-MoE-instruct"
,
},
},
template
=
"phi"
,
)
...
...
@@ -1253,7 +1474,18 @@ register_model_group(
register_model_group
(
models
=
{
"Pixtral-12B-Chat"
:
{
"Phi-4-14B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/phi-4"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/phi-4"
,
},
},
template
=
"phi4"
,
)
register_model_group
(
models
=
{
"Pixtral-12B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"mistral-community/pixtral-12b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/pixtral-12b"
,
}
...
...
@@ -1267,67 +1499,67 @@ register_model_group(
models
=
{
"Qwen-1.8B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-1_8B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-1_8B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-1_8B"
,
},
"Qwen-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-7B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-7B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-7B"
,
},
"Qwen-14B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-14B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-14B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-14B"
,
},
"Qwen-72B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-72B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-72B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-72B"
,
},
"Qwen-1.8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-1_8B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-1_8B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-1_8B-Chat"
,
},
"Qwen-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-7B-Chat"
,
},
"Qwen-14B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-14B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-14B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-14B-Chat"
,
},
"Qwen-72B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-72B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-72B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-72B-Chat"
,
},
"Qwen-1.8B-Chat-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-1_8B-Chat-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-1_8B-Chat-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-1_8B-Chat-Int8"
,
},
"Qwen-1.8B-Chat-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-1_8B-Chat-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-1_8B-Chat-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-1_8B-Chat-Int4"
,
},
"Qwen-7B-Chat-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-7B-Chat-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-7B-Chat-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-7B-Chat-Int8"
,
},
"Qwen-7B-Chat-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-7B-Chat-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-7B-Chat-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-7B-Chat-Int4"
,
},
"Qwen-14B-Chat-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-14B-Chat-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-14B-Chat-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-14B-Chat-Int8"
,
},
"Qwen-14B-Chat-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-14B-Chat-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-14B-Chat-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-14B-Chat-Int4"
,
},
"Qwen-72B-Chat-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-72B-Chat-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-72B-Chat-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-72B-Chat-Int8"
,
},
"Qwen-72B-Chat-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-72B-Chat-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen-72B-Chat-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen-72B-Chat-Int4"
,
},
},
template
=
"qwen"
,
...
...
@@ -1338,147 +1570,147 @@ register_model_group(
models
=
{
"Qwen1.5-0.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-0.5B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-0.5B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-0.5B"
,
},
"Qwen1.5-1.8B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-1.8B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-1.8B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-1.8B"
,
},
"Qwen1.5-4B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-4B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-4B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-4B"
,
},
"Qwen1.5-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-7B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-7B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-7B"
,
},
"Qwen1.5-14B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-14B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-14B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-14B"
,
},
"Qwen1.5-32B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-32B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-32B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-32B"
,
},
"Qwen1.5-72B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-72B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-72B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-72B"
,
},
"Qwen1.5-110B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-110B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-110B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-110B"
,
},
"Qwen1.5-MoE-A2.7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-MoE-A2.7B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-MoE-A2.7B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-MoE-A2.7B"
,
},
"Qwen1.5-0.5B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-0.5B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-0.5B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-0.5B-Chat"
,
},
"Qwen1.5-1.8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-1.8B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-1.8B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-1.8B-Chat"
,
},
"Qwen1.5-4B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-4B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-4B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-4B-Chat"
,
},
"Qwen1.5-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-7B-Chat"
,
},
"Qwen1.5-14B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-14B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-14B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-14B-Chat"
,
},
"Qwen1.5-32B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-32B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-32B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-32B-Chat"
,
},
"Qwen1.5-72B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-72B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-72B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-72B-Chat"
,
},
"Qwen1.5-110B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-110B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-110B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-110B-Chat"
,
},
"Qwen1.5-MoE-A2.7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-MoE-A2.7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-MoE-A2.7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-MoE-A2.7B-Chat"
,
},
"Qwen1.5-0.5B-Chat-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-0.5B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-0.5B-Chat-GPTQ-Int8"
,
},
"Qwen1.5-0.5B-Chat-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-0.5B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-0.5B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-0.5B-Chat-AWQ"
,
},
"Qwen1.5-1.8B-Chat-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-1.8B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-1.8B-Chat-GPTQ-Int8"
,
},
"Qwen1.5-1.8B-Chat-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-1.8B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-1.8B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-1.8B-Chat-AWQ"
,
},
"Qwen1.5-4B-Chat-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-4B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-4B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-4B-Chat-GPTQ-Int8"
,
},
"Qwen1.5-4B-Chat-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-4B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-4B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-4B-Chat-AWQ"
,
},
"Qwen1.5-7B-Chat-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-7B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-7B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-7B-Chat-GPTQ-Int8"
,
},
"Qwen1.5-7B-Chat-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-7B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-7B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-7B-Chat-AWQ"
,
},
"Qwen1.5-14B-Chat-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-14B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-14B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-14B-Chat-GPTQ-Int8"
,
},
"Qwen1.5-14B-Chat-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-14B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-14B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-14B-Chat-AWQ"
,
},
"Qwen1.5-32B-Chat-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-32B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-32B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-32B-Chat-AWQ"
,
},
"Qwen1.5-72B-Chat-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-72B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-72B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-72B-Chat-GPTQ-Int8"
,
},
"Qwen1.5-72B-Chat-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-72B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-72B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-72B-Chat-AWQ"
,
},
"Qwen1.5-110B-Chat-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-110B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-110B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-110B-Chat-AWQ"
,
},
"Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4"
,
},
"CodeQwen1.5-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/CodeQwen1.5-7B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/CodeQwen1.5-7B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/CodeQwen1.5-7B"
,
},
"CodeQwen1.5-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/CodeQwen1.5-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/CodeQwen1.5-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/CodeQwen1.5-7B-Chat"
,
},
"CodeQwen1.5-7B-Chat-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/CodeQwen1.5-7B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/CodeQwen1.5-7B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/CodeQwen1.5-7B-Chat-AWQ"
,
},
},
template
=
"qwen"
,
...
...
@@ -1489,122 +1721,122 @@ register_model_group(
models
=
{
"Qwen2-0.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-0.5B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-0.5B"
,
},
"Qwen2-1.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-1.5B"
,
},
"Qwen2-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-7B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-7B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-7B"
,
},
"Qwen2-72B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-72B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-72B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-72B"
,
},
"Qwen2-MoE-57B-A14B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-57B-A14B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-57B-A14B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-57B-A14B"
,
},
"Qwen2-0.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-0.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-0.5B-Instruct"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/Qwen2-0.5B-Instruct"
,
},
"Qwen2-1.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-1.5B-Instruct"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/Qwen2-1.5B-Instruct"
,
},
"Qwen2-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-7B-Instruct"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/Qwen2-7B-Instruct"
,
},
"Qwen2-72B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-72B-Instruct"
,
},
"Qwen2-MoE-57B-A14B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-57B-A14B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-57B-A14B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-57B-A14B-Instruct"
,
},
"Qwen2-0.5B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-0.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-0.5B-Instruct-GPTQ-Int8"
,
},
"Qwen2-0.5B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-0.5B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-0.5B-Instruct-GPTQ-Int4"
,
},
"Qwen2-0.5B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-0.5B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-0.5B-Instruct-AWQ"
,
},
"Qwen2-1.5B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-1.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-1.5B-Instruct-GPTQ-Int8"
,
},
"Qwen2-1.5B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-1.5B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-1.5B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-1.5B-Instruct-GPTQ-Int4"
,
},
"Qwen2-1.5B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-1.5B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-1.5B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-1.5B-Instruct-AWQ"
,
},
"Qwen2-7B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-7B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-7B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-7B-Instruct-GPTQ-Int8"
,
},
"Qwen2-7B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-7B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-7B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-7B-Instruct-GPTQ-Int4"
,
},
"Qwen2-7B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-7B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-7B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-7B-Instruct-AWQ"
,
},
"Qwen2-72B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-72B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-72B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-72B-Instruct-GPTQ-Int8"
,
},
"Qwen2-72B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-72B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-72B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-72B-Instruct-GPTQ-Int4"
,
},
"Qwen2-72B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-72B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-72B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-72B-Instruct-AWQ"
,
},
"Qwen2-57B-A14B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-57B-A14B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-57B-A14B-Instruct-GPTQ-Int4"
,
},
"Qwen2-Math-1.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Math-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-Math-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-Math-1.5B"
,
},
"Qwen2-Math-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Math-7B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-Math-7B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-Math-7B"
,
},
"Qwen2-Math-72B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Math-72B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-Math-72B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-Math-72B"
,
},
"Qwen2-Math-1.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Math-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-Math-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-Math-1.5B-Instruct"
,
},
"Qwen2-Math-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Math-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-Math-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-Math-7B-Instruct"
,
},
"Qwen2-Math-72B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-Math-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-Math-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-Math-72B-Instruct"
,
},
},
template
=
"qwen"
,
...
...
@@ -1615,215 +1847,219 @@ register_model_group(
models
=
{
"Qwen2.5-0.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-0.5B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-0.5B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-0.5B"
,
},
"Qwen2.5-1.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-1.5B"
,
},
"Qwen2.5-3B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-3B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-3B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-3B"
,
},
"Qwen2.5-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-7B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-7B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-7B"
,
},
"Qwen2.5-14B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-14B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-14B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-14B"
,
},
"Qwen2.5-32B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-32B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-32B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-32B"
,
},
"Qwen2.5-72B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-72B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-72B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-72B"
,
},
"Qwen2.5-0.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-0.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-0.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-0.5B-Instruct"
,
},
"Qwen2.5-1.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-1.5B-Instruct"
,
},
"Qwen2.5-3B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-3B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-3B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-3B-Instruct"
,
},
"Qwen2.5-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-7B-Instruct"
,
},
"Qwen2.5-14B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-14B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-14B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-14B-Instruct"
,
},
"Qwen2.5-32B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-32B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-32B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-32B-Instruct"
,
},
"Qwen2.5-72B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-72B-Instruct"
,
},
"Qwen2.5-0.5B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-0.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-0.5B-Instruct-GPTQ-Int8"
,
},
"Qwen2.5-0.5B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-0.5B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-0.5B-Instruct-GPTQ-Int4"
,
},
"Qwen2.5-0.5B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-0.5B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-0.5B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-0.5B-Instruct-AWQ"
,
},
"Qwen2.5-1.5B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-1.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-1.5B-Instruct-GPTQ-Int8"
,
},
"Qwen2.5-1.5B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-1.5B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-1.5B-Instruct-GPTQ-Int4"
,
},
"Qwen2.5-1.5B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-1.5B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-1.5B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-1.5B-Instruct-AWQ"
,
},
"Qwen2.5-3B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-3B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-3B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-3B-Instruct-GPTQ-Int8"
,
},
"Qwen2.5-3B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-3B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-3B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-3B-Instruct-GPTQ-Int4"
,
},
"Qwen2.5-3B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-3B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-3B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-3B-Instruct-AWQ"
,
},
"Qwen2.5-7B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-7B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-7B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-7B-Instruct-GPTQ-Int8"
,
},
"Qwen2.5-7B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-7B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-7B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-7B-Instruct-GPTQ-Int4"
,
},
"Qwen2.5-7B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-7B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-7B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-7B-Instruct-AWQ"
,
},
"Qwen2.5-14B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-14B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-14B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-14B-Instruct-GPTQ-Int8"
,
},
"Qwen2.5-14B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-14B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-14B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-14B-Instruct-GPTQ-Int4"
,
},
"Qwen2.5-14B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-14B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-14B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-14B-Instruct-AWQ"
,
},
"Qwen2.5-32B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-32B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-32B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-32B-Instruct-GPTQ-Int8"
,
},
"Qwen2.5-32B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-32B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-32B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-32B-Instruct-GPTQ-Int4"
,
},
"Qwen2.5-32B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-32B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-32B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-32B-Instruct-AWQ"
,
},
"Qwen2.5-72B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-72B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-72B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-72B-Instruct-GPTQ-Int8"
,
},
"Qwen2.5-72B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-72B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-72B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-72B-Instruct-GPTQ-Int4"
,
},
"Qwen2.5-72B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-72B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-72B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-72B-Instruct-AWQ"
,
},
"Qwen2.5-Coder-0.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-0.5B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-0.5B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-0.5B"
,
},
"Qwen2.5-Coder-1.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-1.5B"
,
},
"Qwen2.5-Coder-3B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-3B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-3B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-3B"
,
},
"Qwen2.5-Coder-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-7B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-7B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-7B"
,
},
"Qwen2.5-Coder-14B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-14B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-14B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-14B"
,
},
"Qwen2.5-Coder-32B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-32B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-32B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-32B"
,
},
"Qwen2.5-Coder-0.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-0.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-0.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-0.5B-Instruct"
,
},
"Qwen2.5-Coder-1.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-1.5B-Instruct"
,
},
"Qwen2.5-Coder-3B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-3B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-3B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-3B-Instruct"
,
},
"Qwen2.5-Coder-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-7B-Instruct"
,
},
"Qwen2.5-Coder-14B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-14B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-14B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-14B-Instruct"
,
},
"Qwen2.5-Coder-32B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Coder-32B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-32B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-32B-Instruct"
,
},
"Qwen2.5-Math-1.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Math-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Math-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Math-1.5B"
,
},
"Qwen2.5-Math-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Math-7B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Math-7B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Math-7B"
,
},
"Qwen2.5-Math-72B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Math-72B"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Math-72B"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Math-72B"
,
},
"Qwen2.5-Math-1.5B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Math-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-1.5B-Instruct"
,
},
"Qwen2.5-Math-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Math-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2.5-Coder-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2.5-Coder-7B-Instruct"
,
},
"Qwen2.5-Math-72B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2.5-Math-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2.5-Coder-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2.5-Coder-72B-Instruct"
,
},
"QwQ-32B-Preview-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/QwQ-32B-Preview"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/QwQ-32B-Preview"
,
},
},
template
=
"qwen"
,
...
...
@@ -1834,53 +2070,57 @@ register_model_group(
models
=
{
"Qwen2-VL-2B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-2B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-VL-2B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-VL-2B-Instruct"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/Qwen2-VL-2B-Instruct"
,
},
"Qwen2-VL-7B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-VL-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-VL-7B-Instruct"
,
DownloadSource
.
OPENMIND
:
"LlamaFactory/Qwen2-VL-7B-Instruct"
,
},
"Qwen2-VL-72B-Instruct"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-VL-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-VL-72B-Instruct"
,
},
"Qwen2-VL-2B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-VL-2B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-VL-2B-Instruct-GPTQ-Int8"
,
},
"Qwen2-VL-2B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-VL-2B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-VL-2B-Instruct-GPTQ-Int4"
,
},
"Qwen2-VL-2B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-2B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-VL-2B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-VL-2B-Instruct-AWQ"
,
},
"Qwen2-VL-7B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-VL-7B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-VL-7B-Instruct-GPTQ-Int8"
,
},
"Qwen2-VL-7B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-VL-7B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-VL-7B-Instruct-GPTQ-Int4"
,
},
"Qwen2-VL-7B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-7B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-VL-7B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-VL-7B-Instruct-AWQ"
,
},
"Qwen2-VL-72B-Instruct-GPTQ-Int8"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-VL-72B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-VL-72B-Instruct-GPTQ-Int8"
,
},
"Qwen2-VL-72B-Instruct-GPTQ-Int4"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-72B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
q
wen/Qwen2-VL-72B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"
Q
wen/Qwen2-VL-72B-Instruct-GPTQ-Int4"
,
},
"Qwen2-VL-72B-Instruct-AWQ"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-VL-72B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-VL-72B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/Qwen2-VL-72B-Instruct-AWQ"
,
},
"QVQ-72B-Preview"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/QVQ-72B-Preview"
,
DownloadSource
.
MODELSCOPE
:
"Qwen/QVQ-72B-Preview"
,
},
},
template
=
"qwen2_vl"
,
...
...
@@ -1912,6 +2152,17 @@ register_model_group(
)
register_model_group
(
models
=
{
"Skywork-o1-Open-Llama-3.1-8B"
:
{
DownloadSource
.
DEFAULT
:
"Skywork/Skywork-o1-Open-Llama-3.1-8B"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Skywork-o1-Open-Llama-3.1-8B"
,
}
},
template
=
"skywork_o1"
,
)
register_model_group
(
models
=
{
"StarCoder2-3B"
:
{
...
...
@@ -1942,19 +2193,40 @@ register_model_group(
DownloadSource
.
OPENMIND
:
"TeleAI/TeleChat-7B-pt"
,
},
"TeleChat-12B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/TeleChat-12B"
,
DownloadSource
.
MODELSCOPE
:
"TeleAI/TeleChat-12B"
,
DownloadSource
.
OPENMIND
:
"TeleAI/TeleChat-12B-pt"
,
},
"TeleChat-12B-v2-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/TeleChat-12B-v2"
,
DownloadSource
.
MODELSCOPE
:
"TeleAI/TeleChat-12B-v2"
,
DownloadSource
.
OPENMIND
:
"TeleAI/TeleChat-12B-pt"
,
},
"TeleChat-52B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/TeleChat-52B"
,
},
},
template
=
"telechat"
,
)
register_model_group
(
models
=
{
"TeleChat2-3B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/TeleChat2-3B"
,
DownloadSource
.
MODELSCOPE
:
"TeleAI/TeleChat2-3B"
,
},
"TeleChat2-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/TeleChat2-7B"
,
DownloadSource
.
MODELSCOPE
:
"TeleAI/TeleChat2-7B"
,
},
"TeleChat2-35B-Chat"
:
{
DownloadSource
.
MODELSCOPE
:
"TeleAI/TeleChat2-35B-Nov"
,
},
"TeleChat2-115B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/TeleChat2-115B"
,
DownloadSource
.
MODELSCOPE
:
"TeleAI/TeleChat2-115B"
,
},
},
template
=
"telechat2"
,
)
register_model_group
(
models
=
{
"Vicuna-v1.5-7B-Chat"
:
{
...
...
src/llamafactory/extras/env.py
View file @
8293100a
...
...
@@ -26,7 +26,7 @@ import trl
from
transformers.utils
import
is_torch_cuda_available
,
is_torch_npu_available
VERSION
=
"0.9.
1
"
VERSION
=
"0.9.
2.dev0
"
def
print_env
()
->
None
:
...
...
src/llamafactory/extras/logging.py
View file @
8293100a
...
...
@@ -68,7 +68,7 @@ class LoggerHandler(logging.Handler):
class
_Logger
(
logging
.
Logger
):
r
"""
A logger that supports
info_
rank0
and warning_once
.
A logger that supports rank0
logging
.
"""
def
info_rank0
(
self
,
*
args
,
**
kwargs
)
->
None
:
...
...
@@ -77,7 +77,7 @@ class _Logger(logging.Logger):
def
warning_rank0
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
warning
(
*
args
,
**
kwargs
)
def
warning_once
(
self
,
*
args
,
**
kwargs
)
->
None
:
def
warning_
rank0_
once
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
warning
(
*
args
,
**
kwargs
)
...
...
@@ -163,11 +163,11 @@ def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
@
lru_cache
(
None
)
def
warning_once
(
self
:
"logging.Logger"
,
*
args
,
**
kwargs
)
->
None
:
def
warning_
rank0_
once
(
self
:
"logging.Logger"
,
*
args
,
**
kwargs
)
->
None
:
if
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
==
0
:
self
.
warning
(
*
args
,
**
kwargs
)
logging
.
Logger
.
info_rank0
=
info_rank0
logging
.
Logger
.
warning_rank0
=
warning_rank0
logging
.
Logger
.
warning_once
=
warning_once
logging
.
Logger
.
warning_
rank0_
once
=
warning_
rank0_
once
src/llamafactory/extras/misc.py
View file @
8293100a
...
...
@@ -17,7 +17,7 @@
import
gc
import
os
from
typing
import
TYPE_CHECKING
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Literal
,
Sequence
,
Tuple
,
Union
import
torch
import
torch.distributed
as
dist
...
...
@@ -73,18 +73,46 @@ class AverageMeter:
self
.
avg
=
self
.
sum
/
self
.
count
def
check_version
(
requirement
:
str
,
mandatory
:
bool
=
False
)
->
None
:
r
"""
Optionally checks the package version.
"""
if
os
.
getenv
(
"DISABLE_VERSION_CHECK"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
and
not
mandatory
:
logger
.
warning_rank0_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
return
if
mandatory
:
hint
=
f
"To fix: run `pip install
{
requirement
}
`."
else
:
hint
=
f
"To fix: run `pip install
{
requirement
}
` or set `DISABLE_VERSION_CHECK=1` to skip this check."
require_version
(
requirement
,
hint
)
def
check_dependencies
()
->
None
:
r
"""
Checks the version of the required packages.
"""
if
os
.
getenv
(
"DISABLE_VERSION_CHECK"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]:
logger
.
warning_once
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
else
:
require_version
(
"transformers>=4.41.2,<=4.46.1"
,
"To fix: pip install transformers>=4.41.2,<=4.46.1"
)
require_version
(
"datasets>=2.16.0,<=3.1.0"
,
"To fix: pip install datasets>=2.16.0,<=3.1.0"
)
require_version
(
"accelerate>=0.34.0,<=1.0.1"
,
"To fix: pip install accelerate>=0.34.0,<=1.0.1"
)
require_version
(
"peft>=0.11.1,<=0.12.0"
,
"To fix: pip install peft>=0.11.1,<=0.12.0"
)
require_version
(
"trl>=0.8.6,<=0.9.6"
,
"To fix: pip install trl>=0.8.6,<=0.9.6"
)
check_version
(
"transformers>=4.41.2,<=4.46.1"
)
check_version
(
"datasets>=2.16.0,<=3.1.0"
)
check_version
(
"accelerate>=0.34.0,<=1.0.1"
)
check_version
(
"peft>=0.11.1,<=0.12.0"
)
check_version
(
"trl>=0.8.6,<=0.9.6"
)
def
calculate_tps
(
dataset
:
Sequence
[
Dict
[
str
,
Any
]],
metrics
:
Dict
[
str
,
float
],
stage
:
Literal
[
"sft"
,
"rm"
])
->
float
:
r
"""
Calculates effective tokens per second.
"""
effective_token_num
=
0
for
data
in
dataset
:
if
stage
==
"sft"
:
effective_token_num
+=
len
(
data
[
"input_ids"
])
elif
stage
==
"rm"
:
effective_token_num
+=
len
(
data
[
"chosen_input_ids"
])
+
len
(
data
[
"rejected_input_ids"
])
result
=
effective_token_num
*
metrics
[
"epoch"
]
/
metrics
[
"train_runtime"
]
return
result
/
dist
.
get_world_size
()
if
dist
.
is_initialized
()
else
result
def
count_parameters
(
model
:
"torch.nn.Module"
)
->
Tuple
[
int
,
int
]:
...
...
@@ -213,7 +241,7 @@ def skip_check_imports() -> None:
r
"""
Avoids flash attention import error in custom model files.
"""
if
os
.
environ
.
get
(
"FORCE_CHECK_IMPORTS"
,
"0"
).
lower
()
not
in
[
"true"
,
"1"
]:
if
os
.
get
env
(
"FORCE_CHECK_IMPORTS"
,
"0"
).
lower
()
not
in
[
"true"
,
"1"
]:
transformers
.
dynamic_module_utils
.
check_imports
=
get_relative_imports
...
...
@@ -237,7 +265,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
return
model_args
.
model_name_or_path
if
use_modelscope
():
require
_version
(
"modelscope>=1.11.0"
,
"To fix: pip install modelscope>=1.11.0"
)
check
_version
(
"modelscope>=1.11.0"
,
mandatory
=
True
)
from
modelscope
import
snapshot_download
# type: ignore
revision
=
"master"
if
model_args
.
model_revision
==
"main"
else
model_args
.
model_revision
...
...
@@ -248,7 +276,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
)
if
use_openmind
():
require
_version
(
"openmind>=0.8.0"
,
"To fix: pip install openmind>=0.8.0"
)
check
_version
(
"openmind>=0.8.0"
,
mandatory
=
True
)
from
openmind.utils.hub
import
snapshot_download
# type: ignore
return
snapshot_download
(
...
...
@@ -259,16 +287,12 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
def
use_modelscope
()
->
bool
:
return
os
.
environ
.
get
(
"USE_MODELSCOPE_HUB"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
return
os
.
get
env
(
"USE_MODELSCOPE_HUB"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
def
use_openmind
()
->
bool
:
return
os
.
environ
.
get
(
"USE_OPENMIND_HUB"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
return
os
.
get
env
(
"USE_OPENMIND_HUB"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
def
cal_effective_tokens
(
effective_token_num
,
epoch
,
train_runtime
)
->
int
:
r
"""
calculate effective tokens.
"""
result
=
effective_token_num
*
epoch
/
train_runtime
return
result
/
dist
.
get_world_size
()
if
dist
.
is_initialized
()
else
result
def
use_ray
()
->
bool
:
return
os
.
getenv
(
"USE_RAY"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
src/llamafactory/extras/packages.py
View file @
8293100a
...
...
@@ -50,6 +50,10 @@ def is_galore_available():
return
_is_package_available
(
"galore_torch"
)
def
is_apollo_available
():
return
_is_package_available
(
"apollo_torch"
)
def
is_gradio_available
():
return
_is_package_available
(
"gradio"
)
...
...
@@ -62,6 +66,10 @@ def is_pillow_available():
return
_is_package_available
(
"PIL"
)
def
is_ray_available
():
return
_is_package_available
(
"ray"
)
def
is_requests_available
():
return
_is_package_available
(
"requests"
)
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment