Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
7ea81099
Commit
7ea81099
authored
Apr 07, 2025
by
chenych
Browse files
update llama4
parent
84987715
Changes
139
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1344 additions
and
780 deletions
+1344
-780
src/llamafactory/chat/chat_model.py
src/llamafactory/chat/chat_model.py
+38
-47
src/llamafactory/chat/hf_engine.py
src/llamafactory/chat/hf_engine.py
+47
-53
src/llamafactory/chat/sglang_engine.py
src/llamafactory/chat/sglang_engine.py
+275
-0
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+47
-37
src/llamafactory/cli.py
src/llamafactory/cli.py
+2
-3
src/llamafactory/data/__init__.py
src/llamafactory/data/__init__.py
+4
-4
src/llamafactory/data/collator.py
src/llamafactory/data/collator.py
+61
-35
src/llamafactory/data/converter.py
src/llamafactory/data/converter.py
+27
-29
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+10
-15
src/llamafactory/data/formatter.py
src/llamafactory/data/formatter.py
+6
-9
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+11
-19
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+732
-446
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+9
-13
src/llamafactory/data/processor/__init__.py
src/llamafactory/data/processor/__init__.py
+15
-1
src/llamafactory/data/processor/feedback.py
src/llamafactory/data/processor/feedback.py
+12
-12
src/llamafactory/data/processor/pairwise.py
src/llamafactory/data/processor/pairwise.py
+9
-9
src/llamafactory/data/processor/pretrain.py
src/llamafactory/data/processor/pretrain.py
+4
-4
src/llamafactory/data/processor/processor_utils.py
src/llamafactory/data/processor/processor_utils.py
+12
-24
src/llamafactory/data/processor/supervised.py
src/llamafactory/data/processor/supervised.py
+14
-11
src/llamafactory/data/processor/unsupervised.py
src/llamafactory/data/processor/unsupervised.py
+9
-9
No files found.
src/llamafactory/chat/chat_model.py
View file @
7ea81099
# Copyright 202
4
THUDM and the LlamaFactory team.
# Copyright 202
5
THUDM and the LlamaFactory team.
#
#
# This code is inspired by the THUDM's ChatGLM implementation.
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
...
@@ -17,13 +17,15 @@
...
@@ -17,13 +17,15 @@
import
asyncio
import
asyncio
import
os
import
os
from
collections.abc
import
AsyncGenerator
,
Generator
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Dict
,
Generator
,
List
,
Optional
,
Sequence
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
..extras.constants
import
EngineName
from
..extras.constants
import
EngineName
from
..extras.misc
import
torch_gc
from
..extras.misc
import
torch_gc
from
..hparams
import
get_infer_args
from
..hparams
import
get_infer_args
from
.hf_engine
import
HuggingfaceEngine
from
.hf_engine
import
HuggingfaceEngine
from
.sglang_engine
import
SGLangEngine
from
.vllm_engine
import
VllmEngine
from
.vllm_engine
import
VllmEngine
...
@@ -38,20 +40,21 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
...
@@ -38,20 +40,21 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class
ChatModel
:
class
ChatModel
:
r
"""
r
"""General class for chat models. Backed by huggingface or vllm engines.
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
"""
def
__init__
(
self
,
args
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
)
->
None
:
def
__init__
(
self
,
args
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
)
->
None
:
model_args
,
data_args
,
finetuning_args
,
generating_args
=
get_infer_args
(
args
)
model_args
,
data_args
,
finetuning_args
,
generating_args
=
get_infer_args
(
args
)
if
model_args
.
infer_backend
==
EngineName
.
HF
:
if
model_args
.
infer_backend
==
EngineName
.
HF
:
self
.
engine
:
"
BaseEngine
"
=
HuggingfaceEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
self
.
engine
:
BaseEngine
=
HuggingfaceEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
elif
model_args
.
infer_backend
==
EngineName
.
VLLM
:
elif
model_args
.
infer_backend
==
EngineName
.
VLLM
:
self
.
engine
:
"BaseEngine"
=
VllmEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
self
.
engine
:
BaseEngine
=
VllmEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
elif
model_args
.
infer_backend
==
EngineName
.
SGLANG
:
self
.
engine
:
BaseEngine
=
SGLangEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
else
:
else
:
raise
NotImplementedError
(
f
"Unknown backend:
{
model_args
.
infer_backend
}
"
)
raise
NotImplementedError
(
f
"Unknown backend:
{
model_args
.
infer_backend
}
"
)
...
@@ -61,17 +64,15 @@ class ChatModel:
...
@@ -61,17 +64,15 @@ class ChatModel:
def
chat
(
def
chat
(
self
,
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
)
->
list
[
"Response"
]:
r
"""
r
"""Get a list of responses of the chat model."""
Gets a list of responses of the chat model.
"""
task
=
asyncio
.
run_coroutine_threadsafe
(
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
achat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
),
self
.
_loop
self
.
achat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
),
self
.
_loop
)
)
...
@@ -79,32 +80,28 @@ class ChatModel:
...
@@ -79,32 +80,28 @@ class ChatModel:
async
def
achat
(
async
def
achat
(
self
,
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
)
->
list
[
"Response"
]:
r
"""
r
"""Asynchronously get a list of responses of the chat model."""
Asynchronously gets a list of responses of the chat model.
"""
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
def
stream_chat
(
def
stream_chat
(
self
,
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
Generator
[
str
,
None
,
None
]:
)
->
Generator
[
str
,
None
,
None
]:
r
"""
r
"""Get the response token-by-token of the chat model."""
Gets the response token-by-token of the chat model.
"""
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
while
True
:
while
True
:
try
:
try
:
...
@@ -115,17 +112,15 @@ class ChatModel:
...
@@ -115,17 +112,15 @@ class ChatModel:
async
def
astream_chat
(
async
def
astream_chat
(
self
,
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
r
"""
r
"""Asynchronously get the response token-by-token of the chat model."""
Asynchronously gets the response token-by-token of the chat model.
"""
async
for
new_token
in
self
.
engine
.
stream_chat
(
async
for
new_token
in
self
.
engine
.
stream_chat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
):
):
...
@@ -133,23 +128,19 @@ class ChatModel:
...
@@ -133,23 +128,19 @@ class ChatModel:
def
get_scores
(
def
get_scores
(
self
,
self
,
batch_input
:
L
ist
[
str
],
batch_input
:
l
ist
[
str
],
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
float
]:
)
->
list
[
float
]:
r
"""
r
"""Get a list of scores of the reward model."""
Gets a list of scores of the reward model.
"""
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
aget_scores
(
batch_input
,
**
input_kwargs
),
self
.
_loop
)
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
aget_scores
(
batch_input
,
**
input_kwargs
),
self
.
_loop
)
return
task
.
result
()
return
task
.
result
()
async
def
aget_scores
(
async
def
aget_scores
(
self
,
self
,
batch_input
:
L
ist
[
str
],
batch_input
:
l
ist
[
str
],
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
float
]:
)
->
list
[
float
]:
r
"""
r
"""Asynchronously get a list of scores of the reward model."""
Asynchronously gets a list of scores of the reward model.
"""
return
await
self
.
engine
.
get_scores
(
batch_input
,
**
input_kwargs
)
return
await
self
.
engine
.
get_scores
(
batch_input
,
**
input_kwargs
)
...
...
src/llamafactory/chat/hf_engine.py
View file @
7ea81099
...
@@ -13,10 +13,10 @@
...
@@ -13,10 +13,10 @@
# limitations under the License.
# limitations under the License.
import
asyncio
import
asyncio
import
concurrent.futures
import
os
import
os
from
collections.abc
import
AsyncGenerator
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
import
torch
import
torch
from
transformers
import
GenerationConfig
,
TextIteratorStreamer
from
transformers
import
GenerationConfig
,
TextIteratorStreamer
...
@@ -76,15 +76,15 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -76,15 +76,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
template
:
"Template"
,
template
:
"Template"
,
generating_args
:
D
ict
[
str
,
Any
],
generating_args
:
d
ict
[
str
,
Any
],
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
input_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
{},
input_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
{},
)
->
T
uple
[
D
ict
[
str
,
Any
],
int
]:
)
->
t
uple
[
d
ict
[
str
,
Any
],
int
]:
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"audios"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
],
"audlens"
:
[
0
]}
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"audios"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
],
"audlens"
:
[
0
]}
if
images
is
not
None
:
if
images
is
not
None
:
mm_input_dict
.
update
({
"images"
:
images
,
"imglens"
:
[
len
(
images
)]})
mm_input_dict
.
update
({
"images"
:
images
,
"imglens"
:
[
len
(
images
)]})
...
@@ -130,7 +130,7 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -130,7 +130,7 @@ class HuggingfaceEngine(BaseEngine):
skip_special_tokens
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"skip_special_tokens"
,
None
)
skip_special_tokens
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"skip_special_tokens"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
L
ist
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
stop
:
Optional
[
Union
[
str
,
l
ist
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
stop
is
not
None
:
if
stop
is
not
None
:
logger
.
warning_rank0
(
"Stop parameter is not supported by the huggingface engine yet."
)
logger
.
warning_rank0
(
"Stop parameter is not supported by the huggingface engine yet."
)
...
@@ -217,15 +217,15 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -217,15 +217,15 @@ class HuggingfaceEngine(BaseEngine):
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
template
:
"Template"
,
template
:
"Template"
,
generating_args
:
D
ict
[
str
,
Any
],
generating_args
:
d
ict
[
str
,
Any
],
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
input_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
{},
input_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
{},
)
->
L
ist
[
"Response"
]:
)
->
l
ist
[
"Response"
]:
gen_kwargs
,
prompt_length
=
HuggingfaceEngine
.
_process_args
(
gen_kwargs
,
prompt_length
=
HuggingfaceEngine
.
_process_args
(
model
,
model
,
tokenizer
,
tokenizer
,
...
@@ -272,14 +272,14 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -272,14 +272,14 @@ class HuggingfaceEngine(BaseEngine):
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
template
:
"Template"
,
template
:
"Template"
,
generating_args
:
D
ict
[
str
,
Any
],
generating_args
:
d
ict
[
str
,
Any
],
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
input_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
{},
input_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
{},
)
->
Callable
[[],
str
]:
)
->
Callable
[[],
str
]:
gen_kwargs
,
_
=
HuggingfaceEngine
.
_process_args
(
gen_kwargs
,
_
=
HuggingfaceEngine
.
_process_args
(
model
,
model
,
...
@@ -317,12 +317,12 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -317,12 +317,12 @@ class HuggingfaceEngine(BaseEngine):
def
_get_scores
(
def
_get_scores
(
model
:
"PreTrainedModelWrapper"
,
model
:
"PreTrainedModelWrapper"
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
batch_input
:
L
ist
[
str
],
batch_input
:
l
ist
[
str
],
input_kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
{},
input_kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
{},
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
device
=
getattr
(
model
.
pretrained_model
,
"device"
,
"cuda"
)
device
=
getattr
(
model
.
pretrained_model
,
"device"
,
"cuda"
)
inputs
:
D
ict
[
str
,
"
torch.Tensor
"
]
=
tokenizer
(
inputs
:
d
ict
[
str
,
torch
.
Tensor
]
=
tokenizer
(
batch_input
,
batch_input
,
padding
=
True
,
padding
=
True
,
truncation
=
True
,
truncation
=
True
,
...
@@ -330,25 +330,24 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -330,25 +330,24 @@ class HuggingfaceEngine(BaseEngine):
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
add_special_tokens
=
False
,
add_special_tokens
=
False
,
).
to
(
device
)
).
to
(
device
)
values
:
"
torch.Tensor
"
=
model
(
**
inputs
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
values
:
torch
.
Tensor
=
model
(
**
inputs
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
scores
=
values
.
gather
(
dim
=-
1
,
index
=
(
inputs
[
"attention_mask"
].
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
scores
=
values
.
gather
(
dim
=-
1
,
index
=
(
inputs
[
"attention_mask"
].
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
return
scores
return
scores
@
override
@
override
async
def
chat
(
async
def
chat
(
self
,
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
L
ist
[
"Response"
]:
)
->
l
ist
[
"Response"
]:
if
not
self
.
can_generate
:
if
not
self
.
can_generate
:
raise
ValueError
(
"The current model does not support `chat`."
)
raise
ValueError
(
"The current model does not support `chat`."
)
loop
=
asyncio
.
get_running_loop
()
input_args
=
(
input_args
=
(
self
.
model
,
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer
,
...
@@ -364,24 +363,22 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -364,24 +363,22 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs
,
input_kwargs
,
)
)
async
with
self
.
semaphore
:
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
return
await
asyncio
.
to_thread
(
self
.
_chat
,
*
input_args
)
return
await
loop
.
run_in_executor
(
pool
,
self
.
_chat
,
*
input_args
)
@
override
@
override
async
def
stream_chat
(
async
def
stream_chat
(
self
,
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
if
not
self
.
can_generate
:
if
not
self
.
can_generate
:
raise
ValueError
(
"The current model does not support `stream_chat`."
)
raise
ValueError
(
"The current model does not support `stream_chat`."
)
loop
=
asyncio
.
get_running_loop
()
input_args
=
(
input_args
=
(
self
.
model
,
self
.
model
,
self
.
tokenizer
,
self
.
tokenizer
,
...
@@ -397,25 +394,22 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -397,25 +394,22 @@ class HuggingfaceEngine(BaseEngine):
input_kwargs
,
input_kwargs
,
)
)
async
with
self
.
semaphore
:
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
stream
=
self
.
_stream_chat
(
*
input_args
)
stream
=
self
.
_stream_chat
(
*
input_args
)
while
True
:
while
True
:
try
:
try
:
yield
await
asyncio
.
to_thread
(
stream
)
yield
await
loop
.
run_in_executor
(
pool
,
stream
)
except
StopAsyncIteration
:
except
StopAsyncIteration
:
break
break
@
override
@
override
async
def
get_scores
(
async
def
get_scores
(
self
,
self
,
batch_input
:
L
ist
[
str
],
batch_input
:
l
ist
[
str
],
**
input_kwargs
,
**
input_kwargs
,
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
if
self
.
can_generate
:
if
self
.
can_generate
:
raise
ValueError
(
"Cannot get scores using an auto-regressive model."
)
raise
ValueError
(
"Cannot get scores using an auto-regressive model."
)
loop
=
asyncio
.
get_running_loop
()
input_args
=
(
self
.
model
,
self
.
tokenizer
,
batch_input
,
input_kwargs
)
input_args
=
(
self
.
model
,
self
.
tokenizer
,
batch_input
,
input_kwargs
)
async
with
self
.
semaphore
:
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
return
await
asyncio
.
to_thread
(
self
.
_get_scores
,
*
input_args
)
return
await
loop
.
run_in_executor
(
pool
,
self
.
_get_scores
,
*
input_args
)
src/llamafactory/chat/sglang_engine.py
0 → 100644
View file @
7ea81099
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
import
atexit
import
json
from
collections.abc
import
AsyncGenerator
,
AsyncIterator
,
Sequence
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
import
requests
from
typing_extensions
import
override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
,
EngineName
from
..extras.misc
import
get_device_count
,
torch_gc
from
..extras.packages
import
is_sglang_available
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
from
..model
import
load_config
,
load_tokenizer
from
..model.model_utils.quantization
import
QuantizationMethod
from
.base_engine
import
BaseEngine
,
Response
if
is_sglang_available
():
from
sglang.utils
import
launch_server_cmd
,
terminate_process
,
wait_for_server
# type: ignore
if
TYPE_CHECKING
:
from
..data.mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
logger
=
logging
.
get_logger
(
__name__
)
class
SGLangEngine
(
BaseEngine
):
"""Inference engine for SGLang models.
This class wraps the SGLang engine to provide a consistent interface for text generation
that matches LLaMA Factory's requirements. It uses the SGLang HTTP server approach for
better interaction and performance. The engine launches a server process and communicates
with it via HTTP requests.
For more details on the SGLang HTTP server approach, see:
https://docs.sglang.ai/backend/send_request.html
"""
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
self
.
name
=
EngineName
.
SGLANG
self
.
model_args
=
model_args
config
=
load_config
(
model_args
)
# may download model from ms hub
if
getattr
(
config
,
"quantization_config"
,
None
):
# gptq models should use float16
quantization_config
:
dict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
==
QuantizationMethod
.
GPTQ
and
model_args
.
infer_dtype
==
"auto"
:
model_args
.
infer_dtype
=
"float16"
self
.
can_generate
=
finetuning_args
.
stage
==
"sft"
tokenizer_module
=
load_tokenizer
(
model_args
)
self
.
tokenizer
=
tokenizer_module
[
"tokenizer"
]
self
.
processor
=
tokenizer_module
[
"processor"
]
self
.
tokenizer
.
padding_side
=
"left"
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
template
.
mm_plugin
.
expand_mm_tokens
=
False
# for sglang generate
self
.
generating_args
=
generating_args
.
to_dict
()
launch_cmd
=
[
"python3 -m sglang.launch_server"
,
f
"--model-path
{
model_args
.
model_name_or_path
}
"
,
f
"--dtype
{
model_args
.
infer_dtype
}
"
,
f
"--context-length
{
model_args
.
sglang_maxlen
}
"
,
f
"--mem-fraction-static
{
model_args
.
sglang_mem_fraction
}
"
,
f
"--tp-size
{
model_args
.
sglang_tp_size
if
model_args
.
sglang_tp_size
!=
-
1
else
get_device_count
()
or
1
}
"
,
f
"--download-dir
{
model_args
.
cache_dir
}
"
,
"--log-level error"
,
]
launch_cmd
=
" "
.
join
(
launch_cmd
)
logger
.
info_rank0
(
f
"Starting SGLang server with command:
{
launch_cmd
}
"
)
try
:
torch_gc
()
self
.
server_process
,
port
=
launch_server_cmd
(
launch_cmd
)
self
.
base_url
=
f
"http://localhost:
{
port
}
"
atexit
.
register
(
self
.
_cleanup_server
)
logger
.
info_rank0
(
f
"Waiting for SGLang server to be ready at
{
self
.
base_url
}
"
)
wait_for_server
(
self
.
base_url
,
timeout
=
300
)
logger
.
info_rank0
(
f
"SGLang server initialized successfully at
{
self
.
base_url
}
"
)
try
:
response
=
requests
.
get
(
f
"
{
self
.
base_url
}
/get_model_info"
,
timeout
=
5
)
if
response
.
status_code
==
200
:
model_info
=
response
.
json
()
logger
.
info
(
f
"SGLang server model info:
{
model_info
}
"
)
except
Exception
as
e
:
logger
.
debug
(
f
"Note: could not get model info:
{
str
(
e
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start SGLang server:
{
str
(
e
)
}
"
)
self
.
_cleanup_server
()
# make sure to clean up any started process
raise
RuntimeError
(
f
"SGLang server initialization failed:
{
str
(
e
)
}
."
)
def
_cleanup_server
(
self
):
r
"""Clean up the server process when the engine is destroyed."""
if
hasattr
(
self
,
"server_process"
)
and
self
.
server_process
:
try
:
logger
.
info
(
"Terminating SGLang server process"
)
terminate_process
(
self
.
server_process
)
logger
.
info
(
"SGLang server process terminated"
)
except
Exception
as
e
:
logger
.
warning
(
f
"Error terminating SGLang server:
{
str
(
e
)
}
"
)
async
def
_generate
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncIterator
[
dict
[
str
,
Any
]]:
if
images
is
not
None
and
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
*
len
(
images
)
+
messages
[
0
][
"content"
]
if
videos
is
not
None
and
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
*
len
(
videos
)
+
messages
[
0
][
"content"
]
if
audios
is
not
None
and
not
any
(
AUDIO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
AUDIO_PLACEHOLDER
*
len
(
audios
)
+
messages
[
0
][
"content"
]
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
,
images
or
[],
videos
or
[],
audios
or
[],
self
.
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
temperature
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"temperature"
,
None
)
top_p
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_p"
,
None
)
top_k
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_k"
,
None
)
num_return_sequences
:
int
=
input_kwargs
.
pop
(
"num_return_sequences"
,
1
)
repetition_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"repetition_penalty"
,
None
)
skip_special_tokens
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"skip_special_tokens"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
num_return_sequences
!=
1
:
raise
NotImplementedError
(
"SGLang only supports n=1."
)
if
"max_new_tokens"
in
self
.
generating_args
:
max_tokens
=
self
.
generating_args
[
"max_new_tokens"
]
elif
"max_length"
in
self
.
generating_args
:
if
self
.
generating_args
[
"max_length"
]
>
prompt_length
:
max_tokens
=
self
.
generating_args
[
"max_length"
]
-
prompt_length
else
:
max_tokens
=
1
if
max_length
:
max_tokens
=
max_length
-
prompt_length
if
max_length
>
prompt_length
else
1
if
max_new_tokens
:
max_tokens
=
max_new_tokens
sampling_params
=
{
"temperature"
:
temperature
if
temperature
is
not
None
else
self
.
generating_args
[
"temperature"
],
"top_p"
:
(
top_p
if
top_p
is
not
None
else
self
.
generating_args
[
"top_p"
])
or
1.0
,
# top_p must > 0
"top_k"
:
(
top_k
if
top_k
is
not
None
else
self
.
generating_args
[
"top_k"
])
or
-
1
,
# top_k must > 0
"stop"
:
stop
,
"stop_token_ids"
:
self
.
template
.
get_stop_token_ids
(
self
.
tokenizer
),
"max_new_tokens"
:
max_tokens
,
"repetition_penalty"
:
(
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
generating_args
[
"repetition_penalty"
]
)
or
1.0
,
# repetition_penalty must > 0
"skip_special_tokens"
:
skip_special_tokens
if
skip_special_tokens
is
not
None
else
self
.
generating_args
[
"skip_special_tokens"
],
}
def
stream_request
():
json_data
=
{
"input_ids"
:
prompt_ids
,
"sampling_params"
:
sampling_params
,
"stream"
:
True
,
}
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/generate"
,
json
=
json_data
,
stream
=
True
)
if
response
.
status_code
!=
200
:
raise
RuntimeError
(
f
"SGLang server error:
{
response
.
status_code
}
,
{
response
.
text
}
"
)
for
chunk
in
response
.
iter_lines
(
decode_unicode
=
False
):
chunk
=
str
(
chunk
.
decode
(
"utf-8"
))
if
chunk
==
"data: [DONE]"
:
break
if
chunk
and
chunk
.
startswith
(
"data:"
):
yield
json
.
loads
(
chunk
[
5
:].
strip
(
"
\n
"
))
return
await
asyncio
.
to_thread
(
stream_request
)
@
override
async
def
chat
(
self
,
messages
:
Sequence
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
list
[
"Response"
]:
final_output
=
None
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
for
request_output
in
generator
:
final_output
=
request_output
results
=
[
Response
(
response_text
=
final_output
[
"text"
],
response_length
=
final_output
[
"meta_info"
][
"completion_tokens"
],
prompt_length
=
final_output
[
"meta_info"
][
"prompt_tokens"
],
finish_reason
=
"stop"
if
final_output
[
"meta_info"
][
"finish_reason"
]
==
"stop"
else
"length"
,
)
]
return
results
@
override
async
def
stream_chat
(
self
,
messages
:
list
[
dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
generated_text
=
""
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
for
result
in
generator
:
delta_text
=
result
[
"text"
][
len
(
generated_text
)
:]
generated_text
=
result
[
"text"
]
yield
delta_text
@
override
async
def
get_scores
(
self
,
batch_input
:
list
[
str
],
**
input_kwargs
,
)
->
list
[
float
]:
raise
NotImplementedError
(
"SGLang engine does not support `get_scores`."
)
def
__del__
(
self
):
r
"""Ensure server is cleaned up when object is deleted."""
self
.
_cleanup_server
()
try
:
atexit
.
unregister
(
self
.
_cleanup_server
)
except
Exception
:
pass
src/llamafactory/chat/vllm_engine.py
View file @
7ea81099
...
@@ -13,7 +13,8 @@
...
@@ -13,7 +13,8 @@
# limitations under the License.
# limitations under the License.
import
uuid
import
uuid
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
AsyncIterator
,
Dict
,
List
,
Optional
,
Sequence
,
Union
from
collections.abc
import
AsyncGenerator
,
AsyncIterator
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
typing_extensions
import
override
from
typing_extensions
import
override
...
@@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
...
@@ -53,7 +54,7 @@ class VllmEngine(BaseEngine):
self
.
model_args
=
model_args
self
.
model_args
=
model_args
config
=
load_config
(
model_args
)
# may download model from ms hub
config
=
load_config
(
model_args
)
# may download model from ms hub
if
getattr
(
config
,
"quantization_config"
,
None
):
# gptq models should use float16
if
getattr
(
config
,
"quantization_config"
,
None
):
# gptq models should use float16
quantization_config
:
D
ict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quantization_config
:
d
ict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
==
QuantizationMethod
.
GPTQ
and
model_args
.
infer_dtype
==
"auto"
:
if
quant_method
==
QuantizationMethod
.
GPTQ
and
model_args
.
infer_dtype
==
"auto"
:
model_args
.
infer_dtype
=
"float16"
model_args
.
infer_dtype
=
"float16"
...
@@ -82,7 +83,7 @@ class VllmEngine(BaseEngine):
...
@@ -82,7 +83,7 @@ class VllmEngine(BaseEngine):
"max_lora_rank"
:
model_args
.
vllm_max_lora_rank
,
"max_lora_rank"
:
model_args
.
vllm_max_lora_rank
,
}
}
if
self
.
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
if
self
.
template
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
engine_args
[
"limit_mm_per_prompt"
]
=
{
"image"
:
4
,
"video"
:
2
}
engine_args
[
"limit_mm_per_prompt"
]
=
{
"image"
:
4
,
"video"
:
2
,
"audio"
:
2
}
if
isinstance
(
model_args
.
vllm_config
,
dict
):
if
isinstance
(
model_args
.
vllm_config
,
dict
):
engine_args
.
update
(
model_args
.
vllm_config
)
engine_args
.
update
(
model_args
.
vllm_config
)
...
@@ -101,33 +102,26 @@ class VllmEngine(BaseEngine):
...
@@ -101,33 +102,26 @@ class VllmEngine(BaseEngine):
async
def
_generate
(
async
def
_generate
(
self
,
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncIterator
[
"RequestOutput"
]:
)
->
AsyncIterator
[
"RequestOutput"
]:
request_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
request_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"audios"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
],
"audlens"
:
[
0
]}
if
images
is
not
None
and
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
if
images
is
not
None
:
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
*
len
(
images
)
+
messages
[
0
][
"content"
]
mm_input_dict
.
update
({
"images"
:
images
,
"imglens"
:
[
len
(
images
)]})
if
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
if
videos
is
not
None
and
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
*
len
(
images
)
+
messages
[
0
][
"content"
]
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
*
len
(
videos
)
+
messages
[
0
][
"content"
]
if
videos
is
not
None
:
if
audios
is
not
None
and
not
any
(
AUDIO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
mm_input_dict
.
update
({
"videos"
:
videos
,
"vidlens"
:
[
len
(
videos
)]})
messages
[
0
][
"content"
]
=
AUDIO_PLACEHOLDER
*
len
(
audios
)
+
messages
[
0
][
"content"
]
if
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
*
len
(
videos
)
+
messages
[
0
][
"content"
]
if
audios
is
not
None
:
mm_input_dict
.
update
({
"audios"
:
audios
,
"audlens"
:
[
len
(
audios
)]})
if
not
any
(
AUDIO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
AUDIO_PLACEHOLDER
*
len
(
audios
)
+
messages
[
0
][
"content"
]
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
mm_input_dict
[
"
audios
"
],
self
.
processor
messages
,
images
or
[],
videos
or
[],
audios
or
[
],
self
.
processor
)
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
system
=
system
or
self
.
generating_args
[
"default_system"
]
...
@@ -143,7 +137,7 @@ class VllmEngine(BaseEngine):
...
@@ -143,7 +137,7 @@ class VllmEngine(BaseEngine):
skip_special_tokens
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"skip_special_tokens"
,
None
)
skip_special_tokens
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"skip_special_tokens"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
L
ist
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
stop
:
Optional
[
Union
[
str
,
l
ist
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
length_penalty
is
not
None
:
if
length_penalty
is
not
None
:
logger
.
warning_rank0
(
"Length penalty is not supported by the vllm engine yet."
)
logger
.
warning_rank0
(
"Length penalty is not supported by the vllm engine yet."
)
...
@@ -185,8 +179,24 @@ class VllmEngine(BaseEngine):
...
@@ -185,8 +179,24 @@ class VllmEngine(BaseEngine):
images
,
images
,
image_max_pixels
=
self
.
model_args
.
image_max_pixels
,
image_max_pixels
=
self
.
model_args
.
image_max_pixels
,
image_min_pixels
=
self
.
model_args
.
image_min_pixels
,
image_min_pixels
=
self
.
model_args
.
image_min_pixels
,
)
)
[
"images"
]
}
}
elif
videos
is
not
None
:
multi_modal_data
=
{
"video"
:
self
.
template
.
mm_plugin
.
_regularize_videos
(
videos
,
image_max_pixels
=
self
.
model_args
.
video_max_pixels
,
image_min_pixels
=
self
.
model_args
.
video_min_pixels
,
video_fps
=
self
.
model_args
.
video_fps
,
video_maxlen
=
self
.
model_args
.
video_maxlen
,
)[
"videos"
]
}
elif
audios
is
not
None
:
audio_data
=
self
.
template
.
mm_plugin
.
_regularize_audios
(
audios
,
sampling_rate
=
self
.
model_args
.
audio_sampling_rate
,
)
multi_modal_data
=
{
"audio"
:
zip
(
audio_data
[
"audios"
],
audio_data
[
"sampling_rates"
])}
else
:
else
:
multi_modal_data
=
None
multi_modal_data
=
None
...
@@ -201,14 +211,14 @@ class VllmEngine(BaseEngine):
...
@@ -201,14 +211,14 @@ class VllmEngine(BaseEngine):
@
override
@
override
async
def
chat
(
async
def
chat
(
self
,
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
L
ist
[
"Response"
]:
)
->
l
ist
[
"Response"
]:
final_output
=
None
final_output
=
None
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
async
for
request_output
in
generator
:
async
for
request_output
in
generator
:
...
@@ -230,12 +240,12 @@ class VllmEngine(BaseEngine):
...
@@ -230,12 +240,12 @@ class VllmEngine(BaseEngine):
@
override
@
override
async
def
stream_chat
(
async
def
stream_chat
(
self
,
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
generated_text
=
""
generated_text
=
""
...
@@ -248,7 +258,7 @@ class VllmEngine(BaseEngine):
...
@@ -248,7 +258,7 @@ class VllmEngine(BaseEngine):
@
override
@
override
async
def
get_scores
(
async
def
get_scores
(
self
,
self
,
batch_input
:
L
ist
[
str
],
batch_input
:
l
ist
[
str
],
**
input_kwargs
,
**
input_kwargs
,
)
->
L
ist
[
float
]:
)
->
l
ist
[
float
]:
raise
NotImplementedError
(
"vLLM engine does not support get_scores."
)
raise
NotImplementedError
(
"vLLM engine does not support
`
get_scores
`
."
)
src/llamafactory/cli.py
View file @
7ea81099
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
random
import
subprocess
import
subprocess
import
sys
import
sys
from
enum
import
Enum
,
unique
from
enum
import
Enum
,
unique
...
@@ -24,7 +23,7 @@ from .chat.chat_model import run_chat
...
@@ -24,7 +23,7 @@ from .chat.chat_model import run_chat
from
.eval.evaluator
import
run_eval
from
.eval.evaluator
import
run_eval
from
.extras
import
logging
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.env
import
VERSION
,
print_env
from
.extras.misc
import
get_device_count
,
is_env_enabled
,
use_ray
from
.extras.misc
import
find_available_port
,
get_device_count
,
is_env_enabled
,
use_ray
from
.train.tuner
import
export_model
,
run_exp
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
from
.webui.interface
import
run_web_demo
,
run_web_ui
...
@@ -92,7 +91,7 @@ def main():
...
@@ -92,7 +91,7 @@ def main():
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
)
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
)
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
()))
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
()))
master_addr
=
os
.
getenv
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_addr
=
os
.
getenv
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
str
(
random
.
randint
(
20001
,
29999
)))
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
str
(
find_available_port
(
)))
logger
.
info_rank0
(
f
"Initializing
{
nproc_per_node
}
distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
logger
.
info_rank0
(
f
"Initializing
{
nproc_per_node
}
distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
if
int
(
nnodes
)
>
1
:
if
int
(
nnodes
)
>
1
:
print
(
f
"Multi-node training enabled: num nodes:
{
nnodes
}
, node rank:
{
node_rank
}
"
)
print
(
f
"Multi-node training enabled: num nodes:
{
nnodes
}
, node rank:
{
node_rank
}
"
)
...
...
src/llamafactory/data/__init__.py
View file @
7ea81099
...
@@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
...
@@ -24,14 +24,14 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__
=
[
__all__
=
[
"TEMPLATES"
,
"KTODataCollatorWithPadding"
,
"KTODataCollatorWithPadding"
,
"MultiModalDataCollatorForSeq2Seq"
,
"MultiModalDataCollatorForSeq2Seq"
,
"PairwiseDataCollatorWithPadding"
,
"PairwiseDataCollatorWithPadding"
,
"SFTDataCollatorWith4DAttentionMask"
,
"Role"
,
"Role"
,
"split_dataset"
,
"SFTDataCollatorWith4DAttentionMask"
,
"get_dataset"
,
"TEMPLATES"
,
"Template"
,
"Template"
,
"get_dataset"
,
"get_template_and_fix_tokenizer"
,
"get_template_and_fix_tokenizer"
,
"split_dataset"
,
]
]
src/llamafactory/data/collator.py
View file @
7ea81099
# Copyright 202
4
OpenAccess AI Collective and the LlamaFactory team.
# Copyright 202
5
OpenAccess AI Collective and the LlamaFactory team.
#
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Literal
,
Optional
,
Sequence
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -24,6 +24,7 @@ import torch.nn.functional as F
...
@@ -24,6 +24,7 @@ import torch.nn.functional as F
from
transformers
import
DataCollatorForSeq2Seq
from
transformers
import
DataCollatorForSeq2Seq
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
from
..extras.misc
import
get_current_device
from
..extras.packages
import
is_pillow_available
from
..extras.packages
import
is_pillow_available
...
@@ -38,9 +39,10 @@ if TYPE_CHECKING:
...
@@ -38,9 +39,10 @@ if TYPE_CHECKING:
def
prepare_4d_attention_mask
(
attention_mask_with_indices
:
"torch.Tensor"
,
dtype
:
"torch.dtype"
)
->
"torch.Tensor"
:
def
prepare_4d_attention_mask
(
attention_mask_with_indices
:
"torch.Tensor"
,
dtype
:
"torch.dtype"
)
->
"torch.Tensor"
:
r
"""
r
"""Expand 2d attention mask to 4d attention mask.
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
Expand the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
handle packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
e.g.
```python
```python
...
@@ -62,24 +64,37 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
...
@@ -62,24 +64,37 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
```
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
"""
bsz
,
seq_len
=
attention_mask_with_indices
.
size
()
_
,
seq_len
=
attention_mask_with_indices
.
size
()
# Move to compute device if the source is CPU.
source_device
=
attention_mask_with_indices
.
device
compute_device
=
get_current_device
()
if
source_device
.
type
==
"cpu"
else
source_device
if
compute_device
!=
source_device
:
attention_mask_with_indices
=
attention_mask_with_indices
.
to
(
compute_device
)
min_dtype
=
torch
.
finfo
(
dtype
).
min
min_dtype
=
torch
.
finfo
(
dtype
).
min
expanded_mask
=
attention_mask_with_indices
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
seq_len
,
seq_len
)
zero_tensor
=
torch
.
tensor
(
0
,
dtype
=
dtype
,
device
=
compute_device
)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
padding_mask
=
torch
.
where
(
expanded_mask
!=
0
,
1
,
0
)
# Create a non-padding mask.
# Create a block-diagonal mask.
non_padding
=
(
attention_mask_with_indices
!=
0
).
unsqueeze
(
1
).
unsqueeze
(
2
)
attention_mask_4d
=
torch
.
eq
(
expanded_mask
,
expanded_mask
.
transpose
(
-
1
,
-
2
)).
int
()
*
padding_mask
# Create indices for comparison.
# Use the lower triangular mask to zero out the upper triangular part
indices
=
attention_mask_with_indices
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# [bsz, 1, 1, seq_len]
attention_mask_4d
*=
torch
.
tril
(
torch
.
ones
((
seq_len
,
seq_len
),
dtype
=
torch
.
long
))
indices_t
=
attention_mask_with_indices
.
unsqueeze
(
1
).
unsqueeze
(
3
)
# [bsz, 1, seq_len, 1]
# Create a lower triangular mask.
tril_mask
=
torch
.
tril
(
torch
.
ones
((
seq_len
,
seq_len
),
dtype
=
torch
.
bool
,
device
=
compute_device
))
attention_mask_4d
=
(
indices
==
indices_t
)
&
non_padding
&
tril_mask
# Invert the attention mask.
# Invert the attention mask.
attention_mask_4d
=
torch
.
where
(
attention_mask_4d
!=
0
,
torch
.
tensor
(
0
,
dtype
=
dtype
),
min_dtype
)
attention_mask_4d
=
torch
.
where
(
attention_mask_4d
,
zero_tensor
,
min_dtype
)
# Move back to original device if needed.
if
compute_device
!=
source_device
:
attention_mask_4d
=
attention_mask_4d
.
to
(
source_device
)
return
attention_mask_4d
return
attention_mask_4d
@
dataclass
@
dataclass
class
MultiModalDataCollatorForSeq2Seq
(
DataCollatorForSeq2Seq
):
class
MultiModalDataCollatorForSeq2Seq
(
DataCollatorForSeq2Seq
):
r
"""
r
"""Data collator that supports VLMs.
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
Features should contain input_ids, attention_mask, labels, and optionally contain images, videos and audios.
"""
"""
...
@@ -91,7 +106,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
...
@@ -91,7 +106,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
if
self
.
template
is
None
:
if
self
.
template
is
None
:
raise
ValueError
(
"Template is required for MultiModalDataCollator."
)
raise
ValueError
(
"Template is required for MultiModalDataCollator."
)
def
__call__
(
self
,
features
:
Sequence
[
D
ict
[
str
,
Any
]])
->
D
ict
[
str
,
"torch.Tensor"
]:
def
__call__
(
self
,
features
:
list
[
d
ict
[
str
,
Any
]])
->
d
ict
[
str
,
"torch.Tensor"
]:
batch_images
,
batch_videos
,
batch_audios
=
[],
[],
[]
batch_images
,
batch_videos
,
batch_audios
=
[],
[],
[]
batch_imglens
,
batch_vidlens
,
batch_audlens
,
batch_input_ids
=
[],
[],
[],
[]
batch_imglens
,
batch_vidlens
,
batch_audlens
,
batch_input_ids
=
[],
[],
[],
[]
for
feature
in
features
:
for
feature
in
features
:
...
@@ -166,7 +181,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
...
@@ -166,7 +181,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
for
i
,
feature
in
enumerate
(
features
):
for
i
,
feature
in
enumerate
(
features
):
feature
[
"token_type_ids"
]
=
token_type_ids
[
i
]
feature
[
"token_type_ids"
]
=
token_type_ids
[
i
]
features
:
D
ict
[
str
,
"
torch.Tensor
"
]
=
super
().
__call__
(
features
)
features
:
d
ict
[
str
,
torch
.
Tensor
]
=
super
().
__call__
(
features
)
if
self
.
model
is
not
None
and
hasattr
(
self
.
model
,
"get_rope_index"
):
# for qwen2vl mrope
if
self
.
model
is
not
None
and
hasattr
(
self
.
model
,
"get_rope_index"
):
# for qwen2vl mrope
rope_index_kwargs
=
{
rope_index_kwargs
=
{
...
@@ -175,10 +190,28 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
...
@@ -175,10 +190,28 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"video_grid_thw"
:
mm_inputs
.
get
(
"video_grid_thw"
),
"video_grid_thw"
:
mm_inputs
.
get
(
"video_grid_thw"
),
"attention_mask"
:
features
[
"attention_mask"
],
"attention_mask"
:
features
[
"attention_mask"
],
}
}
if
"second_per_grid_ts"
in
mm_inputs
:
if
"second_per_grid_ts"
in
mm_inputs
:
# for qwen2vl
rope_index_kwargs
[
"second_per_grid_ts"
]
=
mm_inputs
.
get
(
"second_per_grid_ts"
)
rope_index_kwargs
[
"second_per_grid_ts"
]
=
mm_inputs
.
get
(
"second_per_grid_ts"
)
if
"video_second_per_grid"
in
mm_inputs
:
# for qwen2omni
features
[
"position_ids"
],
features
[
"rope_deltas"
]
=
self
.
model
.
get_rope_index
(
**
rope_index_kwargs
)
rope_index_kwargs
[
"second_per_grids"
]
=
mm_inputs
.
get
(
"video_second_per_grid"
)
if
getattr
(
self
.
model
.
config
,
"model_type"
,
None
)
==
"qwen2_5_omni_thinker"
:
# for qwen2omni
feature_attention_mask
=
mm_inputs
.
get
(
"feature_attention_mask"
,
None
)
if
feature_attention_mask
is
not
None
:
audio_feature_lengths
=
torch
.
sum
(
feature_attention_mask
,
dim
=
1
)
# FIXME need to get video image lengths
rope_index_kwargs
[
"audio_seqlens"
]
=
audio_feature_lengths
# prepare for input
delta0
=
(
1
-
rope_index_kwargs
[
"attention_mask"
]).
sum
(
dim
=-
1
).
unsqueeze
(
1
)
# avoid conflict
new_position_ids
,
rope_deltas
=
self
.
model
.
get_rope_index
(
**
rope_index_kwargs
)
features
[
"position_ids"
],
features
[
"rope_deltas"
]
=
(
new_position_ids
.
clone
(),
rope_deltas
-
delta0
,
)
# avoid inplace operation FIXME
else
:
# for qwen2vl
features
[
"position_ids"
],
features
[
"rope_deltas"
]
=
self
.
model
.
get_rope_index
(
**
rope_index_kwargs
)
if
"cross_attention_mask"
in
mm_inputs
:
# for mllama inputs when pad_to_multiple_of is enabled
if
"cross_attention_mask"
in
mm_inputs
:
# for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask
=
mm_inputs
.
pop
(
"cross_attention_mask"
)
cross_attention_mask
=
mm_inputs
.
pop
(
"cross_attention_mask"
)
...
@@ -198,15 +231,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
...
@@ -198,15 +231,13 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
@
dataclass
@
dataclass
class
SFTDataCollatorWith4DAttentionMask
(
MultiModalDataCollatorForSeq2Seq
):
class
SFTDataCollatorWith4DAttentionMask
(
MultiModalDataCollatorForSeq2Seq
):
r
"""
r
"""Data collator for 4d attention mask."""
Data collator for 4d attention mask.
"""
block_diag_attn
:
bool
=
False
block_diag_attn
:
bool
=
False
attn_implementation
:
Literal
[
"eager"
,
"sdpa"
,
"flash_attention_2"
]
=
"eager"
attn_implementation
:
Literal
[
"eager"
,
"sdpa"
,
"flash_attention_2"
]
=
"eager"
compute_dtype
:
"torch.dtype"
=
torch
.
float32
compute_dtype
:
"torch.dtype"
=
torch
.
float32
def
__call__
(
self
,
features
:
Sequence
[
D
ict
[
str
,
Any
]])
->
D
ict
[
str
,
"torch.Tensor"
]:
def
__call__
(
self
,
features
:
list
[
d
ict
[
str
,
Any
]])
->
d
ict
[
str
,
"torch.Tensor"
]:
features
=
super
().
__call__
(
features
)
features
=
super
().
__call__
(
features
)
if
self
.
block_diag_attn
and
self
.
attn_implementation
!=
"flash_attention_2"
:
if
self
.
block_diag_attn
and
self
.
attn_implementation
!=
"flash_attention_2"
:
features
[
"attention_mask"
]
=
prepare_4d_attention_mask
(
features
[
"attention_mask"
],
self
.
compute_dtype
)
features
[
"attention_mask"
]
=
prepare_4d_attention_mask
(
features
[
"attention_mask"
],
self
.
compute_dtype
)
...
@@ -220,13 +251,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
...
@@ -220,13 +251,10 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
@
dataclass
@
dataclass
class
PairwiseDataCollatorWithPadding
(
MultiModalDataCollatorForSeq2Seq
):
class
PairwiseDataCollatorWithPadding
(
MultiModalDataCollatorForSeq2Seq
):
r
"""
r
"""Data collator for pairwise data."""
Data collator for pairwise data.
"""
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
"torch.Tensor"
]:
def
__call__
(
self
,
features
:
list
[
dict
[
str
,
Any
]])
->
dict
[
str
,
"torch.Tensor"
]:
r
"""
r
"""Pad batched data to the longest sequence in the batch.
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
the last n examples represent rejected examples.
...
@@ -249,11 +277,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
...
@@ -249,11 +277,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
@
dataclass
@
dataclass
class
KTODataCollatorWithPadding
(
MultiModalDataCollatorForSeq2Seq
):
class
KTODataCollatorWithPadding
(
MultiModalDataCollatorForSeq2Seq
):
r
"""
r
"""Data collator for KTO data."""
Data collator for KTO data.
"""
def
__call__
(
self
,
features
:
Sequence
[
D
ict
[
str
,
Any
]])
->
D
ict
[
str
,
"torch.Tensor"
]:
def
__call__
(
self
,
features
:
list
[
d
ict
[
str
,
Any
]])
->
d
ict
[
str
,
"torch.Tensor"
]:
target_features
=
[]
target_features
=
[]
kl_features
=
[]
kl_features
=
[]
kto_tags
=
[]
kto_tags
=
[]
...
...
src/llamafactory/data/converter.py
View file @
7ea81099
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
os
import
os
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Type
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
..extras
import
logging
from
..extras
import
logging
from
.data_utils
import
Role
from
.data_utils
import
Role
...
@@ -26,8 +26,12 @@ if TYPE_CHECKING:
...
@@ -26,8 +26,12 @@ if TYPE_CHECKING:
from
transformers
import
Seq2SeqTrainingArguments
from
transformers
import
Seq2SeqTrainingArguments
from
..hparams
import
DataArguments
from
..hparams
import
DataArguments
from
.mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
from
.parser
import
DatasetAttr
from
.parser
import
DatasetAttr
MediaType
=
Union
[
ImageInput
,
VideoInput
,
AudioInput
]
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
...
@@ -36,12 +40,12 @@ class DatasetConverter:
...
@@ -36,12 +40,12 @@ class DatasetConverter:
dataset_attr
:
"DatasetAttr"
dataset_attr
:
"DatasetAttr"
data_args
:
"DataArguments"
data_args
:
"DataArguments"
def
_find_medias
(
self
,
medias
:
Union
[
Any
,
Sequence
[
Any
]
])
->
Optional
[
L
ist
[
Any
]]:
def
_find_medias
(
self
,
medias
:
Union
[
"MediaType"
,
list
[
"MediaType"
],
None
])
->
Optional
[
l
ist
[
"MediaType"
]]:
r
"""
r
"""
Optionally concatenate media path to media dir when loading from local disk."""
Optionally concatenates media path to media dir when loading from local disk.
if
medias
is
None
:
"""
return
None
if
not
isinstance
(
medias
,
list
):
el
if
not
isinstance
(
medias
,
list
):
medias
=
[
medias
]
if
medias
is
not
None
else
[]
medias
=
[
medias
]
elif
len
(
medias
)
==
0
:
elif
len
(
medias
)
==
0
:
return
None
return
None
else
:
else
:
...
@@ -57,16 +61,14 @@ class DatasetConverter:
...
@@ -57,16 +61,14 @@ class DatasetConverter:
return
medias
return
medias
@
abstractmethod
@
abstractmethod
def
__call__
(
self
,
example
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
def
__call__
(
self
,
example
:
dict
[
str
,
Any
])
->
dict
[
str
,
Any
]:
r
"""
r
"""Convert a single example in the dataset to the standard format."""
Converts a single example in the dataset to the standard format.
"""
...
...
@
dataclass
@
dataclass
class
AlpacaDatasetConverter
(
DatasetConverter
):
class
AlpacaDatasetConverter
(
DatasetConverter
):
def
__call__
(
self
,
example
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
__call__
(
self
,
example
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
prompt
=
[]
prompt
=
[]
if
self
.
dataset_attr
.
history
and
isinstance
(
example
[
self
.
dataset_attr
.
history
],
list
):
if
self
.
dataset_attr
.
history
and
isinstance
(
example
[
self
.
dataset_attr
.
history
],
list
):
for
old_prompt
,
old_response
in
example
[
self
.
dataset_attr
.
history
]:
for
old_prompt
,
old_response
in
example
[
self
.
dataset_attr
.
history
]:
...
@@ -116,7 +118,7 @@ class AlpacaDatasetConverter(DatasetConverter):
...
@@ -116,7 +118,7 @@ class AlpacaDatasetConverter(DatasetConverter):
@
dataclass
@
dataclass
class
SharegptDatasetConverter
(
DatasetConverter
):
class
SharegptDatasetConverter
(
DatasetConverter
):
def
__call__
(
self
,
example
:
D
ict
[
str
,
Any
])
->
D
ict
[
str
,
Any
]:
def
__call__
(
self
,
example
:
d
ict
[
str
,
Any
])
->
d
ict
[
str
,
Any
]:
tag_mapping
=
{
tag_mapping
=
{
self
.
dataset_attr
.
user_tag
:
Role
.
USER
.
value
,
self
.
dataset_attr
.
user_tag
:
Role
.
USER
.
value
,
self
.
dataset_attr
.
assistant_tag
:
Role
.
ASSISTANT
.
value
,
self
.
dataset_attr
.
assistant_tag
:
Role
.
ASSISTANT
.
value
,
...
@@ -216,10 +218,8 @@ DATASET_CONVERTERS = {
...
@@ -216,10 +218,8 @@ DATASET_CONVERTERS = {
}
}
def
register_dataset_converter
(
name
:
str
,
dataset_converter
:
Type
[
"DatasetConverter"
])
->
None
:
def
register_dataset_converter
(
name
:
str
,
dataset_converter
:
type
[
"DatasetConverter"
])
->
None
:
r
"""
r
"""Register a new dataset converter."""
Register a new dataset converter.
"""
if
name
in
DATASET_CONVERTERS
:
if
name
in
DATASET_CONVERTERS
:
raise
ValueError
(
f
"Dataset converter
{
name
}
already exists."
)
raise
ValueError
(
f
"Dataset converter
{
name
}
already exists."
)
...
@@ -227,9 +227,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver
...
@@ -227,9 +227,7 @@ def register_dataset_converter(name: str, dataset_converter: Type["DatasetConver
def
get_dataset_converter
(
name
:
str
,
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
)
->
"DatasetConverter"
:
def
get_dataset_converter
(
name
:
str
,
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
)
->
"DatasetConverter"
:
r
"""
r
"""Get a dataset converter."""
Gets a dataset converter.
"""
if
name
not
in
DATASET_CONVERTERS
:
if
name
not
in
DATASET_CONVERTERS
:
raise
ValueError
(
f
"Dataset converter
{
name
}
not found."
)
raise
ValueError
(
f
"Dataset converter
{
name
}
not found."
)
...
@@ -242,17 +240,17 @@ def align_dataset(
...
@@ -242,17 +240,17 @@ def align_dataset(
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""
r
"""Align the dataset to a specific format.
Aligned dataset:
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_system: "..."
_tools: "..."
,
_tools: "..."
_images: []
,
_images: []
_videos: []
,
_videos: []
_audios: []
,
_audios: []
"""
"""
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
kwargs
=
{}
kwargs
=
{}
if
not
data_args
.
streaming
:
if
not
data_args
.
streaming
:
...
...
src/llamafactory/data/data_utils.py
View file @
7ea81099
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
enum
import
Enum
,
unique
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
TypedDict
,
Union
from
typing
import
TYPE_CHECKING
,
Optional
,
TypedDict
,
Union
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
...
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
...
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
SLOTS
=
Sequence
[
Union
[
str
,
S
et
[
str
],
D
ict
[
str
,
str
]]]
SLOTS
=
list
[
Union
[
str
,
s
et
[
str
],
d
ict
[
str
,
str
]]]
@
unique
@
unique
...
@@ -43,15 +43,13 @@ class Role(str, Enum):
...
@@ -43,15 +43,13 @@ class Role(str, Enum):
class
DatasetModule
(
TypedDict
):
class
DatasetModule
(
TypedDict
):
train_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]
train_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
D
ict
[
str
,
"Dataset"
]]]
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
d
ict
[
str
,
"Dataset"
]]]
def
merge_dataset
(
def
merge_dataset
(
all_datasets
:
L
ist
[
Union
[
"Dataset"
,
"IterableDataset"
]],
data_args
:
"DataArguments"
,
seed
:
int
all_datasets
:
l
ist
[
Union
[
"Dataset"
,
"IterableDataset"
]],
data_args
:
"DataArguments"
,
seed
:
int
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""
r
"""Merge multiple datasets to a unified dataset."""
Merges multiple datasets to a unified dataset.
"""
if
len
(
all_datasets
)
==
1
:
if
len
(
all_datasets
)
==
1
:
return
all_datasets
[
0
]
return
all_datasets
[
0
]
...
@@ -78,14 +76,13 @@ def merge_dataset(
...
@@ -78,14 +76,13 @@ def merge_dataset(
def
split_dataset
(
def
split_dataset
(
dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]],
dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]],
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
D
ict
[
str
,
"Dataset"
]]],
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
d
ict
[
str
,
"Dataset"
]]],
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
seed
:
int
,
seed
:
int
,
)
->
"DatasetDict"
:
)
->
"DatasetDict"
:
r
"""
r
"""Split the dataset and returns a dataset dict containing train set and validation set.
Splits the dataset and returns a dataset dict containing train set and validation set.
Support
s
both map dataset and iterable dataset.
Support both map dataset and iterable dataset.
"""
"""
if
eval_dataset
is
not
None
and
data_args
.
val_size
>
1e-6
:
if
eval_dataset
is
not
None
and
data_args
.
val_size
>
1e-6
:
raise
ValueError
(
"Cannot specify `val_size` if `eval_dataset` is not None."
)
raise
ValueError
(
"Cannot specify `val_size` if `eval_dataset` is not None."
)
...
@@ -120,10 +117,8 @@ def split_dataset(
...
@@ -120,10 +117,8 @@ def split_dataset(
def
get_dataset_module
(
dataset
:
Union
[
"Dataset"
,
"DatasetDict"
])
->
"DatasetModule"
:
def
get_dataset_module
(
dataset
:
Union
[
"Dataset"
,
"DatasetDict"
])
->
"DatasetModule"
:
r
"""
r
"""Convert dataset or dataset dict to dataset module."""
Converts dataset or dataset dict to dataset module.
dataset_module
:
DatasetModule
=
{}
"""
dataset_module
:
"DatasetModule"
=
{}
if
isinstance
(
dataset
,
DatasetDict
):
# dataset dict
if
isinstance
(
dataset
,
DatasetDict
):
# dataset dict
if
"train"
in
dataset
:
if
"train"
in
dataset
:
dataset_module
[
"train_dataset"
]
=
dataset
[
"train"
]
dataset_module
[
"train_dataset"
]
=
dataset
[
"train"
]
...
...
src/llamafactory/data/formatter.py
View file @
7ea81099
...
@@ -16,7 +16,7 @@ import json
...
@@ -16,7 +16,7 @@ import json
import
re
import
re
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
,
Union
from
typing
import
Optional
,
Union
from
typing_extensions
import
override
from
typing_extensions
import
override
...
@@ -31,14 +31,11 @@ class Formatter(ABC):
...
@@ -31,14 +31,11 @@ class Formatter(ABC):
@
abstractmethod
@
abstractmethod
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
r
"""
r
"""Forms a list of slots according to the inputs to encode."""
Forms a list of slots according to the inputs to encode.
"""
...
...
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
list
[
"FunctionCall"
]]:
r
"""
r
"""Extract a list of tuples from the response message if using tools.
Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
Each tuple consists of function name and function arguments.
"""
"""
...
@@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter):
...
@@ -105,7 +102,7 @@ class FunctionFormatter(StringFormatter):
if
thought
:
if
thought
:
content
=
content
.
replace
(
thought
.
group
(
0
),
""
)
content
=
content
.
replace
(
thought
.
group
(
0
),
""
)
functions
:
L
ist
[
"
FunctionCall
"
]
=
[]
functions
:
l
ist
[
FunctionCall
]
=
[]
try
:
try
:
tool_calls
=
json
.
loads
(
content
)
tool_calls
=
json
.
loads
(
content
)
if
not
isinstance
(
tool_calls
,
list
):
# parallel function call
if
not
isinstance
(
tool_calls
,
list
):
# parallel function call
...
@@ -141,5 +138,5 @@ class ToolFormatter(Formatter):
...
@@ -141,5 +138,5 @@ class ToolFormatter(Formatter):
raise
RuntimeError
(
f
"Invalid JSON format in tool description:
{
str
([
content
])
}
."
)
# flat string
raise
RuntimeError
(
f
"Invalid JSON format in tool description:
{
str
([
content
])
}
."
)
# flat string
@
override
@
override
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
L
ist
[
"FunctionCall"
]]:
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
l
ist
[
"FunctionCall"
]]:
return
self
.
tool_utils
.
tool_extractor
(
content
)
return
self
.
tool_utils
.
tool_extractor
(
content
)
src/llamafactory/data/loader.py
View file @
7ea81099
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
from
typing
import
TYPE_CHECKING
,
Dict
,
Literal
,
Optional
,
Sequence
,
Union
from
typing
import
TYPE_CHECKING
,
Literal
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
from
datasets
import
load_dataset
,
load_from_disk
from
datasets
import
load_dataset
,
load_from_disk
...
@@ -54,9 +54,7 @@ def _load_single_dataset(
...
@@ -54,9 +54,7 @@ def _load_single_dataset(
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""
r
"""Load a single dataset and aligns it to the standard format."""
Loads a single dataset and aligns it to the standard format.
"""
logger
.
info_rank0
(
f
"Loading dataset
{
dataset_attr
}
..."
)
logger
.
info_rank0
(
f
"Loading dataset
{
dataset_attr
}
..."
)
data_path
,
data_name
,
data_dir
,
data_files
=
None
,
None
,
None
,
None
data_path
,
data_name
,
data_dir
,
data_files
=
None
,
None
,
None
,
None
if
dataset_attr
.
load_from
in
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
]:
if
dataset_attr
.
load_from
in
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
]:
...
@@ -133,10 +131,12 @@ def _load_single_dataset(
...
@@ -133,10 +131,12 @@ def _load_single_dataset(
split
=
dataset_attr
.
split
,
split
=
dataset_attr
.
split
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
token
=
model_args
.
hf_hub_token
,
token
=
model_args
.
hf_hub_token
,
streaming
=
data_args
.
streaming
,
num_proc
=
data_args
.
preprocessing_num_workers
,
num_proc
=
data_args
.
preprocessing_num_workers
,
trust_remote_code
=
model_args
.
trust_remote_code
,
trust_remote_code
=
model_args
.
trust_remote_code
,
streaming
=
data_args
.
streaming
and
dataset_attr
.
load_from
!=
"file"
,
)
)
if
data_args
.
streaming
and
dataset_attr
.
load_from
==
"file"
:
dataset
=
dataset
.
to_iterable_dataset
(
num_shards
=
training_args
.
dataloader_num_workers
)
if
dataset_attr
.
num_samples
is
not
None
and
not
data_args
.
streaming
:
if
dataset_attr
.
num_samples
is
not
None
and
not
data_args
.
streaming
:
target_num
=
dataset_attr
.
num_samples
target_num
=
dataset_attr
.
num_samples
...
@@ -158,16 +158,14 @@ def _load_single_dataset(
...
@@ -158,16 +158,14 @@ def _load_single_dataset(
def
_get_merged_dataset
(
def
_get_merged_dataset
(
dataset_names
:
Optional
[
Sequence
[
str
]],
dataset_names
:
Optional
[
list
[
str
]],
model_args
:
"ModelArguments"
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
merge
:
bool
=
True
,
merge
:
bool
=
True
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
Dict
[
str
,
"Dataset"
]]]:
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
,
dict
[
str
,
"Dataset"
]]]:
r
"""
r
"""Return the merged datasets in the standard format."""
Returns the merged datasets in the standard format.
"""
if
dataset_names
is
None
:
if
dataset_names
is
None
:
return
None
return
None
...
@@ -192,9 +190,7 @@ def _get_dataset_processor(
...
@@ -192,9 +190,7 @@ def _get_dataset_processor(
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
do_generate
:
bool
=
False
,
do_generate
:
bool
=
False
,
)
->
"DatasetProcessor"
:
)
->
"DatasetProcessor"
:
r
"""
r
"""Return the corresponding dataset processor."""
Returns the corresponding dataset processor.
"""
if
stage
==
"pt"
:
if
stage
==
"pt"
:
dataset_processor_class
=
PretrainDatasetProcessor
dataset_processor_class
=
PretrainDatasetProcessor
elif
stage
==
"sft"
and
not
do_generate
:
elif
stage
==
"sft"
and
not
do_generate
:
...
@@ -236,9 +232,7 @@ def _get_preprocessed_dataset(
...
@@ -236,9 +232,7 @@ def _get_preprocessed_dataset(
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
is_eval
:
bool
=
False
,
is_eval
:
bool
=
False
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]:
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]:
r
"""
r
"""Preprocesses the dataset, including format checking and tokenization."""
Preprocesses the dataset, including format checking and tokenization.
"""
if
dataset
is
None
:
if
dataset
is
None
:
return
None
return
None
...
@@ -284,9 +278,7 @@ def get_dataset(
...
@@ -284,9 +278,7 @@ def get_dataset(
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
)
->
"DatasetModule"
:
)
->
"DatasetModule"
:
r
"""
r
"""Get the train dataset and optionally gets the evaluation dataset."""
Gets the train dataset and optionally gets the evaluation dataset.
"""
# Load tokenized dataset if path exists
# Load tokenized dataset if path exists
if
data_args
.
tokenized_path
is
not
None
:
if
data_args
.
tokenized_path
is
not
None
:
if
has_tokenized_data
(
data_args
.
tokenized_path
):
if
has_tokenized_data
(
data_args
.
tokenized_path
):
...
...
src/llamafactory/data/mm_plugin.py
View file @
7ea81099
This diff is collapsed.
Click to expand it.
src/llamafactory/data/parser.py
View file @
7ea81099
...
@@ -15,9 +15,9 @@
...
@@ -15,9 +15,9 @@
import
json
import
json
import
os
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
from
typing
import
Any
,
Literal
,
Optional
from
transformers.utils
import
cached_file
from
huggingface_hub
import
hf_hub_download
from
..extras.constants
import
DATA_CONFIG
from
..extras.constants
import
DATA_CONFIG
from
..extras.misc
import
use_modelscope
,
use_openmind
from
..extras.misc
import
use_modelscope
,
use_openmind
...
@@ -25,9 +25,7 @@ from ..extras.misc import use_modelscope, use_openmind
...
@@ -25,9 +25,7 @@ from ..extras.misc import use_modelscope, use_openmind
@
dataclass
@
dataclass
class
DatasetAttr
:
class
DatasetAttr
:
r
"""
r
"""Dataset attributes."""
Dataset attributes.
"""
# basic configs
# basic configs
load_from
:
Literal
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
,
"script"
,
"file"
]
load_from
:
Literal
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
,
"script"
,
"file"
]
...
@@ -68,10 +66,10 @@ class DatasetAttr:
...
@@ -68,10 +66,10 @@ class DatasetAttr:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
self
.
dataset_name
return
self
.
dataset_name
def
set_attr
(
self
,
key
:
str
,
obj
:
D
ict
[
str
,
Any
],
default
:
Optional
[
Any
]
=
None
)
->
None
:
def
set_attr
(
self
,
key
:
str
,
obj
:
d
ict
[
str
,
Any
],
default
:
Optional
[
Any
]
=
None
)
->
None
:
setattr
(
self
,
key
,
obj
.
get
(
key
,
default
))
setattr
(
self
,
key
,
obj
.
get
(
key
,
default
))
def
join
(
self
,
attr
:
D
ict
[
str
,
Any
])
->
None
:
def
join
(
self
,
attr
:
d
ict
[
str
,
Any
])
->
None
:
self
.
set_attr
(
"formatting"
,
attr
,
default
=
"alpaca"
)
self
.
set_attr
(
"formatting"
,
attr
,
default
=
"alpaca"
)
self
.
set_attr
(
"ranking"
,
attr
,
default
=
False
)
self
.
set_attr
(
"ranking"
,
attr
,
default
=
False
)
self
.
set_attr
(
"subset"
,
attr
)
self
.
set_attr
(
"subset"
,
attr
)
...
@@ -92,10 +90,8 @@ class DatasetAttr:
...
@@ -92,10 +90,8 @@ class DatasetAttr:
self
.
set_attr
(
tag
,
attr
[
"tags"
])
self
.
set_attr
(
tag
,
attr
[
"tags"
])
def
get_dataset_list
(
dataset_names
:
Optional
[
Sequence
[
str
]],
dataset_dir
:
str
)
->
List
[
"DatasetAttr"
]:
def
get_dataset_list
(
dataset_names
:
Optional
[
list
[
str
]],
dataset_dir
:
str
)
->
list
[
"DatasetAttr"
]:
r
"""
r
"""Get the attributes of the datasets."""
Gets the attributes of the datasets.
"""
if
dataset_names
is
None
:
if
dataset_names
is
None
:
dataset_names
=
[]
dataset_names
=
[]
...
@@ -103,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
...
@@ -103,7 +99,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info
=
None
dataset_info
=
None
else
:
else
:
if
dataset_dir
.
startswith
(
"REMOTE:"
):
if
dataset_dir
.
startswith
(
"REMOTE:"
):
config_path
=
cached_file
(
path_or_
repo_id
=
dataset_dir
[
7
:],
filename
=
DATA_CONFIG
,
repo_type
=
"dataset"
)
config_path
=
hf_hub_download
(
repo_id
=
dataset_dir
[
7
:],
filename
=
DATA_CONFIG
,
repo_type
=
"dataset"
)
else
:
else
:
config_path
=
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
)
config_path
=
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
)
...
@@ -116,7 +112,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
...
@@ -116,7 +112,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_info
=
None
dataset_info
=
None
dataset_list
:
L
ist
[
"
DatasetAttr
"
]
=
[]
dataset_list
:
l
ist
[
DatasetAttr
]
=
[]
for
name
in
dataset_names
:
for
name
in
dataset_names
:
if
dataset_info
is
None
:
# dataset_dir is ONLINE
if
dataset_info
is
None
:
# dataset_dir is ONLINE
if
use_modelscope
():
if
use_modelscope
():
...
...
src/llamafactory/data/processor/__init__.py
View file @
7ea81099
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.feedback
import
FeedbackDatasetProcessor
from
.feedback
import
FeedbackDatasetProcessor
from
.pairwise
import
PairwiseDatasetProcessor
from
.pairwise
import
PairwiseDatasetProcessor
from
.pretrain
import
PretrainDatasetProcessor
from
.pretrain
import
PretrainDatasetProcessor
...
@@ -9,9 +23,9 @@ from .unsupervised import UnsupervisedDatasetProcessor
...
@@ -9,9 +23,9 @@ from .unsupervised import UnsupervisedDatasetProcessor
__all__
=
[
__all__
=
[
"DatasetProcessor"
,
"DatasetProcessor"
,
"FeedbackDatasetProcessor"
,
"FeedbackDatasetProcessor"
,
"PackedSupervisedDatasetProcessor"
,
"PairwiseDatasetProcessor"
,
"PairwiseDatasetProcessor"
,
"PretrainDatasetProcessor"
,
"PretrainDatasetProcessor"
,
"PackedSupervisedDatasetProcessor"
,
"SupervisedDatasetProcessor"
,
"SupervisedDatasetProcessor"
,
"UnsupervisedDatasetProcessor"
,
"UnsupervisedDatasetProcessor"
,
]
]
src/llamafactory/data/processor/feedback.py
View file @
7ea81099
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
...
@@ -30,15 +30,15 @@ logger = logging.get_logger(__name__)
...
@@ -30,15 +30,15 @@ logger = logging.get_logger(__name__)
class
FeedbackDatasetProcessor
(
DatasetProcessor
):
class
FeedbackDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
def
_encode_data_example
(
self
,
self
,
prompt
:
Sequence
[
D
ict
[
str
,
str
]],
prompt
:
list
[
d
ict
[
str
,
str
]],
response
:
Sequence
[
D
ict
[
str
,
str
]],
response
:
list
[
d
ict
[
str
,
str
]],
kl_response
:
Sequence
[
D
ict
[
str
,
str
]],
kl_response
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
images
:
list
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
audios
:
list
[
"AudioInput"
],
)
->
T
uple
[
L
ist
[
int
],
L
ist
[
int
],
L
ist
[
int
],
L
ist
[
int
],
bool
]:
)
->
t
uple
[
l
ist
[
int
],
l
ist
[
int
],
l
ist
[
int
],
l
ist
[
int
],
bool
]:
if
response
[
0
][
"content"
]:
# desired example
if
response
[
0
][
"content"
]:
# desired example
kto_tag
=
True
kto_tag
=
True
messages
=
prompt
+
[
response
[
0
]]
messages
=
prompt
+
[
response
[
0
]]
...
@@ -82,9 +82,9 @@ class FeedbackDatasetProcessor(DatasetProcessor):
...
@@ -82,9 +82,9 @@ class FeedbackDatasetProcessor(DatasetProcessor):
kl_labels
=
[
IGNORE_INDEX
]
*
kl_source_len
+
kl_response_ids
kl_labels
=
[
IGNORE_INDEX
]
*
kl_source_len
+
kl_response_ids
return
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
return
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
def
preprocess_dataset
(
self
,
examples
:
D
ict
[
str
,
L
ist
[
Any
]])
->
D
ict
[
str
,
L
ist
[
Any
]]:
def
preprocess_dataset
(
self
,
examples
:
d
ict
[
str
,
l
ist
[
Any
]])
->
d
ict
[
str
,
l
ist
[
Any
]]:
#
c
reate
unrelated input-output pairs for estimating the KL term by flipping the matched pairs
#
C
reate
s mismatched pairs of prompts and completions for the KL dataset by adding a +1 offset to the order of completions.
kl_response
=
examples
[
"_response"
][
:
:
-
1
]
kl_response
=
[
examples
[
"_response"
][
-
1
]]
+
examples
[
"_response"
][
:
-
1
]
model_inputs
=
defaultdict
(
list
)
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
...
@@ -121,7 +121,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
...
@@ -121,7 +121,7 @@ class FeedbackDatasetProcessor(DatasetProcessor):
return
model_inputs
return
model_inputs
def
print_data_example
(
self
,
example
:
D
ict
[
str
,
L
ist
[
int
]])
->
None
:
def
print_data_example
(
self
,
example
:
d
ict
[
str
,
l
ist
[
int
]])
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
...
...
src/llamafactory/data/processor/pairwise.py
View file @
7ea81099
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
...
@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
...
@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
class
PairwiseDatasetProcessor
(
DatasetProcessor
):
class
PairwiseDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
def
_encode_data_example
(
self
,
self
,
prompt
:
Sequence
[
D
ict
[
str
,
str
]],
prompt
:
list
[
d
ict
[
str
,
str
]],
response
:
Sequence
[
D
ict
[
str
,
str
]],
response
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
images
:
list
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
audios
:
list
[
"AudioInput"
],
)
->
T
uple
[
L
ist
[
int
],
L
ist
[
int
],
L
ist
[
int
],
L
ist
[
int
]]:
)
->
t
uple
[
l
ist
[
int
],
l
ist
[
int
],
l
ist
[
int
],
l
ist
[
int
]]:
chosen_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
chosen_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
0
]],
images
,
videos
,
audios
,
self
.
processor
prompt
+
[
response
[
0
]],
images
,
videos
,
audios
,
self
.
processor
)
)
...
@@ -68,7 +68,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
...
@@ -68,7 +68,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
rejected_labels
=
[
IGNORE_INDEX
]
*
source_len
+
rejected_ids
rejected_labels
=
[
IGNORE_INDEX
]
*
source_len
+
rejected_ids
return
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
return
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
def
preprocess_dataset
(
self
,
examples
:
D
ict
[
str
,
L
ist
[
Any
]])
->
D
ict
[
str
,
L
ist
[
Any
]]:
def
preprocess_dataset
(
self
,
examples
:
d
ict
[
str
,
l
ist
[
Any
]])
->
d
ict
[
str
,
l
ist
[
Any
]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs
=
defaultdict
(
list
)
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
...
@@ -99,7 +99,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
...
@@ -99,7 +99,7 @@ class PairwiseDatasetProcessor(DatasetProcessor):
return
model_inputs
return
model_inputs
def
print_data_example
(
self
,
example
:
D
ict
[
str
,
L
ist
[
int
]])
->
None
:
def
print_data_example
(
self
,
example
:
d
ict
[
str
,
l
ist
[
int
]])
->
None
:
valid_chosen_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"chosen_labels"
]))
valid_chosen_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"chosen_labels"
]))
valid_rejected_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"rejected_labels"
]))
valid_rejected_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"rejected_labels"
]))
print
(
"chosen_input_ids:
\n
{}"
.
format
(
example
[
"chosen_input_ids"
]))
print
(
"chosen_input_ids:
\n
{}"
.
format
(
example
[
"chosen_input_ids"
]))
...
...
src/llamafactory/data/processor/pretrain.py
View file @
7ea81099
# Copyright 202
4
HuggingFace Inc. and the LlamaFactory team.
# Copyright 202
5
HuggingFace Inc. and the LlamaFactory team.
#
#
# This code is inspired by the HuggingFace's transformers library.
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
...
@@ -17,14 +17,14 @@
...
@@ -17,14 +17,14 @@
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
itertools
import
chain
from
itertools
import
chain
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
from
.processor_utils
import
DatasetProcessor
from
.processor_utils
import
DatasetProcessor
@
dataclass
@
dataclass
class
PretrainDatasetProcessor
(
DatasetProcessor
):
class
PretrainDatasetProcessor
(
DatasetProcessor
):
def
preprocess_dataset
(
self
,
examples
:
D
ict
[
str
,
L
ist
[
Any
]])
->
D
ict
[
str
,
L
ist
[
Any
]]:
def
preprocess_dataset
(
self
,
examples
:
d
ict
[
str
,
l
ist
[
Any
]])
->
d
ict
[
str
,
l
ist
[
Any
]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token
=
"<|end_of_text|>"
if
self
.
data_args
.
template
==
"llama3"
else
self
.
tokenizer
.
eos_token
eos_token
=
"<|end_of_text|>"
if
self
.
data_args
.
template
==
"llama3"
else
self
.
tokenizer
.
eos_token
text_examples
=
[
messages
[
0
][
"content"
]
+
eos_token
for
messages
in
examples
[
"_prompt"
]]
text_examples
=
[
messages
[
0
][
"content"
]
+
eos_token
for
messages
in
examples
[
"_prompt"
]]
...
@@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor):
...
@@ -52,6 +52,6 @@ class PretrainDatasetProcessor(DatasetProcessor):
return
result
return
result
def
print_data_example
(
self
,
example
:
D
ict
[
str
,
L
ist
[
int
]])
->
None
:
def
print_data_example
(
self
,
example
:
d
ict
[
str
,
l
ist
[
int
]])
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
src/llamafactory/data/processor/processor_utils.py
View file @
7ea81099
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
bisect
import
bisect
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -27,9 +27,7 @@ if TYPE_CHECKING:
...
@@ -27,9 +27,7 @@ if TYPE_CHECKING:
@
dataclass
@
dataclass
class
DatasetProcessor
(
ABC
):
class
DatasetProcessor
(
ABC
):
r
"""
r
"""A class for data processors."""
A class for data processors.
"""
template
:
"Template"
template
:
"Template"
tokenizer
:
"PreTrainedTokenizer"
tokenizer
:
"PreTrainedTokenizer"
...
@@ -37,32 +35,24 @@ class DatasetProcessor(ABC):
...
@@ -37,32 +35,24 @@ class DatasetProcessor(ABC):
data_args
:
"DataArguments"
data_args
:
"DataArguments"
@
abstractmethod
@
abstractmethod
def
preprocess_dataset
(
self
,
examples
:
Dict
[
str
,
List
[
Any
]])
->
Dict
[
str
,
List
[
Any
]]:
def
preprocess_dataset
(
self
,
examples
:
dict
[
str
,
list
[
Any
]])
->
dict
[
str
,
list
[
Any
]]:
r
"""
r
"""Build model inputs from the examples."""
Builds model inputs from the examples.
"""
...
...
@
abstractmethod
@
abstractmethod
def
print_data_example
(
self
,
example
:
Dict
[
str
,
List
[
int
]])
->
None
:
def
print_data_example
(
self
,
example
:
dict
[
str
,
list
[
int
]])
->
None
:
r
"""
r
"""Print a data example to stdout."""
Print a data example to stdout.
"""
...
...
def
search_for_fit
(
numbers
:
Sequence
[
int
],
capacity
:
int
)
->
int
:
def
search_for_fit
(
numbers
:
list
[
int
],
capacity
:
int
)
->
int
:
r
"""
r
"""Find the index of largest number that fits into the knapsack with the given capacity."""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
index
=
bisect
.
bisect
(
numbers
,
capacity
)
index
=
bisect
.
bisect
(
numbers
,
capacity
)
return
-
1
if
index
==
0
else
(
index
-
1
)
return
-
1
if
index
==
0
else
(
index
-
1
)
def
greedy_knapsack
(
numbers
:
List
[
int
],
capacity
:
int
)
->
List
[
List
[
int
]]:
def
greedy_knapsack
(
numbers
:
list
[
int
],
capacity
:
int
)
->
list
[
list
[
int
]]:
r
"""
r
"""Implement efficient greedy algorithm with binary search for the knapsack problem."""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers
.
sort
()
# sort numbers in ascending order for binary search
numbers
.
sort
()
# sort numbers in ascending order for binary search
knapsacks
=
[]
knapsacks
=
[]
...
@@ -83,10 +73,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
...
@@ -83,10 +73,8 @@ def greedy_knapsack(numbers: List[int], capacity: int) -> List[List[int]]:
return
knapsacks
return
knapsacks
def
infer_seqlen
(
source_len
:
int
,
target_len
:
int
,
cutoff_len
:
int
)
->
Tuple
[
int
,
int
]:
def
infer_seqlen
(
source_len
:
int
,
target_len
:
int
,
cutoff_len
:
int
)
->
tuple
[
int
,
int
]:
r
"""
r
"""Compute the real sequence length after truncation by the cutoff_len."""
Computes the real sequence length after truncation by the cutoff_len.
"""
if
target_len
*
2
<
cutoff_len
:
# truncate source
if
target_len
*
2
<
cutoff_len
:
# truncate source
max_target_len
=
cutoff_len
max_target_len
=
cutoff_len
elif
source_len
*
2
<
cutoff_len
:
# truncate target
elif
source_len
*
2
<
cutoff_len
:
# truncate target
...
...
src/llamafactory/data/processor/supervised.py
View file @
7ea81099
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
...
@@ -32,14 +32,14 @@ logger = logging.get_logger(__name__)
...
@@ -32,14 +32,14 @@ logger = logging.get_logger(__name__)
class
SupervisedDatasetProcessor
(
DatasetProcessor
):
class
SupervisedDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
def
_encode_data_example
(
self
,
self
,
prompt
:
Sequence
[
D
ict
[
str
,
str
]],
prompt
:
list
[
d
ict
[
str
,
str
]],
response
:
Sequence
[
D
ict
[
str
,
str
]],
response
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
images
:
list
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
audios
:
list
[
"AudioInput"
],
)
->
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]:
)
->
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]:
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
response
,
images
,
videos
,
audios
,
self
.
processor
)
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
prompt
+
response
,
images
,
videos
,
audios
,
self
.
processor
)
input_ids
,
labels
=
self
.
template
.
mm_plugin
.
process_token_ids
(
input_ids
,
labels
=
self
.
template
.
mm_plugin
.
process_token_ids
(
[],
[],
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
[],
[],
images
,
videos
,
audios
,
self
.
tokenizer
,
self
.
processor
...
@@ -85,7 +85,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
...
@@ -85,7 +85,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return
input_ids
,
labels
return
input_ids
,
labels
def
preprocess_dataset
(
self
,
examples
:
D
ict
[
str
,
L
ist
[
Any
]])
->
D
ict
[
str
,
L
ist
[
Any
]]:
def
preprocess_dataset
(
self
,
examples
:
d
ict
[
str
,
l
ist
[
Any
]])
->
d
ict
[
str
,
l
ist
[
Any
]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs
=
defaultdict
(
list
)
model_inputs
=
defaultdict
(
list
)
...
@@ -114,7 +114,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
...
@@ -114,7 +114,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
return
model_inputs
return
model_inputs
def
print_data_example
(
self
,
example
:
D
ict
[
str
,
L
ist
[
int
]])
->
None
:
def
print_data_example
(
self
,
example
:
d
ict
[
str
,
l
ist
[
int
]])
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
...
@@ -124,7 +124,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
...
@@ -124,7 +124,7 @@ class SupervisedDatasetProcessor(DatasetProcessor):
@
dataclass
@
dataclass
class
PackedSupervisedDatasetProcessor
(
SupervisedDatasetProcessor
):
class
PackedSupervisedDatasetProcessor
(
SupervisedDatasetProcessor
):
def
preprocess_dataset
(
self
,
examples
:
D
ict
[
str
,
L
ist
[
Any
]])
->
D
ict
[
str
,
L
ist
[
Any
]]:
def
preprocess_dataset
(
self
,
examples
:
d
ict
[
str
,
l
ist
[
Any
]])
->
d
ict
[
str
,
l
ist
[
Any
]]:
# TODO: use `position_ids` to achieve packing
# TODO: use `position_ids` to achieve packing
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
...
@@ -165,7 +165,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
...
@@ -165,7 +165,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
knapsacks
=
greedy_knapsack
(
lengths
,
self
.
data_args
.
cutoff_len
)
knapsacks
=
greedy_knapsack
(
lengths
,
self
.
data_args
.
cutoff_len
)
for
knapsack
in
knapsacks
:
for
knapsack
in
knapsacks
:
packed_input_ids
,
packed_attention_masks
,
packed_labels
=
[],
[],
[]
packed_input_ids
,
packed_attention_masks
,
packed_labels
=
[],
[],
[]
packed_images
,
packed_videos
,
packed_audios
=
[],
[],
[]
packed_images
,
packed_videos
,
packed_audios
,
packed_position_ids
=
[],
[],
[],
[]
for
i
,
length
in
enumerate
(
knapsack
):
for
i
,
length
in
enumerate
(
knapsack
):
index
=
length2indexes
[
length
].
pop
()
index
=
length2indexes
[
length
].
pop
()
packed_input_ids
+=
batch_input_ids
[
index
]
packed_input_ids
+=
batch_input_ids
[
index
]
...
@@ -175,6 +175,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
...
@@ -175,6 +175,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_audios
+=
batch_audios
[
index
]
packed_audios
+=
batch_audios
[
index
]
if
self
.
data_args
.
neat_packing
:
if
self
.
data_args
.
neat_packing
:
packed_attention_masks
+=
[
i
+
1
]
*
len
(
batch_input_ids
[
index
])
# start from 1
packed_attention_masks
+=
[
i
+
1
]
*
len
(
batch_input_ids
[
index
])
# start from 1
packed_position_ids
+=
list
(
range
(
len
(
batch_input_ids
[
index
])))
else
:
else
:
packed_attention_masks
+=
[
1
]
*
len
(
batch_input_ids
[
index
])
packed_attention_masks
+=
[
1
]
*
len
(
batch_input_ids
[
index
])
...
@@ -184,6 +185,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
...
@@ -184,6 +185,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
packed_labels
+=
[
IGNORE_INDEX
]
*
pad_length
packed_labels
+=
[
IGNORE_INDEX
]
*
pad_length
if
self
.
data_args
.
neat_packing
:
if
self
.
data_args
.
neat_packing
:
packed_attention_masks
+=
[
0
]
*
pad_length
packed_attention_masks
+=
[
0
]
*
pad_length
packed_position_ids
+=
[
0
]
*
pad_length
else
:
else
:
packed_attention_masks
+=
[
1
]
*
pad_length
# more efficient flash_attn
packed_attention_masks
+=
[
1
]
*
pad_length
# more efficient flash_attn
...
@@ -196,5 +198,6 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
...
@@ -196,5 +198,6 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor):
model_inputs
[
"images"
].
append
(
packed_images
or
None
)
model_inputs
[
"images"
].
append
(
packed_images
or
None
)
model_inputs
[
"videos"
].
append
(
packed_videos
or
None
)
model_inputs
[
"videos"
].
append
(
packed_videos
or
None
)
model_inputs
[
"audios"
].
append
(
packed_audios
or
None
)
model_inputs
[
"audios"
].
append
(
packed_audios
or
None
)
model_inputs
[
"position_ids"
].
append
(
packed_position_ids
or
None
)
return
model_inputs
return
model_inputs
src/llamafactory/data/processor/unsupervised.py
View file @
7ea81099
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
...extras
import
logging
from
...extras
import
logging
from
..data_utils
import
Role
from
..data_utils
import
Role
...
@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
...
@@ -30,14 +30,14 @@ logger = logging.get_logger(__name__)
class
UnsupervisedDatasetProcessor
(
DatasetProcessor
):
class
UnsupervisedDatasetProcessor
(
DatasetProcessor
):
def
_encode_data_example
(
def
_encode_data_example
(
self
,
self
,
prompt
:
Sequence
[
D
ict
[
str
,
str
]],
prompt
:
list
[
d
ict
[
str
,
str
]],
response
:
Sequence
[
D
ict
[
str
,
str
]],
response
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
images
:
list
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
videos
:
list
[
"VideoInput"
],
audios
:
Sequence
[
"AudioInput"
],
audios
:
list
[
"AudioInput"
],
)
->
T
uple
[
L
ist
[
int
],
L
ist
[
int
]]:
)
->
t
uple
[
l
ist
[
int
],
l
ist
[
int
]]:
if
len
(
response
)
==
1
:
if
len
(
response
)
==
1
:
messages
=
prompt
+
response
messages
=
prompt
+
response
else
:
else
:
...
@@ -56,7 +56,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
...
@@ -56,7 +56,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
labels
=
labels
[:
target_len
]
labels
=
labels
[:
target_len
]
return
input_ids
,
labels
return
input_ids
,
labels
def
preprocess_dataset
(
self
,
examples
:
D
ict
[
str
,
L
ist
[
Any
]])
->
D
ict
[
str
,
L
ist
[
Any
]]:
def
preprocess_dataset
(
self
,
examples
:
d
ict
[
str
,
l
ist
[
Any
]])
->
d
ict
[
str
,
l
ist
[
Any
]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs
=
defaultdict
(
list
)
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
...
@@ -84,7 +84,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
...
@@ -84,7 +84,7 @@ class UnsupervisedDatasetProcessor(DatasetProcessor):
return
model_inputs
return
model_inputs
def
print_data_example
(
self
,
example
:
D
ict
[
str
,
L
ist
[
int
]])
->
None
:
def
print_data_example
(
self
,
example
:
d
ict
[
str
,
l
ist
[
int
]])
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"inputs:
\n
{}"
.
format
(
self
.
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment