Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
7ea81099
Commit
7ea81099
authored
Apr 07, 2025
by
chenych
Browse files
update llama4
parent
84987715
Changes
139
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1344 additions
and
780 deletions
+1344
-780
src/llamafactory/chat/chat_model.py
src/llamafactory/chat/chat_model.py
+38
-47
src/llamafactory/chat/hf_engine.py
src/llamafactory/chat/hf_engine.py
+47
-53
src/llamafactory/chat/sglang_engine.py
src/llamafactory/chat/sglang_engine.py
+275
-0
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+47
-37
src/llamafactory/cli.py
src/llamafactory/cli.py
+2
-3
src/llamafactory/data/__init__.py
src/llamafactory/data/__init__.py
+4
-4
src/llamafactory/data/collator.py
src/llamafactory/data/collator.py
+61
-35
src/llamafactory/data/converter.py
src/llamafactory/data/converter.py
+27
-29
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+10
-15
src/llamafactory/data/formatter.py
src/llamafactory/data/formatter.py
+6
-9
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+11
-19
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+732
-446
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+9
-13
src/llamafactory/data/processor/__init__.py
src/llamafactory/data/processor/__init__.py
+15
-1
src/llamafactory/data/processor/feedback.py
src/llamafactory/data/processor/feedback.py
+12
-12
src/llamafactory/data/processor/pairwise.py
src/llamafactory/data/processor/pairwise.py
+9
-9
src/llamafactory/data/processor/pretrain.py
src/llamafactory/data/processor/pretrain.py
+4
-4
src/llamafactory/data/processor/processor_utils.py
src/llamafactory/data/processor/processor_utils.py
+12
-24
src/llamafactory/data/processor/supervised.py
src/llamafactory/data/processor/supervised.py
+14
-11
src/llamafactory/data/processor/unsupervised.py
src/llamafactory/data/processor/unsupervised.py
+9
-9
No files found.
src/llamafactory/chat/chat_model.py
View file @
7ea81099
# Copyright 202
4
THUDM and the LlamaFactory team.
# Copyright 202
5
THUDM and the LlamaFactory team.
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
...
...
@@ -17,13 +17,15 @@
import
asyncio
import
os
from
collections.abc
import
AsyncGenerator
,
Generator
from
threading
import
Thread
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Dict
,
Generator
,
List
,
Optional
,
Sequence
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
..extras.constants
import
EngineName
from
..extras.misc
import
torch_gc
from
..hparams
import
get_infer_args
from
.hf_engine
import
HuggingfaceEngine
from
.sglang_engine
import
SGLangEngine
from
.vllm_engine
import
VllmEngine
...
...
@@ -38,20 +40,21 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class
ChatModel
:
r
"""
General class for chat models. Backed by huggingface or vllm engines.
r
"""General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def
__init__
(
self
,
args
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
)
->
None
:
def
__init__
(
self
,
args
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
)
->
None
:
model_args
,
data_args
,
finetuning_args
,
generating_args
=
get_infer_args
(
args
)
if
model_args
.
infer_backend
==
EngineName
.
HF
:
self
.
engine
:
"
BaseEngine
"
=
HuggingfaceEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
self
.
engine
:
BaseEngine
=
HuggingfaceEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
elif
model_args
.
infer_backend
==
EngineName
.
VLLM
:
self
.
engine
:
"BaseEngine"
=
VllmEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
self
.
engine
:
BaseEngine
=
VllmEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
self
.
engine
:
BaseEngine
=
SGLangEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
else
:
raise
NotImplementedError
(
f
"Unknown backend:
{
model_args
.
infer_backend
}
"
)
...
...
@@ -61,17 +64,15 @@ class ChatModel:
def
chat
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
r
"""
Gets a list of responses of the chat model.
"""
)
->
list
[
"Response"
]:
r
"""Get a list of responses of the chat model."""
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
achat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
),
self
.
_loop
)
...
...
@@ -79,32 +80,28 @@ class ChatModel:
async
def
achat
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
r
"""
Asynchronously gets a list of responses of the chat model.
"""
)
->
list
[
"Response"
]:
r
"""Asynchronously get a list of responses of the chat model."""
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
def
stream_chat
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
Generator
[
str
,
None
,
None
]:
r
"""
Gets the response token-by-token of the chat model.
"""
r
"""Get the response token-by-token of the chat model."""
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
while
True
:
try
:
...
...
@@ -115,17 +112,15 @@ class ChatModel:
async
def
astream_chat
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
r
"""
Asynchronously gets the response token-by-token of the chat model.
"""
r
"""Asynchronously get the response token-by-token of the chat model."""
async
for
new_token
in
self
.
engine
.
stream_chat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
):
...
...
@@ -133,23 +128,19 @@ class ChatModel:
def
get_scores
(
self
,
batch_input
:
L
ist
[
str
],
batch_input
:
l
ist
[
str
],
**
input_kwargs
,
)
->
List
[
float
]:
r
"""
Gets a list of scores of the reward model.
"""
)
->
list
[
float
]:
r
"""Get a list of scores of the reward model."""
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
aget_scores
(
batch_input
,
**
input_kwargs
),
self
.
_loop
)
return
task
.
result
()
async
def
aget_scores
(
self
,
batch_input
:
L
ist
[
str
],
batch_input
:
l
ist
[
str
],
**
input_kwargs
,
)
->
List
[
float
]:
r
"""
Asynchronously gets a list of scores of the reward model.
"""
)
->
list
[
float
]:
r
"""Asynchronously get a list of scores of the reward model."""
return
await
self
.
engine
.
get_scores
(
batch_input
,
**
input_kwargs
)
...
...
src/llamafactory/chat/hf_engine.py
View file @
7ea81099
...
...
@@ -13,10 +13,10 @@
# limitations under the License.
import
asyncio
import
concurrent.futures
import
os
from
collections.abc
import
AsyncGenerator
from
threading
import
Thread
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
import
torch
from
transformers
import
GenerationConfig
,
TextIteratorStreamer
...
...
@@ -76,15 +76,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
template
:
"Template"
,
generating_args
:
D
ict
[
str
,
Any
],
messages
:
Sequence
[
D
ict
[
str
,
str
]],
generating_args
:
d
ict
[
str
,
Any
],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
input_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
{},
)
->
T
uple
[
D
ict
[
str
,
Any
],
int
]:
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
input_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
{},
)
->
t
uple
[
d
ict
[
str
,
Any
],
int
]:
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"audios"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
],
"audlens"
:
[
0
]}
if
images
is
not
None
:
mm_input_dict
.
update
({
"images"
:
images
,
"imglens"
:
[
len
(
images
)]})
...
...
@@ -130,7 +130,7 @@ class HuggingfaceEngine(BaseEngine):
skip_special_tokens
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"skip_special_tokens"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
L
ist
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
stop
:
Optional
[
Union
[
str
,
l
ist
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
stop
is
not
None
:
logger
.
warning_rank0
(
"Stop parameter is not supported by the huggingface engine yet."
)
...
...
@@ -217,15 +217,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
template
:
"Template"
,
generating_args
:
D
ict
[
str
,
Any
],
messages
:
Sequence
[
D
ict
[
str
,
str
]],
generating_args
:
d
ict
[
str
,
Any
],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
input_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
{},
)
->
L
ist
[
"Response"
]:
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
input_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
{},
)
->
l
ist
[
"Response"
]:
gen_kwargs
,
prompt_length
=
HuggingfaceEngine
.
_process_args
(
model
,
tokenizer
,
...
...
@@ -272,14 +272,14 @@ class HuggingfaceEngine(BaseEngine):
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
template
:
"Template"
,
generating_args
:
D
ict
[
str
,
Any
],
messages
:
Sequence
[
D
ict
[
str
,
str
]],
generating_args
:
d
ict
[
str
,
Any
],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
input_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
{},
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
input_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
{},
)
->
Callable
[[],
str
]:
gen_kwargs
,
_
=
HuggingfaceEngine
.
_process_args
(
model
,
...
...
@@ -317,12 +317,12 @@ class HuggingfaceEngine(BaseEngine):
def
_get_scores
(
model
:
"PreTrainedModelWrapper"
,
tokenizer
:
"PreTrainedTokenizer"
,
batch_input
:
L
ist
[
str
],
input_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
{},
)
->
L
ist
[
float
]:
batch_input
:
l
ist
[
str
],
input_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
{},
)
->
l
ist
[
float
]:
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
device
=
getattr
(
model
.
pretrained_model
,
"device"
,
"cuda"
)
inputs
:
D
ict
[
str
,
"
torch.Tensor
"
]
=
tokenizer
(
inputs
:
d
ict
[
str
,
torch
.
Tensor
]
=
tokenizer
(
batch_input
,
padding
=
True
,
truncation
=
True
,
...
...
@@ -330,25 +330,24 @@ class HuggingfaceEngine(BaseEngine):
return_tensors
=
"pt"
,
add_special_tokens
=
False
,
).
to
(
device
)
values
:
"
torch.Tensor
"
=
model
(
**
inputs
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
values
:
torch
.
Tensor
=
model
(
**
inputs
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
scores
=
values
.
gather
(
dim
=-
1
,
index
=
(
inputs
[
"attention_mask"
].
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
return
scores
@
override
async
def
chat
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
L
ist
[
"Response"
]:
)
->
l
ist
[
"Response"
]:
if
not
self
.
can_generate
:
raise
ValueError
(
"The current model does not support `chat`."
)
loop
=
asyncio
.
get_running_loop
()
input_args
=
(
self
.
model
,
self
.
tokenizer
,
...
...
@@ -364,24 +363,22 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs
,
)
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
return
await
loop
.
run_in_executor
(
pool
,
self
.
_chat
,
*
input_args
)
return
await
asyncio
.
to_thread
(
self
.
_chat
,
*
input_args
)
@
override
async
def
stream_chat
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
if
not
self
.
can_generate
:
raise
ValueError
(
"The current model does not support `stream_chat`."
)
loop
=
asyncio
.
get_running_loop
()
input_args
=
(
self
.
model
,
self
.
tokenizer
,
...
...
@@ -397,25 +394,22 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs
,
)
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
stream
=
self
.
_stream_chat
(
*
input_args
)
while
True
:
try
:
yield
await
loop
.
run_in_executor
(
pool
,
stream
)
yield
await
asyncio
.
to_thread
(
stream
)
except
StopAsyncIteration
:
break
@
override
async
def
get_scores
(
self
,
batch_input
:
L
ist
[
str
],
batch_input
:
l
ist
[
str
],
**
input_kwargs
,
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
if
self
.
can_generate
:
raise
ValueError
(
"Cannot get scores using an auto-regressive model."
)
loop
=
asyncio
.
get_running_loop
()
input_args
=
(
self
.
model
,
self
.
tokenizer
,
batch_input
,
input_kwargs
)
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
return
await
loop
.
run_in_executor
(
pool
,
self
.
_get_scores
,
*
input_args
)
return
await
asyncio
.
to_thread
(
self
.
_get_scores
,
*
input_args
)
src/llamafactory/chat/sglang_engine.py
0 → 100644
View file @
7ea81099
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
import
atexit
import
json
from
collections.abc
import
AsyncGenerator
,
AsyncIterator
,
Sequence
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
import
requests
from
typing_extensions
import
override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
,
EngineName
from
..extras.misc
import
get_device_count
,
torch_gc
from
..extras.packages
import
is_sglang_available
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
from
..model
import
load_config
,
load_tokenizer
from
..model.model_utils.quantization
import
QuantizationMethod
from
.base_engine
import
BaseEngine
,
Response
if
is_sglang_available
():
from
sglang.utils
import
launch_server_cmd
,
terminate_process
,
wait_for_server
# type: ignore
if
TYPE_CHECKING
:
from
..data.mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
class
SGLangEngine
(
BaseEngine
):
"""Inference engine for SGLang models.
This class wraps the SGLang engine to provide a consistent interface for text generation
that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for
better interaction and performance. The engine launches a server process and communicates
with it via HTTP requests.
For more details on the SGLang HTTP server approach, see:
https://docs.sglang.ai/backend/send_request.html
"""
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
self
.
name
=
EngineName
.
SGLANG
self
.
model_args
=
model_args
config
=
load_config
(
model_args
)
# may download model from ms hub
if
getattr
(
config
,
"quantization_config"
,
None
):
# gptq models should use float16
quantization_config
:
dict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
==
QuantizationMethod
.
GPTQ
and
model_args
.
infer_dtype
==
"auto"
:
model_args
.
infer_dtype
=
"float16"
self
.
can_generate
=
finetuning_args
.
stage
==
"sft"
tokenizer_module
=
load_tokenizer
(
model_args
)
self
.
tokenizer
=
tokenizer_module
[
"tokenizer"
]
self
.
processor
=
tokenizer_module
[
"processor"
]
self
.
tokenizer
.
padding_side
=
"left"
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
template
.
mm_plugin
.
expand_mm_tokens
=
False
# for sglang generate
self
.
generating_args
=
generating_args
.
to_dict
()
launch_cmd
=
[
"python3 -m sglang.launch_server"
,
f
"--model-path
{
model_args
.
model_name_or_path
}
"
,
f
"--dtype
{
model_args
.
infer_dtype
}
"
,
f
"--context-length
{
model_args
.
sglang_maxlen
}
"
,
f
"--mem-fraction-static
{
model_args
.
sglang_mem_fraction
}
"
,
f
"--tp-size
{
model_args
.
sglang_tp_size
if
model_args
.
sglang_tp_size
!=
-
1
else
get_device_count
()
or
1
}
"
,
f
"--download-dir
{
model_args
.
cache_dir
}
"
,
"--log-level error"
,
]
launch_cmd
=
" "
.
join
(
launch_cmd
)
logger
.
info_rank0
(
f
"Starting SGLang server with command:
{
launch_cmd
}
"
)
try
:
torch_gc
()
self
.
server_process
,
port
=
launch_server_cmd
(
launch_cmd
)
self
.
base_url
=
f
"http://localhost:
{
port
}
"
atexit
.
register
(
self
.
_cleanup_server
)
logger
.
info_rank0
(
f
"Waiting for SGLang server to be ready at
{
self
.
base_url
}
"
)
wait_for_server
(
self
.
base_url
,
timeout
=
300
)
logger
.
info_rank0
(
f
"SGLang server initialized successfully at
{
self
.
base_url
}
"
)
try
:
response
=
requests
.
get
(
f
"
{
self
.
base_url
}
/get_model_info"
,
timeout
=
5
)
if
response
.
status_code
==
200
:
model_info
=
response
.
json
()
logger
.
info
(
f
"SGLang server model info:
{
model_info
}
"
)
except
Exception
as
e
:
logger
.
debug
(
f
"Note: could not get model info:
{
str
(
e
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start SGLang server:
{
str
(
e
)
}
"
)
self
.
_cleanup_server
()
# make sure to clean up any started process
raise
RuntimeError
(
f
"SGLang server initialization failed:
{
str
(
e
)
}
."
)
def
_cleanup_server
(
self
):
r
"""Clean up the server process when the engine is destroyed."""
if
hasattr
(
self
,
"server_process"
)
and
self
.
server_process
:
try
:
logger
.
info
(
"Terminating SGLang server process"
)
terminate_process
(
self
.
server_process
)
logger
.
info
(
"SGLang server process terminated"
)
except
Exception
as
e
:
logger
.
warning
(
f
"Error terminating SGLang server:
{
str
(
e
)
}
"
)
async
def
_generate
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncIterator
[
dict
[
str
,
Any
]]:
if
images
is
not
None
and
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
*
len
(
images
)
+
messages
[
0
][
"content"
]
if
videos
is
not
None
and
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
*
len
(
videos
)
+
messages
[
0
][
"content"
]
if
audios
is
not
None
and
not
any
(
AUDIO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
AUDIO_PLACEHOLDER
*
len
(
audios
)
+
messages
[
0
][
"content"
]
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
temperature
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"temperature"
,
None
)
top_p
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_p"
,
None
)
top_k
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_k"
,
None
)
num_return_sequences
:
int
=
input_kwargs
.
pop
(
"num_return_sequences"
,
1
)
repetition_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"repetition_penalty"
,
None
)
skip_special_tokens
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"skip_special_tokens"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
num_return_sequences
!=
1
:
raise
NotImplementedError
(
"SGLang only supports n=1."
)
if
"max_new_tokens"
in
self
.
generating_args
:
max_tokens
=
self
.
generating_args
[
"max_new_tokens"
]
elif
"max_length"
in
self
.
generating_args
:
if
self
.
generating_args
[
"max_length"
]
>
prompt_length
:
max_tokens
=
self
.
generating_args
[
"max_length"
]
-
prompt_length
else
:
max_tokens
=
1
if
max_length
:
max_tokens
=
max_length
-
prompt_length
if
max_length
>
prompt_length
else
1
if
max_new_tokens
:
max_tokens
=
max_new_tokens
sampling_params
=
{
"temperature"
:
temperature
if
temperature
is
not
None
else
self
.
generating_args
[
"temperature"
],
"top_p"
:
(
top_p
if
top_p
is
not
None
else
self
.
generating_args
[
"top_p"
])
or
1.0
,
# top_p must > 0
"top_k"
:
(
top_k
if
top_k
is
not
None
else
self
.
generating_args
[
"top_k"
])
or
-
1
,
# top_k must > 0
"stop"
:
stop
,
"stop_token_ids"
:
self
.
template
.
get_stop_token_ids
(
self
.
tokenizer
),
"max_new_tokens"
:
max_tokens
,
"repetition_penalty"
:
(
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
generating_args
[
"repetition_penalty"
]
)
or
1.0
,
# repetition_penalty must > 0
"skip_special_tokens"
:
skip_special_tokens
if
skip_special_tokens
is
not
None
else
self
.
generating_args
[
"skip_special_tokens"
],
}
def
stream_request
():
json_data
=
{
"input_ids"
:
prompt_ids
,
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
}
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
json_data
,
stream
=
True
)
if
response
.
status_code
!=
200
:
raise
RuntimeError
(
f
"SGLang server error:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
for
chunk
in
response
.
iter_lines
(
decode_unicode
=
False
):
chunk
=
str
(
chunk
.
decode
(
"utf-8"
))
if
chunk
==
"data: [DONE]"
:
break
if
chunk
and
chunk
.
startswith
(
"data:"
):
yield
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
return
await
asyncio
.
to_thread
(
stream_request
)
@
override
async
def
chat
(
self
,
messages
:
Sequence
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
list
[
"Response"
]:
final_output
=
None
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
for
request_output
in
generator
:
final_output
=
request_output
results
=
[
Response
(
response_text
=
final_output
[
"text"
],
response_length
=
final_output
[
"meta_info"
][
"completion_tokens"
],
prompt_length
=
final_output
[
"meta_info"
][
"prompt_tokens"
],
finish_reason
=
"stop"
if
final_output
[
"meta_info"
][
"finish_reason"
]
==
"stop"
else
"length"
,
)
]
return
results
@
override
async
def
stream_chat
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
generated_text
=
""
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
for
result
in
generator
:
delta_text
=
result
[
"text"
][
len
(
generated_text
)
:]
generated_text
=
result
[
"text"
]
yield
delta_text
@
override
async
def
get_scores
(
self
,
batch_input
:
list
[
str
],
**
input_kwargs
,
)
->
list
[
float
]:
raise
NotImplementedError
(
"SGLang engine does not support `get_scores`."
)
def
__del__
(
self
):
r
"""Ensure server is cleaned up when object is deleted."""
self
.
_cleanup_server
()
try
:
atexit
.
unregister
(
self
.
_cleanup_server
)
except
Exception
:
pass
src/llamafactory/chat/vllm_engine.py
View file @
7ea81099
...
...
@@ -13,7 +13,8 @@
# limitations under the License.
import
uuid
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
AsyncIterator
,
Dict
,
List
,
Optional
,
Sequence
,
Union
from
collections.abc
import
AsyncGenerator
,
AsyncIterator
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
typing_extensions
import
override
...
...
@@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
self
.
model_args
=
model_args
config
=
load_config
(
model_args
)
# may download model from ms hub
if
getattr
(
config
,
"quantization_config"
,
None
):
# gptq models should use float16
quantization_config
:
D
ict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quantization_config
:
d
ict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
==
QuantizationMethod
.
GPTQ
and
model_args
.
infer_dtype
==
"auto"
:
model_args
.
infer_dtype
=
"float16"
...
...
@@ -82,7 +83,7 @@ class VllmEngine(BaseEngine):
"max_lora_rank"
:
model_args
.
vllm_max_lora_rank
,
}
if
self
.
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
engine_args
[
"limit_mm_per_prompt"
]
=
{
"image"
:
4
,
"video"
:
2
}
engine_args
[
"limit_mm_per_prompt"
]
=
{
"image"
:
4
,
"video"
:
2
,
"audio"
:
2
}
if
isinstance
(
model_args
.
vllm_config
,
dict
):
engine_args
.
update
(
model_args
.
vllm_config
)
...
...
@@ -101,33 +102,26 @@ class VllmEngine(BaseEngine):
async
def
_generate
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncIterator
[
"RequestOutput"
]:
request_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"audios"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
],
"audlens"
:
[
0
]}
if
images
is
not
None
:
mm_input_dict
.
update
({
"images"
:
images
,
"imglens"
:
[
len
(
images
)]})
if
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
if
images
is
not
None
and
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
*
len
(
images
)
+
messages
[
0
][
"content"
]
if
videos
is
not
None
:
mm_input_dict
.
update
({
"videos"
:
videos
,
"vidlens"
:
[
len
(
videos
)]})
if
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
if
videos
is
not
None
and
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
*
len
(
videos
)
+
messages
[
0
][
"content"
]
if
audios
is
not
None
:
mm_input_dict
.
update
({
"audios"
:
audios
,
"audlens"
:
[
len
(
audios
)]})
if
not
any
(
AUDIO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
if
audios
is
not
None
and
not
any
(
AUDIO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
AUDIO_PLACEHOLDER
*
len
(
audios
)
+
messages
[
0
][
"content"
]
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
mm_input_dict
[
"
audios
"
],
self
.
processor
messages
,
images
or
[],
videos
or
[],
audios
or
[
],
self
.
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
...
...
@@ -143,7 +137,7 @@ class VllmEngine(BaseEngine):
skip_special_tokens
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"skip_special_tokens"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
L
ist
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
stop
:
Optional
[
Union
[
str
,
l
ist
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
length_penalty
is
not
None
:
logger
.
warning_rank0
(
"Length penalty is not supported by the vllm engine yet."
)
...
...
@@ -185,8 +179,24 @@ class VllmEngine(BaseEngine):
images
,
image_max_pixels
=
self
.
model_args
.
image_max_pixels
,
image_min_pixels
=
self
.
model_args
.
image_min_pixels
,
)
)[
"images"
]
}
elif
videos
is
not
None
:
multi_modal_data
=
{
"video"
:
self
.
template
.
mm_plugin
.
_regularize_videos
(
videos
,
image_max_pixels
=
self
.
model_args
.
video_max_pixels
,
image_min_pixels
=
self
.
model_args
.
video_min_pixels
,
video_fps
=
self
.
model_args
.
video_fps
,
video_maxlen
=
self
.
model_args
.
video_maxlen
,
)[
"videos"
]
}
elif
audios
is
not
None
:
audio_data
=
self
.
template
.
mm_plugin
.
_regularize_audios
(
audios
,
sampling_rate
=
self
.
model_args
.
audio_sampling_rate
,
)
multi_modal_data
=
{
"audio"
:
zip
(
audio_data
[
"audios"
],
audio_data
[
"sampling_rates"
])}
else
:
multi_modal_data
=
None
...
...
@@ -201,14 +211,14 @@ class VllmEngine(BaseEngine):
@
override
async
def
chat
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
L
ist
[
"Response"
]:
)
->
l
ist
[
"Response"
]:
final_output
=
None
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
async
for
request_output
in
generator
:
...
...
@@ -230,12 +240,12 @@ class VllmEngine(BaseEngine):
@
override
async
def
stream_chat
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
generated_text
=
""
...
...
@@ -248,7 +258,7 @@ class VllmEngine(BaseEngine):
@
override
async
def
get_scores
(
self
,
batch_input
:
L
ist
[
str
],
batch_input
:
l
ist
[
str
],
**
input_kwargs
,
)
->
L
ist
[
float
]:
raise
NotImplementedError
(
"vLLM engine does not support get_scores."
)
)
->
l
ist
[
float
]:
raise
NotImplementedError
(
"vLLM engine does not support
`
get_scores
`
."
)
src/llamafactory/cli.py
View file @
7ea81099
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
import
os
import
random
import
subprocess
import
sys
from
enum
import
Enum
,
unique
...
...
@@ -24,7 +23,7 @@ from .chat.chat_model import run_chat
from
.eval.evaluator
import
run_eval
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.misc
import
get_device_count
,
is_env_enabled
,
use_ray
from
.extras.misc
import
find_available_port
,
get_device_count
,
is_env_enabled
,
use_ray
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
...
...
@@ -92,7 +91,7 @@ def main():
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
)
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
()))
master_addr
=
os
.
getenv
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
str
(
random
.
randint
(
20001
,
29999
)))
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
str
(
find_available_port
(
)))
logger
.
info_rank0
(
f
"Initializing
{
nproc_per_node
}
distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
if
int
(
nnodes
)
>
1
:
print
(
f
"Multi-node training enabled: num nodes:
{
nnodes
}
, node rank:
{
node_rank
}
"
)
...
...
src/llamafactory/data/__init__.py
View file @
7ea81099
...
...
@@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__
=
[
"TEMPLATES"
,
"KTODataCollatorWithPadding"
,
"MultiModalDataCollatorForSeq2Seq"
,
"PairwiseDataCollatorWithPadding"
,
"SFTDataCollatorWith4DAttentionMask"
,
"Role"
,
"split_dataset"
,
"get_dataset"
,
"TEMPLATES"
,
"SFTDataCollatorWith4DAttentionMask"
,
"Template"
,
"get_dataset"
,
"get_template_and_fix_tokenizer"
,
"split_dataset"
,
]
src/llamafactory/data/collator.py
View file @
7ea81099
# Copyright 202
4
OpenAccess AI Collective and the LlamaFactory team.
# Copyright 202
5
OpenAccess AI Collective and the LlamaFactory team.
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
...
...
@@ -16,7 +16,7 @@
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Literal
,
Optional
,
Sequence
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
import
numpy
as
np
import
torch
...
...
@@ -24,6 +24,7 @@ import torch.nn.functional as F
from
transformers
import
DataCollatorForSeq2Seq
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
from
..extras.misc
import
get_current_device
from
..extras.packages
import
is_pillow_available
...
...
@@ -38,9 +39,10 @@ if TYPE_CHECKING:
def
prepare_4d_attention_mask
(
attention_mask_with_indices
:
"torch.Tensor"
,
dtype
:
"torch.dtype"
)
->
"torch.Tensor"
:
r
"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
r
"""Expand 2d attention mask to 4d attention mask.
Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```python
...
...
@@ -62,24 +64,37 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
bsz
,
seq_len
=
attention_mask_with_indices
.
size
()
_
,
seq_len
=
attention_mask_with_indices
.
size
()
# Move to compute device if the source is CPU.
source_device
=
attention_mask_with_indices
.
device
compute_device
=
get_current_device
()
if
source_device
.
type
==
"cpu"
else
source_device
if
compute_device
!=
source_device
:
attention_mask_with_indices
=
attention_mask_with_indices
.
to
(
compute_device
)
min_dtype
=
torch
.
finfo
(
dtype
).
min
expanded_mask
=
attention_mask_with_indices
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
seq_len
,
seq_len
)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
padding_mask
=
torch
.
where
(
expanded_mask
!=
0
,
1
,
0
)
# Create a block-diagonal mask.
attention_mask_4d
=
torch
.
eq
(
expanded_mask
,
expanded_mask
.
transpose
(
-
1
,
-
2
)).
int
()
*
padding_mask
# Use the lower triangular mask to zero out the upper triangular part
attention_mask_4d
*=
torch
.
tril
(
torch
.
ones
((
seq_len
,
seq_len
),
dtype
=
torch
.
long
))
zero_tensor
=
torch
.
tensor
(
0
,
dtype
=
dtype
,
device
=
compute_device
)
# Create a non-padding mask.
non_padding
=
(
attention_mask_with_indices
!=
0
).
unsqueeze
(
1
).
unsqueeze
(
2
)
# Create indices for comparison.
indices
=
attention_mask_with_indices
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# [bsz, 1, 1, seq_len]
indices_t
=
attention_mask_with_indices
.
unsqueeze
(
1
).
unsqueeze
(
3
)
# [bsz, 1, seq_len, 1]
# Create a lower triangular mask.
tril_mask
=
torch
.
tril
(
torch
.
ones
((
seq_len
,
seq_len
),
dtype
=
torch
.
bool
,
device
=
compute_device
))
attention_mask_4d
=
(
indices
==
indices_t
)
&
non_padding
&
tril_mask
# Invert the attention mask.
attention_mask_4d
=
torch
.
where
(
attention_mask_4d
!=
0
,
torch
.
tensor
(
0
,
dtype
=
dtype
),
min_dtype
)
attention_mask_4d
=
torch
.
where
(
attention_mask_4d
,
zero_tensor
,
min_dtype
)
# Move back to original device if needed.
if
compute_device
!=
source_device
:
attention_mask_4d
=
attention_mask_4d
.
to
(
source_device
)
return
attention_mask_4d
@
dataclass
class
MultiModalDataCollatorForSeq2Seq
(
DataCollatorForSeq2Seq
):
r
"""
Data collator that supports VLMs.
r
"""Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
"""
...
...
@@ -91,7 +106,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if
self
.
template
is
None
:
raise
ValueError
(
"Template is required for MultiModalDataCollator."
)
def
__call__
(
self
,
features
:
Sequence
[
D
ict
[
str
,
Any
]])
->
D
ict
[
str
,
"torch.Tensor"
]:
def
__call__
(
self
,
features
:
list
[
d
ict
[
str
,
Any
]])
->
d
ict
[
str
,
"torch.Tensor"
]:
batch_images
,
batch_videos
,
batch_audios
=
[],
[],
[]
batch_imglens
,
batch_vidlens
,
batch_audlens
,
batch_input_ids
=
[],
[],
[],
[]
for
feature
in
features
:
...
...
@@ -166,7 +181,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for
i
,
feature
in
enumerate
(
features
):
feature
[
"token_type_ids"
]
=
token_type_ids
[
i
]
features
:
D
ict
[
str
,
"
torch.Tensor
"
]
=
super
().
__call__
(
features
)
features
:
d
ict
[
str
,
torch
.
Tensor
]
=
super
().
__call__
(
features
)
if
self
.
model
is
not
None
and
hasattr
(
self
.
model
,
"get_rope_index"
):
# for qwen2vl mrope
rope_index_kwargs
=
{
...
...
@@ -175,9 +190,27 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"video_grid_thw"
:
mm_inputs
.
get
(
"video_grid_thw"
),
"attention_mask"
:
features
[
"attention_mask"
],
}
if
"second_per_grid_ts"
in
mm_inputs
:
if
"second_per_grid_ts"
in
mm_inputs
:
# for qwen2vl
rope_index_kwargs
[
"second_per_grid_ts"
]
=
mm_inputs
.
get
(
"second_per_grid_ts"
)
if
"video_second_per_grid"
in
mm_inputs
:
# for qwen2omni
rope_index_kwargs
[
"second_per_grids"
]
=
mm_inputs
.
get
(
"video_second_per_grid"
)
if
getattr
(
self
.
model
.
config
,
"model_type"
,
None
)
==
"qwen2_5_omni_thinker"
:
# for qwen2omni
feature_attention_mask
=
mm_inputs
.
get
(
"feature_attention_mask"
,
None
)
if
feature_attention_mask
is
not
None
:
audio_feature_lengths
=
torch
.
sum
(
feature_attention_mask
,
dim
=
1
)
# FIXME need to get video image lengths
rope_index_kwargs
[
"audio_seqlens"
]
=
audio_feature_lengths
# prepare for input
delta0
=
(
1
-
rope_index_kwargs
[
"attention_mask"
]).
sum
(
dim
=-
1
).
unsqueeze
(
1
)
# avoid conflict
new_position_ids
,
rope_deltas
=
self
.
model
.
get_rope_index
(
**
rope_index_kwargs
)
features
[
"position_ids"
],
features
[
"rope_deltas"
]
=
(
new_position_ids
.
clone
(),
rope_deltas
-
delta0
,
)
# avoid inplace operation FIXME
else
:
# for qwen2vl
features
[
"position_ids"
],
features
[
"rope_deltas"
]
=
self
.
model
.
get_rope_index
(
**
rope_index_kwargs
)
if
"cross_attention_mask"
in
mm_inputs
:
# for mllama inputs when pad_to_multiple_of is enabled
...
...
@@ -198,15 +231,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
@
dataclass
class
SFTDataCollatorWith4DAttentionMask
(
MultiModalDataCollatorForSeq2Seq
):
r
"""
Data collator for 4d attention mask.
"""
r
"""Data collator for 4d attention mask."""
block_diag_attn
:
bool
=
False
attn_implementation
:
Literal
[
"eager"
,
"sdpa"
,
"flash_attention_2"
]
=
"eager"
compute_dtype
:
"torch.dtype"
=
torch
.
float32
def
__call__
(
self
,
features
:
Sequence
[
D
ict
[
str
,
Any
]])
->
D
ict
[
str
,
"torch.Tensor"
]:
def
__call__
(
self
,
features
:
list
[
d
ict
[
str
,
Any
]])
->
d
ict
[
str
,
"torch.Tensor"
]:
features
=
super
().
__call__
(
features
)
if
self
.
block_diag_attn
and
self
.
attn_implementation
!=
"flash_attention_2"
:
features
[
"attention_mask"
]
=
prepare_4d_attention_mask
(
features
[
"attention_mask"
],
self
.
compute_dtype
)
...
...
@@ -220,13 +251,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
@
dataclass
class
PairwiseDataCollatorWithPadding
(
MultiModalDataCollatorForSeq2Seq
):
r
"""
Data collator for pairwise data.
"""
r
"""Data collator for pairwise data."""
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
"torch.Tensor"
]:
r
"""
Pads batched data to the longest sequence in the batch.
def
__call__
(
self
,
features
:
list
[
dict
[
str
,
Any
]])
->
dict
[
str
,
"torch.Tensor"
]:
r
"""Pad batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
...
...
@@ -249,11 +277,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
@
dataclass
class
KTODataCollatorWithPadding
(
MultiModalDataCollatorForSeq2Seq
):
r
"""
Data collator for KTO data.
"""
r
"""Data collator for KTO data."""
def
__call__
(
self
,
features
:
Sequence
[
D
ict
[
str
,
Any
]])
->
D
ict
[
str
,
"torch.Tensor"
]:
def
__call__
(
self
,
features
:
list
[
d
ict
[
str
,
Any
]])
->
d
ict
[
str
,
"torch.Tensor"
]:
target_features
=
[]
kl_features
=
[]
kto_tags
=
[]
...
...
src/llamafactory/data/converter.py
View file @
7ea81099
...
...
@@ -15,7 +15,7 @@
import
os
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
..extras
import
logging
from
.data_utils
import
Role
...
...
@@ -26,8 +26,12 @@ if TYPE_CHECKING:
from
transformers
import
Seq2SeqTrainingArguments
from
..hparams
import
DataArguments
from
.mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
from
.parser
import
DatasetAttr
MediaType
=
Union
[
ImageInput
,
VideoInput
,
AudioInput
]
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -36,12 +40,12 @@ class DatasetConverter:
dataset_attr
:
"DatasetAttr"
data_args
:
"DataArguments"
def
_find_medias
(
self
,
medias
:
Union
[
Any
,
Sequence
[
Any
]
])
->
Optional
[
L
ist
[
Any
]]:
r
"""
Optionally concatenates media path to media dir when loading from local disk.
"""
if
not
isinstance
(
medias
,
list
):
medias
=
[
medias
]
if
medias
is
not
None
else
[]
def
_find_medias
(
self
,
medias
:
Union
[
"MediaType"
,
list
[
"MediaType"
],
None
])
->
Optional
[
l
ist
[
"MediaType"
]]:
r
"""
Optionally concatenate media path to media dir when loading from local disk."""
if
medias
is
None
:
return
None
el
if
not
isinstance
(
medias
,
list
):
medias
=
[
medias
]
elif
len
(
medias
)
==
0
:
return
None
else
:
...
...
@@ -57,16 +61,14 @@ class DatasetConverter:
return
medias
@
abstractmethod
def
__call__
(
self
,
example
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
r
"""
Converts a single example in the dataset to the standard format.
"""
def
__call__
(
self
,
example
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
r
"""Convert a single example in the dataset to the standard format."""
...
@
dataclass
class
AlpacaDatasetConverter
(
DatasetConverter
):
def
__call__
(
self
,
example
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
__call__
(
self
,
example
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
prompt
=
[]
if
self
.
dataset_attr
.
history
and
isinstance
(
example
[
self
.
dataset_attr
.
history
],
list
):
for
old_prompt
,
old_response
in
example
[
self
.
dataset_attr
.
history
]:
...
...
@@ -116,7 +118,7 @@ class AlpacaDatasetConverter(DatasetConverter):
@
dataclass
class
SharegptDatasetConverter
(
DatasetConverter
):
def
__call__
(
self
,
example
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
__call__
(
self
,
example
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
tag_mapping
=
{
self
.
dataset_attr
.
user_tag
:
Role
.
USER
.
value
,
self
.
dataset_attr
.
assistant_tag
:
Role
.
ASSISTANT
.
value
,
...
...
@@ -216,10 +218,8 @@ DATASET_CONVERTERS = {
}
def
register_dataset_converter
(
name
:
str
,
dataset_converter
:
Type
[
"DatasetConverter"
])
->
None
:
r
"""
Register a new dataset converter.
"""
def
register_dataset_converter
(
name
:
str
,
dataset_converter
:
type
[
"DatasetConverter"
])
->
None
:
r
"""Register a new dataset converter."""
if
name
in
DATASET_CONVERTERS
:
raise
ValueError
(
f
"Dataset converter
{
name
}
already exists."
)
...
...
@@ -227,9 +227,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver
def
get_dataset_converter
(
name
:
str
,
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
)
->
"DatasetConverter"
:
r
"""
Gets a dataset converter.
"""
r
"""Get a dataset converter."""
if
name
not
in
DATASET_CONVERTERS
:
raise
ValueError
(
f
"Dataset converter
{
name
}
not found."
)
...
...
@@ -242,17 +240,17 @@ def align_dataset(
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""
r
"""Align the dataset to a specific format.
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "..."
,
_images: []
,
_videos: []
,
_audios: []
,
_tools: "..."
_images: []
_videos: []
_audios: []
"""
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
kwargs
=
{}
if
not
data_args
.
streaming
:
...
...
src/llamafactory/data/data_utils.py
View file @
7ea81099
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
TypedDict
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
TypedDict
,
Union
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
...
...
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
SLOTS
=
Sequence
[
Union
[
str
,
S
et
[
str
],
D
ict
[
str
,
str
]]]
SLOTS
=
list
[
Union
[
str
,
s
et
[
str
],
d
ict
[
str
,
str
]]]
@
unique
...
...
@@ -43,15 +43,13 @@ class Role(str, Enum):
class
DatasetModule
(
TypedDict
):
train_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
D
ict
[
str
,
"Dataset"
]]]
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
d
ict
[
str
,
"Dataset"
]]]
def
merge_dataset
(
all_datasets
:
L
ist
[
Union
[
"Dataset"
,
"IterableDataset"
]],
data_args
:
"DataArguments"
,
seed
:
int
all_datasets
:
l
ist
[
Union
[
"Dataset"
,
"IterableDataset"
]],
data_args
:
"DataArguments"
,
seed
:
int
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""
Merges multiple datasets to a unified dataset.
"""
r
"""Merge multiple datasets to a unified dataset."""
if
len
(
all_datasets
)
==
1
:
return
all_datasets
[
0
]
...
...
@@ -78,14 +76,13 @@ def merge_dataset(
def
split_dataset
(
dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]],
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
D
ict
[
str
,
"Dataset"
]]],
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
d
ict
[
str
,
"Dataset"
]]],
data_args
:
"DataArguments"
,
seed
:
int
,
)
->
"DatasetDict"
:
r
"""
Splits the dataset and returns a dataset dict containing train set and validation set.
r
"""Split the dataset and returns a dataset dict containing train set and validation set.
Support
s
both map dataset and iterable dataset.
Support both map dataset and iterable dataset.
"""
if
eval_dataset
is
not
None
and
data_args
.
val_size
>
1e-6
:
raise
ValueError
(
"Cannot specify `val_size` if `eval_dataset` is not None."
)
...
...
@@ -120,10 +117,8 @@ def split_dataset(
def
get_dataset_module
(
dataset
:
Union
[
"Dataset"
,
"DatasetDict"
])
->
"DatasetModule"
:
r
"""
Converts dataset or dataset dict to dataset module.
"""
dataset_module
:
"DatasetModule"
=
{}
r
"""Convert dataset or dataset dict to dataset module."""
dataset_module
:
DatasetModule
=
{}
if
isinstance
(
dataset
,
DatasetDict
):
# dataset dict
if
"train"
in
dataset
:
dataset_module
[
"train_dataset"
]
=
dataset
[
"train"
]
...
...
src/llamafactory/data/formatter.py
View file @
7ea81099
...
...
@@ -16,7 +16,7 @@ import json
import
re
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
,
Union
from
typing
import
Optional
,
Union
from
typing_extensions
import
override
...
...
@@ -31,14 +31,11 @@ class Formatter(ABC):
@
abstractmethod
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
r
"""
Forms a list of slots according to the inputs to encode.
"""
r
"""Forms a list of slots according to the inputs to encode."""
...
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
r
"""
Extract a list of tuples from the response message if using tools.
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
r
"""Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
"""
...
...
@@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter):
if
thought
:
content
=
content
.
replace
(
thought
.
group
(
0
),
""
)
functions
:
L
ist
[
"
FunctionCall
"
]
=
[]
functions
:
l
ist
[
FunctionCall
]
=
[]
try
:
tool_calls
=
json
.
loads
(
content
)
if
not
isinstance
(
tool_calls
,
list
):
# parallel function call
...
...
@@ -141,5 +138,5 @@ class ToolFormatter(Formatter):
raise
RuntimeError
(
f
"Invalid JSON format in tool description:
{
str
([
content
])
}
."
)
# flat string
@
override
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
L
ist
[
"FunctionCall"
]]:
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
l
ist
[
"FunctionCall"
]]:
return
self
.
tool_utils
.
tool_extractor
(
content
)
src/llamafactory/data/loader.py
View file @
7ea81099
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
,
Dict
,
Literal
,
Optional
,
Sequence
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
numpy
as
np
from
datasets
import
load_dataset
,
load_from_disk
...
...
@@ -54,9 +54,7 @@ def _load_single_dataset(
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""
Loads a single dataset and aligns it to the standard format.
"""
r
"""Load a single dataset and aligns it to the standard format."""
logger
.
info_rank0
(
f
"Loading dataset
{
dataset_attr
}
..."
)
data_path
,
data_name
,
data_dir
,
data_files
=
None
,
None
,
None
,
None
if
dataset_attr
.
load_from
in
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
]:
...
...
@@ -133,10 +131,12 @@ def _load_single_dataset(
split
=
dataset_attr
.
split
,
cache_dir
=
model_args
.
cache_dir
,
token
=
model_args
.
hf_hub_token
,
streaming
=
data_args
.
streaming
,
num_proc
=
data_args
.
preprocessing_num_workers
,
trust_remote_code
=
model_args
.
trust_remote_code
,
streaming
=
data_args
.
streaming
and
dataset_attr
.
load_from
!=
"file"
,
)
if
data_args
.
streaming
and
dataset_attr
.
load_from
==
"file"
:
dataset
=
dataset
.
to_iterable_dataset
(
num_shards
=
training_args
.
dataloader_num_workers
)
if
dataset_attr
.
num_samples
is
not
None
and
not
data_args
.
streaming
:
target_num
=
dataset_attr
.
num_samples
...
...
@@ -158,16 +158,14 @@ def _load_single_dataset(
def
_get_merged_dataset
(
dataset_names
:
Optional
[
Sequence
[
str
]],
dataset_names
:
Optional
[
list
[
str
]],
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
merge
:
bool
=
True
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
Dict
[
str
,
"Dataset"
]]]:
r
"""
Returns the merged datasets in the standard format.
"""
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]]:
r
"""Return the merged datasets in the standard format."""
if
dataset_names
is
None
:
return
None
...
...
@@ -192,9 +190,7 @@ def _get_dataset_processor(
processor
:
Optional
[
"ProcessorMixin"
],
do_generate
:
bool
=
False
,
)
->
"DatasetProcessor"
:
r
"""
Returns the corresponding dataset processor.
"""
r
"""Return the corresponding dataset processor."""
if
stage
==
"pt"
:
dataset_processor_class
=
PretrainDatasetProcessor
elif
stage
==
"sft"
and
not
do_generate
:
...
...
@@ -236,9 +232,7 @@ def _get_preprocessed_dataset(
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
is_eval
:
bool
=
False
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]:
r
"""
Preprocesses the dataset, including format checking and tokenization.
"""
r
"""Preprocesses the dataset, including format checking and tokenization."""
if
dataset
is
None
:
return
None
...
...
@@ -284,9 +278,7 @@ def get_dataset(
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
)
->
"DatasetModule"
:
r
"""
Gets the train dataset and optionally gets the evaluation dataset.
"""
r
"""Get the train dataset and optionally gets the evaluation dataset."""
# Load tokenized dataset if path exists
if
data_args
.
tokenized_path
is
not
None
:
if
has_tokenized_data
(
data_args
.
tokenized_path
):
...
...
src/llamafactory/data/mm_plugin.py
View file @
7ea81099
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's Transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llava/processing_llava.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
inspect
import
math
import
re
from
copy
import
deepcopy
from
dataclasses
import
dataclass
from
io
import
BytesIO
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
TypedDict
,
Union
from
typing
import
TYPE_CHECKING
,
BinaryIO
,
Literal
,
Optional
,
TypedDict
,
Union
import
numpy
as
np
import
torch
...
...
@@ -51,28 +68,65 @@ if TYPE_CHECKING:
path
:
Optional
[
str
]
bytes
:
Optional
[
bytes
]
ImageInput
=
Union
[
str
,
bytes
,
EncodedImage
,
ImageObject
]
VideoInput
=
str
AudioInput
=
Union
[
str
,
NDArray
]
ImageInput
=
Union
[
str
,
bytes
,
EncodedImage
,
BinaryIO
,
ImageObject
]
VideoInput
=
Union
[
str
,
BinaryIO
]
AudioInput
=
Union
[
str
,
BinaryIO
,
NDArray
]
class
MMProcessor
(
ProcessorMixin
):
patch_size
:
int
image_seq_length
:
int
num_additional_image_tokens
:
int
vision_feature_select_strategy
:
Literal
[
"default"
,
"full"
]
def
_get_number_of_features
(
self
,
orig_height
:
int
,
orig_width
:
int
,
height
:
int
,
width
:
int
)
->
int
:
pass
def
_get_paligemma_token_type_ids
(
imglens
:
Sequence
[
int
],
seqlens
:
Sequence
[
int
],
processor
:
"ProcessorMixin"
)
->
List
[
List
[
int
]]:
r
"""
Gets paligemma token type ids for computing loss.
def
_get_paligemma_token_type_ids
(
imglens
:
list
[
int
],
seqlens
:
list
[
int
],
processor
:
"MMProcessor"
)
->
list
[
list
[
int
]]:
r
"""Get paligemma token type ids for computing loss.
It is slightly different with the original token type ids where the prompt part is 0.
Returns:
batch_token_type_ids: shape (batch_size, sequence_length)
batch_token_type_ids: shape (batch_size, seq_length)
"""
batch_token_type_ids
=
[]
for
imglen
,
seqlen
in
zip
(
imglens
,
seqlens
):
image_seqlen
=
imglen
*
getattr
(
processor
,
"
image_seqlen
"
)
image_seqlen
=
imglen
*
processor
.
image_seq
_
len
gth
batch_token_type_ids
.
append
([
0
]
*
image_seqlen
+
[
1
]
*
(
seqlen
-
image_seqlen
))
return
batch_token_type_ids
def
_get_gemma3_token_type_ids
(
batch_ids
:
list
[
list
[
int
]],
processor
:
"MMProcessor"
):
r
"""Get gemma3 token type ids for computing loss.
Returns:
batch_token_type_ids: shape (batch_size, seq_length)
"""
image_token_id
:
int
=
getattr
(
processor
,
"image_token_id"
)
batch_token_type_ids
=
[]
for
token_ids
in
batch_ids
:
token_ids
=
np
.
array
(
token_ids
)
token_type_ids
=
np
.
zeros_like
(
token_ids
)
token_type_ids
[
token_ids
==
image_token_id
]
=
1
batch_token_type_ids
.
append
(
token_type_ids
.
tolist
())
return
batch_token_type_ids
def
_make_batched_images
(
images
:
list
[
"ImageObject"
],
imglens
:
list
[
int
])
->
list
[
list
[
"ImageObject"
]]:
r
"""Make nested list of images."""
batch_images
=
[]
for
imglen
in
imglens
:
batch_images
.
append
(
images
[:
imglen
])
images
=
images
[
imglen
:]
return
batch_images
@
dataclass
class
MMPluginMixin
:
image_token
:
Optional
[
str
]
...
...
@@ -82,16 +136,17 @@ class MMPluginMixin:
def
_validate_input
(
self
,
processor
:
Optional
[
"Processor
Mixin
"
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"
MM
Processor"
],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
)
->
None
:
r
"""
Validates if this model accepts the input modalities.
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
,
None
)
feature_extractor
:
"SequenceFeatureExtractor"
=
getattr
(
processor
,
"feature_extractor"
,
None
)
r
"""Validate if this model accepts the input modalities."""
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
video_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"video_processor"
,
getattr
(
processor
,
"image_processor"
,
None
)
)
feature_extractor
:
SequenceFeatureExtractor
=
getattr
(
processor
,
"feature_extractor"
,
None
)
if
len
(
images
)
!=
0
and
self
.
image_token
is
None
:
raise
ValueError
(
"This model does not support image input. Please check whether the correct `template` is used."
...
...
@@ -113,15 +168,16 @@ class MMPluginMixin:
if
self
.
image_token
is
not
None
and
image_processor
is
None
:
raise
ValueError
(
"Image processor was not found, please check and update your processor config."
)
if
self
.
video_token
is
not
None
and
video_processor
is
None
:
raise
ValueError
(
"Video processor was not found, please check and update your processor config."
)
if
self
.
audio_token
is
not
None
and
feature_extractor
is
None
:
raise
ValueError
(
"Audio feature extractor was not found, please check and update your processor config."
)
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
image_max_pixels
:
int
,
image_min_pixels
:
int
,
**
kwargs
)
->
"ImageObject"
:
r
"""
Pre-processes a single image.
"""
r
"""Pre-process a single image."""
if
(
image
.
width
*
image
.
height
)
>
image_max_pixels
:
resize_factor
=
math
.
sqrt
(
image_max_pixels
/
(
image
.
width
*
image
.
height
))
width
,
height
=
int
(
image
.
width
*
resize_factor
),
int
(
image
.
height
*
resize_factor
)
...
...
@@ -139,10 +195,8 @@ class MMPluginMixin:
def
_get_video_sample_indices
(
self
,
video_stream
:
"Stream"
,
video_fps
:
float
,
video_maxlen
:
int
,
**
kwargs
)
->
List
[
int
]:
r
"""
Computes video sample indices according to fps.
"""
)
->
list
[
int
]:
r
"""Compute video sample indices according to fps."""
total_frames
=
video_stream
.
frames
if
total_frames
==
0
:
# infinite video
return
np
.
linspace
(
0
,
video_maxlen
-
1
,
video_maxlen
).
astype
(
np
.
int32
)
...
...
@@ -151,13 +205,11 @@ class MMPluginMixin:
sample_frames
=
min
(
total_frames
,
video_maxlen
,
sample_frames
)
return
np
.
linspace
(
0
,
total_frames
-
1
,
sample_frames
).
astype
(
np
.
int32
)
def
_regularize_images
(
self
,
images
:
Sequence
[
"ImageInput"
],
**
kwargs
)
->
List
[
"ImageObject"
]:
r
"""
Regularizes images to avoid error. Including reading and pre-processing.
"""
def
_regularize_images
(
self
,
images
:
list
[
"ImageInput"
],
**
kwargs
)
->
dict
[
str
,
list
[
"ImageObject"
]]:
r
"""Regularize images to avoid error. Including reading and pre-processing."""
results
=
[]
for
image
in
images
:
if
isinstance
(
image
,
str
):
if
isinstance
(
image
,
(
str
,
BinaryIO
)
):
image
=
Image
.
open
(
image
)
elif
isinstance
(
image
,
bytes
):
image
=
Image
.
open
(
BytesIO
(
image
))
...
...
@@ -172,53 +224,52 @@ class MMPluginMixin:
results
.
append
(
self
.
_preprocess_image
(
image
,
**
kwargs
))
return
results
return
{
"images"
:
results
}
def
_regularize_videos
(
self
,
videos
:
Sequence
[
"VideoInput"
],
**
kwargs
)
->
List
[
List
[
"ImageObject"
]]:
r
"""
Regularizes videos to avoid error. Including reading, resizing and converting.
"""
def
_regularize_videos
(
self
,
videos
:
list
[
"VideoInput"
],
**
kwargs
)
->
dict
[
str
,
list
[
list
[
"ImageObject"
]]]:
r
"""Regularizes videos to avoid error. Including reading, resizing and converting."""
results
=
[]
for
video
in
videos
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
frames
:
L
ist
[
"
ImageObject
"
]
=
[]
frames
:
l
ist
[
ImageObject
]
=
[]
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
if
frame_idx
in
sample_indices
:
frames
.
append
(
frame
.
to_image
())
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)
[
"images"
]
results
.
append
(
frames
)
return
results
return
{
"videos"
:
results
}
def
_regularize_audios
(
self
,
audios
:
Sequence
[
"AudioInput"
],
sampling_rate
:
float
,
**
kwargs
)
->
List
[
"NDArray"
]:
r
"""
Regularizes audios to avoid error. Including reading and resampling.
"""
results
=
[]
def
_regularize_audios
(
self
,
audios
:
list
[
"AudioInput"
],
sampling_rate
:
float
,
**
kwargs
)
->
dict
[
str
,
Union
[
list
[
"NDArray"
],
list
[
float
]]]:
r
"""Regularizes audios to avoid error. Including reading and resampling.
"""
results
,
sampling_rates
=
[]
,
[]
for
audio
in
audios
:
if
isinstance
(
audio
,
str
):
audio
=
librosa
.
load
(
audio
,
sr
=
sampling_rate
)
[
0
]
if
isinstance
(
audio
,
(
str
,
BinaryIO
)
):
audio
,
sampling_rate
=
librosa
.
load
(
audio
,
sr
=
sampling_rate
)
if
not
isinstance
(
audio
,
np
.
ndarray
):
raise
ValueError
(
f
"Expect input is a list of audios, but got
{
type
(
audio
)
}
."
)
results
.
append
(
audio
)
sampling_rates
.
append
(
sampling_rate
)
return
results
return
{
"audios"
:
results
,
"sampling_rates"
:
sampling_rates
}
def
_get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
"Processor
Mixin
"
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
r
"""
Process
es
visual inputs.
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
"
MM
Processor"
,
imglens
:
Optional
[
list
[
int
]]
=
None
,
)
->
dict
[
str
,
"torch.Tensor"
]:
r
"""
Process visual inputs.
Returns: (llava and paligemma)
pixel_values: tensor with shape (B, C, H, W)
...
...
@@ -226,44 +277,67 @@ class MMPluginMixin:
Returns: (qwen2-vl)
pixel_values: tensor with shape (num_patches, patch_dim)
image_grid_thw: tensor with shape (num_images, 3), where the three numbers are time, width, height
where num_patches == torch.prod(image_grid_thw)
Returns: (mllama)
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
It holds num_patches == torch.prod(image_grid_thw)
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
,
None
)
video_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"video_processor"
,
image_processor
)
feature_extractor
:
"SequenceFeatureExtractor"
=
getattr
(
processor
,
"feature_extractor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)[
"images"
]
if
imglens
is
not
None
:
# if imglens are provided, make batched images
images
=
_make_batched_images
(
images
,
imglens
)
image_processor_kwargs
=
{}
if
getattr
(
processor
,
"image_do_pan_and_scan"
,
False
):
# gemma3 image processor
image_processor_kwargs
.
update
(
{
"do_pan_and_scan"
:
True
,
"pan_and_scan_min_crop_size"
:
256
,
"pan_and_scan_max_num_crops"
:
4
,
"pan_and_scan_min_ratio_to_activate"
:
1.2
,
}
)
mm_inputs
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
))
mm_inputs
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
,
**
image_processor_kwargs
))
if
len
(
videos
)
!=
0
:
video_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"video_processor"
,
getattr
(
processor
,
"image_processor"
,
None
)
)
videos
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
)
[
"videos"
]
if
"videos"
in
inspect
.
signature
(
video_processor
.
preprocess
).
parameters
:
# for qwen2_vl and video_llava
mm_inputs
.
update
(
video_processor
(
images
=
None
,
videos
=
videos
,
return_tensors
=
"pt"
))
else
:
# for llava_next_video
mm_inputs
.
update
(
video_processor
(
videos
,
return_tensors
=
"pt"
))
if
len
(
audios
)
!=
0
:
feature_extractor
:
SequenceFeatureExtractor
=
getattr
(
processor
,
"feature_extractor"
,
None
)
audios
=
self
.
_regularize_audios
(
audios
,
sampling_rate
=
getattr
(
feature_extractor
,
"
sampling_rate"
,
16000
),
)
sampling_rate
=
getattr
(
processor
,
"audio_
sampling_rate"
,
16000
),
)
[
"audios"
]
mm_inputs
.
update
(
feature_extractor
(
audios
,
sampling_rate
=
getattr
(
feature_extractor
,
"
sampling_rate"
,
16000
),
sampling_rate
=
getattr
(
processor
,
"audio_
sampling_rate"
,
16000
),
return_attention_mask
=
True
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
...
...
@@ -278,83 +352,95 @@ class MMPluginMixin:
class
BasePlugin
(
MMPluginMixin
):
def
process_messages
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
r
"""
Pre-processes input messages before tokenization for VLMs.
"""
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
r
"""Pre-process input messages before tokenization for VLMs."""
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
messages
def
process_token_ids
(
self
,
input_ids
:
L
ist
[
int
],
labels
:
Optional
[
L
ist
[
int
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
input_ids
:
l
ist
[
int
],
labels
:
Optional
[
l
ist
[
int
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Tuple
[
List
[
int
],
Optional
[
List
[
int
]]]:
r
"""
Pre-processes token ids after tokenization for VLMs.
"""
processor
:
Optional
[
"MMProcessor"
],
)
->
tuple
[
list
[
int
],
Optional
[
list
[
int
]]]:
r
"""Pre-process token ids after tokenization for VLMs."""
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
input_ids
,
labels
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
r
"""
Builds batched multimodal inputs for VLMs.
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
list
[
int
]],
processor
:
Optional
[
"MMProcessor"
],
)
->
dict
[
str
,
Union
[
list
[
int
],
"torch.Tensor"
]]:
r
"""Build batched multimodal inputs for VLMs.
Arguments:
images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,)
audios: a list of audio inputs, shape (num_audios,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
audlens: number of audios in each sample, shape (batch_size,)
batch_ids: token ids of input samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
{}
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
Llava
Plugin
(
BasePlugin
):
class
Gemma3
Plugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
L
ist
[
D
ict
[
str
,
str
]]:
messages
:
list
[
d
ict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"
MM
Processor"
],
)
->
l
ist
[
d
ict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
image_seqlen
=
getattr
(
processor
,
"image_seqlen"
)
if
self
.
expand_mm_tokens
else
1
messages
=
deepcopy
(
messages
)
boi_token
:
str
=
getattr
(
processor
,
"boi_token"
)
full_image_sequence
:
str
=
getattr
(
processor
,
"full_image_sequence"
)
image_str
=
full_image_sequence
if
self
.
expand_mm_tokens
else
boi_token
do_pan_and_scan
:
bool
=
getattr
(
processor
,
"image_do_pan_and_scan"
,
False
)
if
do_pan_and_scan
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
if
do_pan_and_scan
:
image_placeholder_str
=
(
"Here is the original image {{image}} and here are some crops to help you see better "
+
" "
.
join
([
"{{image}}"
]
*
mm_inputs
[
"num_crops"
][
0
][
num_image_tokens
])
)
else
:
image_placeholder_str
=
"{{image}}"
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
image_placeholder_str
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_
token
)
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
image_
str
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
...
...
@@ -364,17 +450,127 @@ class LlavaPlugin(BasePlugin):
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
L
ist
[
int
]],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
D
ict
[
str
,
Union
[
L
ist
[
int
],
"torch.Tensor"
]]:
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
l
ist
[
int
]],
processor
:
Optional
[
"
MM
Processor"
],
)
->
d
ict
[
str
,
Union
[
l
ist
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
.
pop
(
"num_crops"
,
None
)
mm_inputs
[
"token_type_ids"
]
=
_get_gemma3_token_type_ids
(
batch_ids
,
processor
)
return
mm_inputs
@
dataclass
class
Llama4Plugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
image_height
,
image_width
=
mm_inputs
[
"pixel_values"
][
0
].
shape
[
-
2
:]
num_patches_per_chunk
=
int
(
(
image_height
//
processor
.
patch_size
)
*
(
image_width
//
processor
.
patch_size
)
//
processor
.
downsample_ratio
)
aspect_ratios
=
mm_inputs
.
pop
(
"aspect_ratios"
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
placeholder_count
=
content
.
count
(
IMAGE_PLACEHOLDER
)
if
self
.
expand_mm_tokens
:
prompt_splits
=
content
.
split
(
IMAGE_PLACEHOLDER
)
new_content
=
[]
for
local_image_index
,
split_part
in
enumerate
(
prompt_splits
):
new_content
.
append
(
split_part
)
if
local_image_index
<
placeholder_count
:
tokens_for_this_image
=
processor
.
_prompt_split_image
(
aspect_ratios
[
num_image_tokens
],
num_patches_per_chunk
)
num_image_tokens
+=
1
new_content
.
append
(
tokens_for_this_image
)
content
=
""
.
join
(
new_content
)
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
list
[
int
]],
processor
:
Optional
[
"MMProcessor"
],
)
->
dict
[
str
,
Union
[
list
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
mm_inputs
.
pop
(
"aspect_ratios"
,
None
)
return
mm_inputs
@
dataclass
class
LlavaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
]))
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
+
processor
.
num_additional_image_tokens
if
processor
.
vision_feature_select_strategy
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
=
1
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
dataclass
...
...
@@ -382,15 +578,16 @@ class LlavaNextPlugin(BasePlugin):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
L
ist
[
D
ict
[
str
,
str
]]:
messages
:
list
[
d
ict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"
MM
Processor"
],
)
->
l
ist
[
d
ict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
...
...
@@ -402,7 +599,7 @@ class LlavaNextPlugin(BasePlugin):
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
getattr
(
processor
,
"
vision_feature_select_strategy
"
,
"default"
)
==
"default"
:
if
processor
.
vision_feature_select_strategy
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
=
1
...
...
@@ -417,47 +614,34 @@ class LlavaNextPlugin(BasePlugin):
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
LlavaNextVideoPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
L
ist
[
D
ict
[
str
,
str
]]:
messages
:
list
[
d
ict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"
MM
Processor"
],
)
->
l
ist
[
d
ict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
][
0
]))
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
self
.
expand_mm_tokens
:
orig_height
,
orig_width
=
next
(
image_sizes
)
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
getattr
(
processor
,
"
vision_feature_select_strategy
"
,
"default"
)
==
"default"
:
if
processor
.
vision_feature_select_strategy
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
=
1
...
...
@@ -467,11 +651,11 @@ class LlavaNextVideoPlugin(BasePlugin):
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
if
"pixel_values_videos"
in
mm_inputs
:
if
self
.
expand_mm_tokens
:
pixel_values_video
=
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_videos"
)[
0
])
height
,
width
=
get_image_size
(
pixel_values_video
[
0
])
num_frames
=
pixel_values_video
.
shape
[
0
]
# frame dim is always after batch dim
if
"pixel_values_videos"
in
mm_inputs
:
one_video
=
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_videos"
)[
0
])
height
,
width
=
get_image_size
(
one_video
[
0
])
num_frames
=
one_video
.
shape
[
0
]
# frame dim is always after batch dim
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
video_seqlen
=
image_seqlen
//
4
*
num_frames
# divide by 4 needed for avg pooling layer
else
:
...
...
@@ -480,8 +664,8 @@ class LlavaNextVideoPlugin(BasePlugin):
for
message
in
messages
:
content
=
message
[
"content"
]
while
VIDEO_PLACEHOLDER
in
content
:
num_video_tokens
+=
1
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{video}}"
*
video_seqlen
,
1
)
num_video_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
)
...
...
@@ -493,37 +677,22 @@ class LlavaNextVideoPlugin(BasePlugin):
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
@
dataclass
class
MiniCPMVPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
L
ist
[
D
ict
[
str
,
str
]]:
messages
:
list
[
d
ict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"
MM
Processor"
],
)
->
l
ist
[
d
ict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
,
num_audio_tokens
=
0
,
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
"
BaseImageProcessor
"
=
getattr
(
processor
,
"image_processor"
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
mm_inputs
=
{}
audio_inputs
=
{}
if
len
(
images
)
!=
0
and
len
(
videos
)
!=
0
:
...
...
@@ -614,21 +783,20 @@ class MiniCPMVPlugin(BasePlugin):
@
override
def
_get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
"Processor
Mixin
"
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
"
MM
Processor"
,
**
kwargs
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
feature_extractor
:
"SequenceFeatureExtractor"
=
getattr
(
processor
,
"feature_extractor"
,
None
)
)
->
dict
[
str
,
"torch.Tensor"
]:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)
)
[
"images"
]
if
"valid_image_nums_ls"
in
kwargs
:
valid_image_nums_ls
=
kwargs
[
"valid_image_nums_ls"
]
new_images
=
[]
...
...
@@ -651,15 +819,15 @@ class MiniCPMVPlugin(BasePlugin):
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
)
[
"videos"
]
video_inputs
=
image_processor
(
videos
,
do_pad
=
True
,
max_slice_nums
=
2
,
return_tensors
=
"pt"
)
mm_inputs
.
update
(
video_inputs
)
if
len
(
audios
)
!=
0
:
audios
=
self
.
_regularize_audios
(
audios
,
sampling_rate
=
getattr
(
feature_extractor
,
"
sampling_rate"
,
16000
),
)
sampling_rate
=
getattr
(
processor
,
"audio_
sampling_rate"
,
16000
),
)
[
"audios"
]
if
"valid_audio_nums_ls"
in
kwargs
:
valid_audio_nums_ls
=
kwargs
[
"valid_audio_nums_ls"
]
audios_ls
=
[]
...
...
@@ -673,7 +841,7 @@ class MiniCPMVPlugin(BasePlugin):
audio_features
,
audio_feature_lens
,
audio_phs
=
processor
.
audio_feature_extract
(
audios_ls
,
chunk_input
=
True
,
sampling_rate
=
16000
,
sampling_rate
=
getattr
(
processor
,
"audio_sampling_rate"
,
16000
)
,
)
audio_feature_lens
=
[
torch
.
tensor
(
audio_feature_len
)
for
audio_feature_len
in
audio_feature_lens
]
mm_inputs
.
update
({
"audio_features"
:
audio_features
,
"audio_feature_lens"
:
audio_feature_lens
})
...
...
@@ -685,15 +853,15 @@ class MiniCPMVPlugin(BasePlugin):
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
L
ist
[
int
]],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
D
ict
[
str
,
Union
[
L
ist
[
int
],
"torch.Tensor"
]]:
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
l
ist
[
int
]],
processor
:
Optional
[
"
MM
Processor"
],
)
->
d
ict
[
str
,
Union
[
l
ist
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
# image bound
image_bounds_list
=
[]
...
...
@@ -756,12 +924,12 @@ class MllamaPlugin(BasePlugin):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
L
ist
[
D
ict
[
str
,
str
]]:
messages
:
list
[
d
ict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"
MM
Processor"
],
)
->
l
ist
[
d
ict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
...
...
@@ -775,61 +943,24 @@ class MllamaPlugin(BasePlugin):
return
messages
@
override
def
_get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
"ProcessorMixin"
,
imglens
:
List
[
int
],
)
->
Dict
[
str
,
"torch.Tensor"
]:
r
"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
mm_inputs
=
{}
if
len
(
images
)
>
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)
batch_images
=
[]
for
image_length
in
imglens
:
batch_images
.
append
(
images
[:
image_length
])
images
=
images
[
image_length
:]
mm_inputs
.
update
(
image_processor
(
batch_images
,
return_tensors
=
"pt"
))
return
mm_inputs
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
L
ist
[
int
]],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
D
ict
[
str
,
Union
[
L
ist
[
int
],
"torch.Tensor"
]]:
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
l
ist
[
int
]],
processor
:
Optional
[
"
MM
Processor"
],
)
->
d
ict
[
str
,
Union
[
l
ist
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
,
imglens
)
if
mm_inputs
:
num_tiles
=
mm_inputs
.
pop
(
"num_tiles"
)
image_token_id
=
getattr
(
processor
,
"image_token_id"
)
max_image_tiles
=
getattr
(
processor
.
image_processor
,
"max_image_tiles"
)
image_token_id
:
int
=
getattr
(
processor
,
"image_token_id"
)
max_image_tiles
:
int
=
getattr
(
processor
.
image_processor
,
"max_image_tiles"
)
cross_attention_token_mask
=
[
get_cross_attention_token_mask
(
input_ids
,
image_token_id
)
for
input_ids
in
batch_ids
]
...
...
@@ -850,22 +981,22 @@ class PaliGemmaPlugin(BasePlugin):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
L
ist
[
D
ict
[
str
,
str
]]:
messages
:
list
[
d
ict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"
MM
Processor"
],
)
->
l
ist
[
d
ict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"
{{image}}
"
,
1
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
""
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
""
)
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
...
...
@@ -875,36 +1006,36 @@ class PaliGemmaPlugin(BasePlugin):
@
override
def
process_token_ids
(
self
,
input_ids
:
L
ist
[
int
],
labels
:
Optional
[
L
ist
[
int
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
input_ids
:
l
ist
[
int
],
labels
:
Optional
[
l
ist
[
int
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
T
uple
[
L
ist
[
int
],
Optional
[
L
ist
[
int
]]]:
processor
:
Optional
[
"
MM
Processor"
],
)
->
t
uple
[
l
ist
[
int
],
Optional
[
l
ist
[
int
]]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_images
=
len
(
images
)
image_seqlen
=
num_images
*
getattr
(
processor
,
"
image_seqlen
"
)
if
self
.
expand_mm_tokens
else
0
# skip mm token
image_seqlen
=
processor
.
image_seq
_
len
gth
if
self
.
expand_mm_tokens
else
0
# skip mm token
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
self
.
image_token
)
input_ids
=
[
image_token_id
]
*
image_seqlen
+
input_ids
input_ids
=
[
image_token_id
]
*
num_images
*
image_seqlen
+
input_ids
if
labels
is
not
None
:
labels
=
[
IGNORE_INDEX
]
*
image_seqlen
+
labels
labels
=
[
IGNORE_INDEX
]
*
num_images
*
image_seqlen
+
labels
return
input_ids
,
labels
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
L
ist
[
int
]],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
D
ict
[
str
,
Union
[
L
ist
[
int
],
"torch.Tensor"
]]:
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
l
ist
[
int
]],
processor
:
Optional
[
"
MM
Processor"
],
)
->
d
ict
[
str
,
Union
[
l
ist
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
seqlens
=
[
len
(
input_ids
)
for
input_ids
in
batch_ids
]
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
...
...
@@ -917,37 +1048,39 @@ class PixtralPlugin(BasePlugin):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
L
ist
[
D
ict
[
str
,
str
]]:
messages
:
list
[
d
ict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"
MM
Processor"
],
)
->
l
ist
[
d
ict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
patch_size
=
getattr
(
processor
,
"patch_size"
)
image_token
=
getattr
(
processor
,
"image_token"
)
image_break_token
=
getattr
(
processor
,
"image_break_token"
)
image_end_token
=
getattr
(
processor
,
"image_end_token"
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values"
in
mm_inputs
:
# BC for transformers < 4.49.0
if
isinstance
(
mm_inputs
[
"image_sizes"
],
list
):
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
][
0
])
else
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
].
tolist
())
image_break_token
:
str
=
getattr
(
processor
,
"image_break_token"
)
image_end_token
:
str
=
getattr
(
processor
,
"image_end_token"
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
self
.
expand_mm_tokens
:
height
,
width
=
next
(
image_sizes
)
num_height_tokens
=
height
//
patch_size
num_width_tokens
=
width
//
patch_size
replace_tokens
=
[[
image_token
]
*
num_width_tokens
+
[
image_break_token
]]
*
num_height_tokens
num_height_tokens
=
height
//
processor
.
patch_size
num_width_tokens
=
width
//
processor
.
patch_size
replace_tokens
=
[[
self
.
image_token
]
*
num_width_tokens
+
[
image_break_token
]]
*
num_height_tokens
replace_tokens
=
[
item
for
sublist
in
replace_tokens
for
item
in
sublist
]
# flatten list
replace_tokens
[
-
1
]
=
image_end_token
replace_str
=
""
.
join
(
replace_tokens
)
else
:
replace_str
=
image_token
replace_str
=
self
.
image_token
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
replace_str
,
1
)
num_image_tokens
+=
1
...
...
@@ -962,17 +1095,21 @@ class PixtralPlugin(BasePlugin):
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
L
ist
[
int
]],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
D
ict
[
str
,
Union
[
L
ist
[
int
],
"torch.Tensor"
]]:
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
l
ist
[
int
]],
processor
:
Optional
[
"
MM
Processor"
],
)
->
d
ict
[
str
,
Union
[
l
ist
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
# ref to this commit https://github.com/huggingface/transformers/pull/35122
# after transformers 4.49.0, the `image_sizes` is mandatory as an input parameter for Pixtral VisionEncoder forwarding.
# it can be passed into `LlavaConditionalGeneration` as a parameter.
if
not
is_transformers_version_greater_than
(
"4.49.0"
):
mm_inputs
.
pop
(
"image_sizes"
,
None
)
return
mm_inputs
...
...
@@ -982,21 +1119,22 @@ class Qwen2AudioPlugin(BasePlugin):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
L
ist
[
D
ict
[
str
,
str
]]:
messages
:
list
[
d
ict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"
MM
Processor"
],
)
->
l
ist
[
d
ict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
bos_token
:
str
=
getattr
(
processor
,
"audio_bos_token"
)
eos_token
:
str
=
getattr
(
processor
,
"audio_eos_token"
)
num_audio_tokens
=
0
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
([],
[],
audios
,
processor
)
if
"feature_attention_mask"
in
mm_inputs
:
audio_lengths
=
mm_inputs
[
"feature_attention_mask"
].
sum
(
-
1
).
tolist
()
num_audio_tokens
=
0
for
message
in
messages
:
content
=
message
[
"content"
]
while
AUDIO_PLACEHOLDER
in
content
:
...
...
@@ -1022,15 +1160,15 @@ class Qwen2AudioPlugin(BasePlugin):
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
L
ist
[
int
]],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
D
ict
[
str
,
Union
[
L
ist
[
int
],
"torch.Tensor"
]]:
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
imglens
:
list
[
int
],
vidlens
:
list
[
int
],
audlens
:
list
[
int
],
batch_ids
:
list
[
l
ist
[
int
]],
processor
:
Optional
[
"
MM
Processor"
],
)
->
d
ict
[
str
,
Union
[
l
ist
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
...
...
@@ -1056,14 +1194,14 @@ class Qwen2VLPlugin(BasePlugin):
@
override
def
_regularize_videos
(
self
,
videos
:
Sequence
[
"VideoInput"
],
**
kwargs
)
->
Tuple
[
L
ist
[
L
ist
[
"ImageObject"
]],
L
ist
[
float
]]:
self
,
videos
:
list
[
"VideoInput"
],
**
kwargs
)
->
dict
[
str
,
Union
[
l
ist
[
l
ist
[
"ImageObject"
]],
l
ist
[
float
]]
]
:
results
,
fps_per_video
=
[],
[]
for
video
in
videos
:
container
=
av
.
open
(
video
,
"r"
)
video_stream
=
next
(
stream
for
stream
in
container
.
streams
if
stream
.
type
==
"video"
)
sample_indices
=
self
.
_get_video_sample_indices
(
video_stream
,
**
kwargs
)
frames
:
L
ist
[
"
ImageObject
"
]
=
[]
frames
:
l
ist
[
ImageObject
]
=
[]
container
.
seek
(
0
)
for
frame_idx
,
frame
in
enumerate
(
container
.
decode
(
video_stream
)):
if
frame_idx
in
sample_indices
:
...
...
@@ -1072,59 +1210,61 @@ class Qwen2VLPlugin(BasePlugin):
if
len
(
frames
)
%
2
!=
0
:
# qwen2-vl requires even number of frames
frames
.
append
(
frames
[
-
1
])
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)
frames
=
self
.
_regularize_images
(
frames
,
**
kwargs
)
[
"images"
]
results
.
append
(
frames
)
if
video_stream
.
duration
is
None
:
fps_per_video
.
append
(
2.0
)
else
:
fps_per_video
.
append
(
len
(
sample_indices
)
/
float
(
video_stream
.
duration
*
video_stream
.
time_base
))
return
results
,
fps_per_video
return
{
"videos"
:
results
,
"
fps_per_video
"
:
fps_per_video
}
@
override
def
_get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
"Processor
Mixin
"
,
)
->
D
ict
[
str
,
"torch.Tensor"
]:
image_processor
:
"
BaseImageProcessor
"
=
getattr
(
processor
,
"image_processor"
,
None
)
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
"
MM
Processor"
,
)
->
d
ict
[
str
,
"torch.Tensor"
]:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)
)
[
"images"
]
mm_inputs
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
))
if
len
(
videos
)
!=
0
:
video
s
,
fps_per_video
=
self
.
_regularize_videos
(
video
_data
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
mm_inputs
.
update
(
image_processor
(
images
=
None
,
videos
=
videos
,
return_tensors
=
"pt"
))
mm_inputs
[
"fps_per_video"
]
=
fps_per_video
mm_inputs
.
update
(
image_processor
(
images
=
None
,
videos
=
video_data
[
"videos"
],
return_tensors
=
"pt"
))
temporal_patch_size
:
int
=
getattr
(
image_processor
,
"temporal_patch_size"
,
2
)
if
"second_per_grid_ts"
in
processor
.
model_input_names
:
mm_inputs
[
"second_per_grid_ts"
]
=
[
temporal_patch_size
/
fps
for
fps
in
video_data
[
"fps_per_video"
]]
return
mm_inputs
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
L
ist
[
D
ict
[
str
,
str
]]:
messages
:
list
[
d
ict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"
MM
Processor"
],
)
->
l
ist
[
d
ict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
image_processor
:
"
BaseImageProcessor
"
=
getattr
(
processor
,
"image_processor"
)
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
)
merge_length
:
int
=
getattr
(
image_processor
,
"merge_size"
)
**
2
if
self
.
expand_mm_tokens
:
...
...
@@ -1167,26 +1307,188 @@ class Qwen2VLPlugin(BasePlugin):
return
messages
class
Qwen2OmniPlugin
(
Qwen2VLPlugin
):
@
override
def
get_mm_inputs
(
def
_get_mm_inputs
(
self
,
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
"MMProcessor"
,
)
->
dict
[
str
,
"torch.Tensor"
]:
image_processor
:
BaseImageProcessor
=
getattr
(
processor
,
"image_processor"
,
None
)
feature_extractor
:
SequenceFeatureExtractor
=
getattr
(
processor
,
"feature_extractor"
,
None
)
mm_inputs
=
{}
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_max_pixels
=
getattr
(
processor
,
"image_max_pixels"
,
768
*
768
),
image_min_pixels
=
getattr
(
processor
,
"image_min_pixels"
,
32
*
32
),
)[
"images"
]
mm_inputs
.
update
(
image_processor
(
images
,
return_tensors
=
"pt"
))
if
len
(
videos
)
!=
0
:
video_dict
=
self
.
_regularize_videos
(
videos
,
image_max_pixels
=
getattr
(
processor
,
"video_max_pixels"
,
256
*
256
),
image_min_pixels
=
getattr
(
processor
,
"video_min_pixels"
,
16
*
16
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
128
),
)
mm_inputs
.
update
(
image_processor
(
images
=
None
,
videos
=
video_dict
[
"videos"
],
return_tensors
=
"pt"
))
temporal_patch_size
:
int
=
getattr
(
image_processor
,
"temporal_patch_size"
,
2
)
mm_inputs
[
"video_second_per_grid"
]
=
torch
.
tensor
(
[
temporal_patch_size
/
fps
for
fps
in
video_dict
[
"fps_per_video"
]]
)
if
len
(
audios
)
!=
0
:
audios
=
self
.
_regularize_audios
(
audios
,
sampling_rate
=
getattr
(
processor
,
"audio_sampling_rate"
,
16000
),
)[
"audios"
]
mm_inputs
.
update
(
feature_extractor
(
audios
,
sampling_rate
=
getattr
(
processor
,
"audio_sampling_rate"
,
16000
),
return_attention_mask
=
True
,
padding
=
"max_length"
,
return_tensors
=
"pt"
,
)
)
mm_inputs
[
"feature_attention_mask"
]
=
mm_inputs
.
pop
(
"attention_mask"
)
# prevent conflicts
return
mm_inputs
@
override
def
process_messages
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
messages
:
list
[
dict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"MMProcessor"
],
)
->
list
[
dict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
messages
=
deepcopy
(
messages
)
if
self
.
expand_mm_tokens
:
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
fps_per_video
=
mm_inputs
.
pop
(
"fps_per_video"
,
[])
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
if
"second_per_grid_ts"
in
processor
.
model_input_names
and
fps_per_video
:
mm_inputs
[
"second_per_grid_ts"
]
=
[
image_processor
.
temporal_patch_size
/
fps
for
fps
in
fps_per_video
]
else
:
mm_inputs
=
{}
return
mm_inputs
num_audio_tokens
,
num_image_tokens
,
num_video_tokens
=
0
,
0
,
0
use_audio_in_video
=
getattr
(
processor
,
"use_audio_in_video"
,
False
)
# get length or size from mm_inputs
if
"feature_attention_mask"
in
mm_inputs
:
input_lengths
=
(
mm_inputs
[
"feature_attention_mask"
].
sum
(
-
1
).
numpy
()
-
1
)
//
2
+
1
audio_lengths
=
(
input_lengths
-
2
)
//
2
+
1
if
mm_inputs
.
get
(
"image_grid_thw"
,
None
)
is
not
None
:
image_grid_thw
=
mm_inputs
[
"image_grid_thw"
]
merge_length
=
processor
.
image_processor
.
merge_size
**
2
if
mm_inputs
.
get
(
"video_grid_thw"
,
None
)
is
not
None
:
video_grid_thw
=
mm_inputs
[
"video_grid_thw"
]
merge_length
=
processor
.
image_processor
.
merge_size
**
2
if
use_audio_in_video
:
if
audio_lengths
is
None
:
raise
ValueError
(
"audio_lengths should exist when use_audio_in_video is `True`."
)
if
not
mm_inputs
.
get
(
"video_grid_thw"
,
None
):
raise
ValueError
(
"video_grid_thw should exist when use_audio_in_video is `True`."
)
positions_list
=
[]
for
i
,
message
in
enumerate
(
messages
):
# get multimodal index when use_audio
positions
=
[]
for
special_token
in
[
self
.
audio_token
,
self
.
image_token
,
self
.
video_token
]:
start
=
0
while
True
:
pos
=
message
[
i
].
find
(
special_token
,
start
)
if
pos
==
-
1
:
break
positions
.
append
((
pos
,
special_token
))
start
=
pos
+
len
(
special_token
)
positions_list
.
append
(
positions
.
sort
(
key
=
lambda
x
:
x
[
0
]))
for
message
in
messages
:
content
=
message
[
"content"
]
# separate with audio-video
while
IMAGE_PLACEHOLDER
in
content
:
image_token_replace_length
=
image_grid_thw
[
num_image_tokens
].
prod
()
//
merge_length
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
f
"<|vision_bos|>
{
self
.
image_token
*
image_token_replace_length
}
<|vision_eos|>"
,
1
,
)
num_image_tokens
+=
1
if
not
use_audio_in_video
:
while
AUDIO_PLACEHOLDER
in
content
:
audio_token_replace_length
=
audio_lengths
[
num_audio_tokens
]
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
f
"<|audio_bos|>
{
self
.
audio_token
*
audio_token_replace_length
}
<|audio_eos|>"
,
1
,
)
num_audio_tokens
+=
1
# TODO handle video_input and use_audio_in_video
while
VIDEO_PLACEHOLDER
in
content
:
video_replace_length
=
video_grid_thw
[
num_video_tokens
].
prod
()
//
merge_length
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
f
"<|vision_bos|>
{
self
.
video_token
*
video_replace_length
}
<|vision_eos|>"
,
1
)
num_video_tokens
+=
1
else
:
# if use the audio of video # deal video token and audio token togather
while
VIDEO_PLACEHOLDER
in
content
:
audio_t_index
=
torch
.
arange
(
audio_lengths
[
num_audio_tokens
])
video_t_index
=
(
torch
.
arange
(
video_grid_thw
[
num_video_tokens
][
0
])
.
view
(
-
1
,
1
,
1
)
.
expand
(
-
1
,
video_grid_thw
[
num_video_tokens
][
1
]
//
self
.
image_processor
.
merge_size
,
video_grid_thw
[
num_video_tokens
][
2
]
//
self
.
image_processor
.
merge_size
,
)
.
flatten
()
*
mm_inputs
[
"video_second_per_grid"
][
num_video_tokens
]
*
25
# FIXME hardcode of position_id_per_seconds=25
).
long
()
t_ntoken_per_chunk
=
50
# FIXME hardcode: [25 * 2]
video_chunk_indices
=
processor
.
get_chunked_index
(
video_t_index
,
t_ntoken_per_chunk
)
audio_chunk_indices
=
self
.
get_chunked_index
(
audio_t_index
,
t_ntoken_per_chunk
)
placeholder_string
=
""
for
j
in
range
(
max
(
len
(
video_chunk_indices
),
len
(
audio_chunk_indices
))):
video_chunk_index
=
video_chunk_indices
[
j
]
if
j
<
len
(
video_chunk_indices
)
else
None
audio_chunk_index
=
audio_chunk_indices
[
j
]
if
j
<
len
(
audio_chunk_indices
)
else
None
placeholder_string
=
"<|vision_bos|>"
+
"<|audio_bos|>"
if
video_chunk_index
is
not
None
:
placeholder_string
+=
self
.
video_token
*
(
video_chunk_index
[
1
]
-
video_chunk_index
[
0
])
if
audio_chunk_index
is
not
None
:
placeholder_string
+=
self
.
audio_token
*
(
audio_chunk_index
[
1
]
-
audio_chunk_index
[
0
])
placeholder_string
+=
"<|audio_eos|>"
+
"<|vision_eos|>"
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
placeholder_string
,
1
)
content
=
content
.
replace
(
AUDIO_PLACEHOLDER
,
""
,
1
)
num_audio_tokens
+=
1
num_video_tokens
+=
1
message
[
"content"
]
=
content
if
len
(
audios
)
!=
num_audio_tokens
:
raise
ValueError
(
f
"The number of audios does not match the number of
{
AUDIO_PLACEHOLDER
}
tokens."
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
@
dataclass
...
...
@@ -1194,33 +1496,33 @@ class VideoLlavaPlugin(BasePlugin):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
processor
:
Optional
[
"Processor
Mixin
"
],
)
->
L
ist
[
D
ict
[
str
,
str
]]:
messages
:
list
[
d
ict
[
str
,
str
]],
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
processor
:
Optional
[
"
MM
Processor"
],
)
->
l
ist
[
d
ict
[
str
,
str
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
num_frames
=
0
has_images
=
"pixel_values_images"
in
mm_inputs
has_videos
=
"pixel_values_videos"
in
mm_inputs
if
has_images
or
has_videos
:
if
self
.
expand_mm_tokens
:
if
has_images
:
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_images"
)[
0
]))
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
if
"pixel_values_images"
in
mm_inputs
:
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values_images"
][
0
]))
num_frames
=
1
if
has_videos
:
pixel_values_video
=
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_videos"
)[
0
])
height
,
width
=
get_image_size
(
pixel_values_video
[
0
])
num_frames
=
pixel_values_video
.
shape
[
0
]
# frame dim is always after batch dim
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
+
1
if
"pixel_values_videos"
in
mm_inputs
:
one_video
=
to_numpy_array
(
mm_inputs
[
"pixel_values_videos"
][
0
])
height
,
width
=
get_image_size
(
one_video
[
0
])
num_frames
=
one_video
.
shape
[
0
]
# frame dim is always after batch dim
if
"pixel_values_images"
in
mm_inputs
or
"pixel_values_videos"
in
mm_inputs
:
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
+
processor
.
num_additional_image_tokens
video_seqlen
=
image_seqlen
*
num_frames
if
getattr
(
processor
,
"
vision_feature_select_strategy
"
,
"default"
)
==
"default"
:
if
processor
.
vision_feature_select_strategy
==
"default"
:
image_seqlen
-=
1
else
:
image_seqlen
,
video_seqlen
=
1
,
1
...
...
@@ -1246,24 +1548,11 @@ class VideoLlavaPlugin(BasePlugin):
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
audlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
processor
,
images
,
videos
,
audios
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
audios
,
processor
)
PLUGINS
=
{
"base"
:
BasePlugin
,
"gemma3"
:
Gemma3Plugin
,
"llama4"
:
Llama4Plugin
,
"llava"
:
LlavaPlugin
,
"llava_next"
:
LlavaNextPlugin
,
"llava_next_video"
:
LlavaNextVideoPlugin
,
...
...
@@ -1272,15 +1561,14 @@ PLUGINS = {
"paligemma"
:
PaliGemmaPlugin
,
"pixtral"
:
PixtralPlugin
,
"qwen2_audio"
:
Qwen2AudioPlugin
,
"qwen2_omni"
:
Qwen2OmniPlugin
,
"qwen2_vl"
:
Qwen2VLPlugin
,
"video_llava"
:
VideoLlavaPlugin
,
}
def
register_mm_plugin
(
name
:
str
,
plugin_class
:
Type
[
"BasePlugin"
])
->
None
:
r
"""
Registers a multimodal plugin.
"""
def
register_mm_plugin
(
name
:
str
,
plugin_class
:
type
[
"BasePlugin"
])
->
None
:
r
"""Register a multimodal plugin."""
if
name
in
PLUGINS
:
raise
ValueError
(
f
"Multimodal plugin
{
name
}
already exists."
)
...
...
@@ -1293,9 +1581,7 @@ def get_mm_plugin(
video_token
:
Optional
[
str
]
=
None
,
audio_token
:
Optional
[
str
]
=
None
,
)
->
"BasePlugin"
:
r
"""
Gets plugin for multimodal inputs.
"""
r
"""Get plugin for multimodal inputs."""
if
name
not
in
PLUGINS
:
raise
ValueError
(
f
"Multimodal plugin `
{
name
}
` not found."
)
...
...
src/llamafactory/data/parser.py
View file @
7ea81099
...
...
@@ -15,9 +15,9 @@
import
json
import
os
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
from
typing
import
Any
,
Literal
,
Optional
from
transformers.utils
import
cached_file
from
huggingface_hub
import
hf_hub_download
from
..extras.constants
import
DATA_CONFIG
from
..extras.misc
import
use_modelscope
,
use_openmind
...
...
@@ -25,9 +25,7 @@ from ..extras.misc import use_modelscope, use_openmind
@
dataclass
class
DatasetAttr
:
r
"""
Dataset attributes.
"""
r
"""Dataset attributes."""
# basic configs
load_from
:
Literal
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
,
"script"
,
"file"
]
...
...
@@ -68,10 +66,10 @@ class DatasetAttr:
def
__repr__
(
self
)
->
str
:
return
self
.
dataset_name
def
set_attr
(
self
,
key
:
str
,
obj
:
D
ict
[
str
,
Any
],
default
:
Optional
[
Any
]
=
None
)
->
None
:
def
set_attr
(
self
,
key
:
str
,
obj
:
d
ict
[
str
,
Any
],
default
:
Optional
[
Any
]
=
None
)
->
None
:
setattr
(
self
,
key
,
obj
.
get
(
key
,
default
))
def
join
(
self
,
attr
:
D
ict
[
str
,
Any
])
->
None
:
def
join
(
self
,
attr
:
d
ict
[
str
,
Any
])
->
None
:
self
.
set_attr
(
"formatting"
,
attr
,
default
=
"alpaca"
)
self
.
set_attr
(
"ranking"
,
attr
,
default
=
False
)
self
.
set_attr
(
"subset"
,
attr
)
...
...
@@ -92,10 +90,8 @@ class DatasetAttr:
self
.
set_attr
(
tag
,
attr
[
"tags"
])
def
get_dataset_list
(
dataset_names
:
Optional
[
Sequence
[
str
]],
dataset_dir
:
str
)
->
List
[
"DatasetAttr"
]:
r
"""
Gets the attributes of the datasets.
"""
def
get_dataset_list
(
dataset_names
:
Optional
[
list
[
str
]],
dataset_dir
:
str
)
->
list
[
"DatasetAttr"
]:
r
"""Get the attributes of the datasets."""
if
dataset_names
is
None
:
dataset_names
=
[]
...
...
@@ -103,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info
=
None
else
:
if
dataset_dir
.
startswith
(
"REMOTE:"
):
config_path
=
cached_file
(
path_or_
repo_id
=
dataset_dir
[
7
:],
filename
=
DATA_CONFIG
,
repo_type
=
"dataset"
)
config_path
=
hf_hub_download
(
repo_id
=
dataset_dir
[
7
:],
filename
=
DATA_CONFIG
,
repo_type
=
"dataset"
)
else
:
config_path
=
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
)
...
...
@@ -116,7 +112,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info
=
None
dataset_list
:
L
ist
[
"
DatasetAttr
"
]
=
[]
dataset_list
:
l
ist
[
DatasetAttr
]
=
[]
for
name
in
dataset_names
:
if
dataset_info
is
None
:
# dataset_dir is ONLINE
if
use_modelscope
():
...
...
src/llamafactory/data/processor/__init__.py
View file @
7ea81099
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.feedback
import
FeedbackDatasetProcessor
from
.pairwise
import
PairwiseDatasetProcessor
from
.pretrain
import
PretrainDatasetProcessor
...
...
@@ -9,9 +23,9 @@ from .unsupervised import UnsupervisedDatasetProcessor
__all__
=
[
"DatasetProcessor"
,
"FeedbackDatasetProcessor"
,
"PackedSupervisedDatasetProcessor"
,
"PairwiseDatasetProcessor"
,
"PretrainDatasetProcessor"
,
"PackedSupervisedDatasetProcessor"
,
"SupervisedDatasetProcessor"
,
"UnsupervisedDatasetProcessor"
,
]
src/llamafactory/data/processor/feedback.py
View file @
7ea81099
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
...
...
@@ -30,15 +30,15 @@ logger = logging.get_logger(__name__)
class
FeedbackDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
Sequence
[
D
ict
[
str
,
str
]],
response
:
Sequence
[
D
ict
[
str
,
str
]],
kl_response
:
Sequence
[
D
ict
[
str
,
str
]],
prompt
:
list
[
d
ict
[
str
,
str
]],
response
:
list
[
d
ict
[
str
,
str
]],
kl_response
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
T
uple
[
L
ist
[
int
],
L
ist
[
int
],
L
ist
[
int
],
L
ist
[
int
],
bool
]:
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
)
->
t
uple
[
l
ist
[
int
],
l
ist
[
int
],
l
ist
[
int
],
l
ist
[
int
],
bool
]:
if
response
[
0
][
"content"
]:
# desired example
kto_tag
=
True
messages
=
prompt
+
[
response
[
0
]]
...
...
@@ -82,9 +82,9 @@ class FeedbackDatasetProcessor(DatasetProcessor):
kl_labels
=
[
IGNORE_INDEX
]
*
kl_source_len
+
kl_response_ids
return
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
def
preprocess_dataset
(
self
,
examples
:
D
ict
[
str
,
L
ist
[
Any
]])
->
D
ict
[
str
,
L
ist
[
Any
]]:
#
c
reate
unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response
=
examples
[
"_response"
][
:
:
-
1
]
def
preprocess_dataset
(
self
,
examples
:
d
ict
[
str
,
l
ist
[
Any
]])
->
d
ict
[
str
,
l
ist
[
Any
]]:
#
C
reate
s mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
kl_response
=
[
examples
[
"_response"
][
-
1
]]
+
examples
[
"_response"
][
:
-
1
]
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
...
...
@@ -121,7 +121,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
return
model_inputs
def
print_data_example
(
self
,
example
:
D
ict
[
str
,
L
ist
[
int
]])
->
None
:
def
print_data_example
(
self
,
example
:
d
ict
[
str
,
l
ist
[
int
]])
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
...
...
src/llamafactory/data/processor/pairwise.py
View file @
7ea81099
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
...
...
@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
class
PairwiseDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
Sequence
[
D
ict
[
str
,
str
]],
response
:
Sequence
[
D
ict
[
str
,
str
]],
prompt
:
list
[
d
ict
[
str
,
str
]],
response
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
T
uple
[
L
ist
[
int
],
L
ist
[
int
],
L
ist
[
int
],
L
ist
[
int
]]:
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
)
->
t
uple
[
l
ist
[
int
],
l
ist
[
int
],
l
ist
[
int
],
l
ist
[
int
]]:
chosen_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
0
]],
images
,
videos
,
audios
,
self
.
processor
)
...
...
@@ -68,7 +68,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
rejected_labels
=
[
IGNORE_INDEX
]
*
source_len
+
rejected_ids
return
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
def
preprocess_dataset
(
self
,
examples
:
D
ict
[
str
,
L
ist
[
Any
]])
->
D
ict
[
str
,
L
ist
[
Any
]]:
def
preprocess_dataset
(
self
,
examples
:
d
ict
[
str
,
l
ist
[
Any
]])
->
d
ict
[
str
,
l
ist
[
Any
]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
...
...
@@ -99,7 +99,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
return
model_inputs
def
print_data_example
(
self
,
example
:
D
ict
[
str
,
L
ist
[
int
]])
->
None
:
def
print_data_example
(
self
,
example
:
d
ict
[
str
,
l
ist
[
int
]])
->
None
:
valid_chosen_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"chosen_labels"
]))
valid_rejected_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"rejected_labels"
]))
print
(
"chosen_input_ids:
\n
{}"
.
format
(
example
[
"chosen_input_ids"
]))
...
...
src/llamafactory/data/processor/pretrain.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...
...
@@ -17,14 +17,14 @@
from
dataclasses
import
dataclass
from
itertools
import
chain
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
from
.processor_utils
import
DatasetProcessor
@
dataclass
class
PretrainDatasetProcessor
(
DatasetProcessor
):
def
preprocess_dataset
(
self
,
examples
:
D
ict
[
str
,
L
ist
[
Any
]])
->
D
ict
[
str
,
L
ist
[
Any
]]:
def
preprocess_dataset
(
self
,
examples
:
d
ict
[
str
,
l
ist
[
Any
]])
->
d
ict
[
str
,
l
ist
[
Any
]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token
=
"<|end_of_text|>"
if
self
.
data_args
.
template
==
"llama3"
else
self
.
tokenizer
.
eos_token
text_examples
=
[
messages
[
0
][
"content"
]
+
eos_token
for
messages
in
examples
[
"_prompt"
]]
...
...
@@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor):
return
result
def
print_data_example
(
self
,
example
:
D
ict
[
str
,
L
ist
[
int
]])
->
None
:
def
print_data_example
(
self
,
example
:
d
ict
[
str
,
l
ist
[
int
]])
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/processor/processor_utils.py
View file @
7ea81099
...
...
@@ -15,7 +15,7 @@
import
bisect
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
if
TYPE_CHECKING
:
...
...
@@ -27,9 +27,7 @@ if TYPE_CHECKING:
@
dataclass
class
DatasetProcessor
(
ABC
):
r
"""
A class for data processors.
"""
r
"""A class for data processors."""
template
:
"Template"
tokenizer
:
"PreTrainedTokenizer"
...
...
@@ -37,32 +35,24 @@ class DatasetProcessor(ABC):
data_args
:
"DataArguments"
@
abstractmethod
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
r
"""
Builds model inputs from the examples.
"""
def
preprocess_dataset
(
self
,
examples
:
dict
[
str
,
list
[
Any
]])
->
dict
[
str
,
list
[
Any
]]:
r
"""Build model inputs from the examples."""
...
@
abstractmethod
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
r
"""
Print a data example to stdout.
"""
def
print_data_example
(
self
,
example
:
dict
[
str
,
list
[
int
]])
->
None
:
r
"""Print a data example to stdout."""
...
def
search_for_fit
(
numbers
:
Sequence
[
int
],
capacity
:
int
)
->
int
:
r
"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
def
search_for_fit
(
numbers
:
list
[
int
],
capacity
:
int
)
->
int
:
r
"""Find the index of largest number that fits into the knapsack with the given capacity."""
index
=
bisect
.
bisect
(
numbers
,
capacity
)
return
-
1
if
index
==
0
else
(
index
-
1
)
def
greedy_knapsack
(
numbers
:
List
[
int
],
capacity
:
int
)
->
List
[
List
[
int
]]:
r
"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
def
greedy_knapsack
(
numbers
:
list
[
int
],
capacity
:
int
)
->
list
[
list
[
int
]]:
r
"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
numbers
.
sort
()
# sort numbers in ascending order for binary search
knapsacks
=
[]
...
...
@@ -83,10 +73,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return
knapsacks
def
infer_seqlen
(
source_len
:
int
,
target_len
:
int
,
cutoff_len
:
int
)
->
Tuple
[
int
,
int
]:
r
"""
Computes the real sequence length after truncation by the cutoff_len.
"""
def
infer_seqlen
(
source_len
:
int
,
target_len
:
int
,
cutoff_len
:
int
)
->
tuple
[
int
,
int
]:
r
"""Compute the real sequence length after truncation by the cutoff_len."""
if
target_len
*
2
<
cutoff_len
:
# truncate source
max_target_len
=
cutoff_len
elif
source_len
*
2
<
cutoff_len
:
# truncate target
...
...
src/llamafactory/data/processor/supervised.py
View file @
7ea81099
...
...
@@ -14,7 +14,7 @@
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
...
...
@@ -32,14 +32,14 @@ logger = logging.get_logger(__name__)
class
SupervisedDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
Sequence
[
D
ict
[
str
,
str
]],
response
:
Sequence
[
D
ict
[
str
,
str
]],
prompt
:
list
[
d
ict
[
str
,
str
]],
response
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]:
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
)
->
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]:
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
response
,
images
,
videos
,
audios
,
self
.
processor
)
input_ids
,
labels
=
self
.
template
.
mm_plugin
.
process_token_ids
(
[],
[],
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
...
...
@@ -85,7 +85,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return
input_ids
,
labels
def
preprocess_dataset
(
self
,
examples
:
D
ict
[
str
,
L
ist
[
Any
]])
->
D
ict
[
str
,
L
ist
[
Any
]]:
def
preprocess_dataset
(
self
,
examples
:
d
ict
[
str
,
l
ist
[
Any
]])
->
d
ict
[
str
,
l
ist
[
Any
]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs
=
defaultdict
(
list
)
...
...
@@ -114,7 +114,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return
model_inputs
def
print_data_example
(
self
,
example
:
D
ict
[
str
,
L
ist
[
int
]])
->
None
:
def
print_data_example
(
self
,
example
:
d
ict
[
str
,
l
ist
[
int
]])
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
...
...
@@ -124,7 +124,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
@
dataclass
class
PackedSupervisedDatasetProcessor
(
SupervisedDatasetProcessor
):
def
preprocess_dataset
(
self
,
examples
:
D
ict
[
str
,
L
ist
[
Any
]])
->
D
ict
[
str
,
L
ist
[
Any
]]:
def
preprocess_dataset
(
self
,
examples
:
d
ict
[
str
,
l
ist
[
Any
]])
->
d
ict
[
str
,
l
ist
[
Any
]]:
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
...
...
@@ -165,7 +165,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
knapsacks
=
greedy_knapsack
(
lengths
,
self
.
data_args
.
cutoff_len
)
for
knapsack
in
knapsacks
:
packed_input_ids
,
packed_attention_masks
,
packed_labels
=
[],
[],
[]
packed_images
,
packed_videos
,
packed_audios
=
[],
[],
[]
packed_images
,
packed_videos
,
packed_audios
,
packed_position_ids
=
[],
[],
[],
[]
for
i
,
length
in
enumerate
(
knapsack
):
index
=
length2indexes
[
length
].
pop
()
packed_input_ids
+=
batch_input_ids
[
index
]
...
...
@@ -175,6 +175,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_audios
+=
batch_audios
[
index
]
if
self
.
data_args
.
neat_packing
:
packed_attention_masks
+=
[
i
+
1
]
*
len
(
batch_input_ids
[
index
])
# start from 1
packed_position_ids
+=
list
(
range
(
len
(
batch_input_ids
[
index
])))
else
:
packed_attention_masks
+=
[
1
]
*
len
(
batch_input_ids
[
index
])
...
...
@@ -184,6 +185,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_labels
+=
[
IGNORE_INDEX
]
*
pad_length
if
self
.
data_args
.
neat_packing
:
packed_attention_masks
+=
[
0
]
*
pad_length
packed_position_ids
+=
[
0
]
*
pad_length
else
:
packed_attention_masks
+=
[
1
]
*
pad_length
# more efficient flash_attn
...
...
@@ -196,5 +198,6 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs
[
"images"
].
append
(
packed_images
or
None
)
model_inputs
[
"videos"
].
append
(
packed_videos
or
None
)
model_inputs
[
"audios"
].
append
(
packed_audios
or
None
)
model_inputs
[
"position_ids"
].
append
(
packed_position_ids
or
None
)
return
model_inputs
src/llamafactory/data/processor/unsupervised.py
View file @
7ea81099
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
..data_utils
import
Role
...
...
@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
class
UnsupervisedDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
self
,
prompt
:
Sequence
[
D
ict
[
str
,
str
]],
response
:
Sequence
[
D
ict
[
str
,
str
]],
prompt
:
list
[
d
ict
[
str
,
str
]],
response
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
)
->
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]:
images
:
list
[
"ImageInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
list
[
"AudioInput"
],
)
->
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]:
if
len
(
response
)
==
1
:
messages
=
prompt
+
response
else
:
...
...
@@ -56,7 +56,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
labels
=
labels
[:
target_len
]
return
input_ids
,
labels
def
preprocess_dataset
(
self
,
examples
:
D
ict
[
str
,
L
ist
[
Any
]])
->
D
ict
[
str
,
L
ist
[
Any
]]:
def
preprocess_dataset
(
self
,
examples
:
d
ict
[
str
,
l
ist
[
Any
]])
->
d
ict
[
str
,
l
ist
[
Any
]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
...
...
@@ -84,7 +84,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
return
model_inputs
def
print_data_example
(
self
,
example
:
D
ict
[
str
,
L
ist
[
int
]])
->
None
:
def
print_data_example
(
self
,
example
:
d
ict
[
str
,
l
ist
[
int
]])
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment