Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
27a7ad86
Commit
27a7ad86
authored
Oct 14, 2024
by
luopl
Browse files
update to v0.9.1
parent
731cf9b8
Changes
120
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1176 additions
and
481 deletions
+1176
-481
src/llamafactory/api/chat.py
src/llamafactory/api/chat.py
+10
-10
src/llamafactory/chat/base_engine.py
src/llamafactory/chat/base_engine.py
+31
-7
src/llamafactory/chat/chat_model.py
src/llamafactory/chat/chat_model.py
+42
-10
src/llamafactory/chat/hf_engine.py
src/llamafactory/chat/hf_engine.py
+47
-47
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+41
-53
src/llamafactory/cli.py
src/llamafactory/cli.py
+1
-1
src/llamafactory/data/__init__.py
src/llamafactory/data/__init__.py
+7
-1
src/llamafactory/data/aligner.py
src/llamafactory/data/aligner.py
+162
-143
src/llamafactory/data/collator.py
src/llamafactory/data/collator.py
+52
-18
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+7
-2
src/llamafactory/data/formatter.py
src/llamafactory/data/formatter.py
+32
-24
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+23
-7
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+627
-0
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+2
-1
src/llamafactory/data/preprocess.py
src/llamafactory/data/preprocess.py
+2
-1
src/llamafactory/data/processors/feedback.py
src/llamafactory/data/processors/feedback.py
+23
-38
src/llamafactory/data/processors/pairwise.py
src/llamafactory/data/processors/pairwise.py
+21
-42
src/llamafactory/data/processors/pretrain.py
src/llamafactory/data/processors/pretrain.py
+3
-3
src/llamafactory/data/processors/processor_utils.py
src/llamafactory/data/processors/processor_utils.py
+1
-31
src/llamafactory/data/processors/supervised.py
src/llamafactory/data/processors/supervised.py
+42
-42
No files found.
src/llamafactory/api/chat.py
View file @
27a7ad86
...
@@ -16,6 +16,7 @@ import base64
...
@@ -16,6 +16,7 @@ import base64
import
io
import
io
import
json
import
json
import
os
import
os
import
re
import
uuid
import
uuid
from
typing
import
TYPE_CHECKING
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
...
@@ -51,9 +52,8 @@ if is_requests_available():
...
@@ -51,9 +52,8 @@ if is_requests_available():
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
..chat
import
ChatModel
from
..chat
import
ChatModel
from
..data.mm_plugin
import
ImageInput
from
.protocol
import
ChatCompletionRequest
,
ScoreEvaluationRequest
from
.protocol
import
ChatCompletionRequest
,
ScoreEvaluationRequest
...
@@ -69,7 +69,7 @@ ROLE_MAPPING = {
...
@@ -69,7 +69,7 @@ ROLE_MAPPING = {
def
_process_request
(
def
_process_request
(
request
:
"ChatCompletionRequest"
,
request
:
"ChatCompletionRequest"
,
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
Optional
[
str
],
Optional
[
str
],
Optional
[
"
NDArray
"
]]:
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
Optional
[
str
],
Optional
[
str
],
Optional
[
"
ImageInput
"
]]:
logger
.
info
(
"==== request ====
\n
{}"
.
format
(
json
.
dumps
(
dictify
(
request
),
indent
=
2
,
ensure_ascii
=
False
)))
logger
.
info
(
"==== request ====
\n
{}"
.
format
(
json
.
dumps
(
dictify
(
request
),
indent
=
2
,
ensure_ascii
=
False
)))
if
len
(
request
.
messages
)
==
0
:
if
len
(
request
.
messages
)
==
0
:
...
@@ -104,15 +104,14 @@ def _process_request(
...
@@ -104,15 +104,14 @@ def _process_request(
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
input_item
.
text
})
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
input_item
.
text
})
else
:
else
:
image_url
=
input_item
.
image_url
.
url
image_url
=
input_item
.
image_url
.
url
if
image_url
.
startswith
(
"data:image"
):
# base64 image
if
re
.
match
(
r
"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$"
,
image_url
):
# base64 image
image_data
=
base64
.
b64decode
(
image_url
.
split
(
","
,
maxsplit
=
1
)[
1
])
image_stream
=
io
.
BytesIO
(
base64
.
b64decode
(
image_url
.
split
(
","
,
maxsplit
=
1
)[
1
]))
image_path
=
io
.
BytesIO
(
image_data
)
elif
os
.
path
.
isfile
(
image_url
):
# local file
elif
os
.
path
.
isfile
(
image_url
):
# local file
image_
path
=
open
(
image_url
,
"rb"
)
image_
stream
=
open
(
image_url
,
"rb"
)
else
:
# web uri
else
:
# web uri
image_
path
=
requests
.
get
(
image_url
,
stream
=
True
).
raw
image_
stream
=
requests
.
get
(
image_url
,
stream
=
True
).
raw
image
=
Image
.
open
(
image_
path
).
convert
(
"RGB"
)
image
=
Image
.
open
(
image_
stream
).
convert
(
"RGB"
)
else
:
else
:
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
message
.
content
})
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
message
.
content
})
...
@@ -230,8 +229,9 @@ async def create_stream_chat_completion_response(
...
@@ -230,8 +229,9 @@ async def create_stream_chat_completion_response(
async
def
create_score_evaluation_response
(
async
def
create_score_evaluation_response
(
request
:
"ScoreEvaluationRequest"
,
chat_model
:
"ChatModel"
request
:
"ScoreEvaluationRequest"
,
chat_model
:
"ChatModel"
)
->
"ScoreEvaluationResponse"
:
)
->
"ScoreEvaluationResponse"
:
score_id
=
"scoreval-{}"
.
format
(
uuid
.
uuid4
().
hex
)
if
len
(
request
.
messages
)
==
0
:
if
len
(
request
.
messages
)
==
0
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid request"
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid request"
)
scores
=
await
chat_model
.
aget_scores
(
request
.
messages
,
max_length
=
request
.
max_length
)
scores
=
await
chat_model
.
aget_scores
(
request
.
messages
,
max_length
=
request
.
max_length
)
return
ScoreEvaluationResponse
(
model
=
request
.
model
,
scores
=
scores
)
return
ScoreEvaluationResponse
(
id
=
score_id
,
model
=
request
.
model
,
scores
=
scores
)
src/llamafactory/chat/base_engine.py
View file @
27a7ad86
...
@@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
...
@@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Literal, Opti
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
from
vllm
import
AsyncLLMEngine
from
vllm
import
AsyncLLMEngine
from
..data
import
Template
from
..data
import
Template
from
..data.mm_plugin
import
ImageInput
,
VideoInput
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
...
@@ -35,6 +35,12 @@ class Response:
...
@@ -35,6 +35,12 @@ class Response:
class
BaseEngine
(
ABC
):
class
BaseEngine
(
ABC
):
r
"""
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
"""
model
:
Union
[
"PreTrainedModel"
,
"AsyncLLMEngine"
]
model
:
Union
[
"PreTrainedModel"
,
"AsyncLLMEngine"
]
tokenizer
:
"PreTrainedTokenizer"
tokenizer
:
"PreTrainedTokenizer"
can_generate
:
bool
can_generate
:
bool
...
@@ -48,7 +54,11 @@ class BaseEngine(ABC):
...
@@ -48,7 +54,11 @@ class BaseEngine(ABC):
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
...
)
->
None
:
r
"""
Initializes an inference engine.
"""
...
@
abstractmethod
@
abstractmethod
async
def
chat
(
async
def
chat
(
...
@@ -56,9 +66,14 @@ class BaseEngine(ABC):
...
@@ -56,9 +66,14 @@ class BaseEngine(ABC):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
...
)
->
List
[
"Response"
]:
r
"""
Gets a list of responses of the chat model.
"""
...
@
abstractmethod
@
abstractmethod
async
def
stream_chat
(
async
def
stream_chat
(
...
@@ -66,13 +81,22 @@ class BaseEngine(ABC):
...
@@ -66,13 +81,22 @@ class BaseEngine(ABC):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
...
)
->
AsyncGenerator
[
str
,
None
]:
r
"""
Gets the response token-by-token of the chat model.
"""
...
@
abstractmethod
@
abstractmethod
async
def
get_scores
(
async
def
get_scores
(
self
,
self
,
batch_input
:
List
[
str
],
batch_input
:
List
[
str
],
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
float
]:
...
)
->
List
[
float
]:
r
"""
Gets a list of scores of the reward model.
"""
...
src/llamafactory/chat/chat_model.py
View file @
27a7ad86
...
@@ -27,8 +27,7 @@ from .vllm_engine import VllmEngine
...
@@ -27,8 +27,7 @@ from .vllm_engine import VllmEngine
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
..data.mm_plugin
import
ImageInput
,
VideoInput
from
.base_engine
import
BaseEngine
,
Response
from
.base_engine
import
BaseEngine
,
Response
...
@@ -38,8 +37,17 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
...
@@ -38,8 +37,17 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class
ChatModel
:
class
ChatModel
:
r
"""
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def
__init__
(
self
,
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
None
:
def
__init__
(
self
,
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
None
:
model_args
,
data_args
,
finetuning_args
,
generating_args
=
get_infer_args
(
args
)
model_args
,
data_args
,
finetuning_args
,
generating_args
=
get_infer_args
(
args
)
self
.
engine_type
=
model_args
.
infer_backend
if
model_args
.
infer_backend
==
"huggingface"
:
if
model_args
.
infer_backend
==
"huggingface"
:
self
.
engine
:
"BaseEngine"
=
HuggingfaceEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
self
.
engine
:
"BaseEngine"
=
HuggingfaceEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
elif
model_args
.
infer_backend
==
"vllm"
:
elif
model_args
.
infer_backend
==
"vllm"
:
...
@@ -56,10 +64,16 @@ class ChatModel:
...
@@ -56,10 +64,16 @@ class ChatModel:
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
)
->
List
[
"Response"
]:
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
achat
(
messages
,
system
,
tools
,
image
,
**
input_kwargs
),
self
.
_loop
)
r
"""
Gets a list of responses of the chat model.
"""
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
achat
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
),
self
.
_loop
)
return
task
.
result
()
return
task
.
result
()
async
def
achat
(
async
def
achat
(
...
@@ -67,20 +81,28 @@ class ChatModel:
...
@@ -67,20 +81,28 @@ class ChatModel:
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
)
->
List
[
"Response"
]:
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
image
,
**
input_kwargs
)
r
"""
Asynchronously gets a list of responses of the chat model.
"""
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
)
def
stream_chat
(
def
stream_chat
(
self
,
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
Generator
[
str
,
None
,
None
]:
)
->
Generator
[
str
,
None
,
None
]:
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
image
,
**
input_kwargs
)
r
"""
Gets the response token-by-token of the chat model.
"""
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
)
while
True
:
while
True
:
try
:
try
:
task
=
asyncio
.
run_coroutine_threadsafe
(
generator
.
__anext__
(),
self
.
_loop
)
task
=
asyncio
.
run_coroutine_threadsafe
(
generator
.
__anext__
(),
self
.
_loop
)
...
@@ -93,10 +115,14 @@ class ChatModel:
...
@@ -93,10 +115,14 @@ class ChatModel:
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
async
for
new_token
in
self
.
engine
.
stream_chat
(
messages
,
system
,
tools
,
image
,
**
input_kwargs
):
r
"""
Asynchronously gets the response token-by-token of the chat model.
"""
async
for
new_token
in
self
.
engine
.
stream_chat
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
):
yield
new_token
yield
new_token
def
get_scores
(
def
get_scores
(
...
@@ -104,6 +130,9 @@ class ChatModel:
...
@@ -104,6 +130,9 @@ class ChatModel:
batch_input
:
List
[
str
],
batch_input
:
List
[
str
],
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
float
]:
)
->
List
[
float
]:
r
"""
Gets a list of scores of the reward model.
"""
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
aget_scores
(
batch_input
,
**
input_kwargs
),
self
.
_loop
)
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
aget_scores
(
batch_input
,
**
input_kwargs
),
self
.
_loop
)
return
task
.
result
()
return
task
.
result
()
...
@@ -112,6 +141,9 @@ class ChatModel:
...
@@ -112,6 +141,9 @@ class ChatModel:
batch_input
:
List
[
str
],
batch_input
:
List
[
str
],
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
float
]:
)
->
List
[
float
]:
r
"""
Asynchronously gets a list of scores of the reward model.
"""
return
await
self
.
engine
.
get_scores
(
batch_input
,
**
input_kwargs
)
return
await
self
.
engine
.
get_scores
(
batch_input
,
**
input_kwargs
)
...
...
src/llamafactory/chat/hf_engine.py
View file @
27a7ad86
...
@@ -20,8 +20,10 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
...
@@ -20,8 +20,10 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
import
torch
import
torch
from
transformers
import
GenerationConfig
,
TextIteratorStreamer
from
transformers
import
GenerationConfig
,
TextIteratorStreamer
from
typing_extensions
import
override
from
..data
import
get_template_and_fix_tokenizer
from
..data
import
get_template_and_fix_tokenizer
from
..extras.constants
import
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.logging
import
get_logger
from
..extras.logging
import
get_logger
from
..extras.misc
import
get_logits_processor
from
..extras.misc
import
get_logits_processor
from
..model
import
load_model
,
load_tokenizer
from
..model
import
load_model
,
load_tokenizer
...
@@ -29,12 +31,11 @@ from .base_engine import BaseEngine, Response
...
@@ -29,12 +31,11 @@ from .base_engine import BaseEngine, Response
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.image_processing_utils
import
BaseImageProcessor
from
trl
import
PreTrainedModelWrapper
from
trl
import
PreTrainedModelWrapper
from
..data
import
Template
from
..data
import
Template
from
..data.mm_plugin
import
ImageInput
,
VideoInput
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
...
@@ -54,7 +55,7 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -54,7 +55,7 @@ class HuggingfaceEngine(BaseEngine):
self
.
tokenizer
=
tokenizer_module
[
"tokenizer"
]
self
.
tokenizer
=
tokenizer_module
[
"tokenizer"
]
self
.
processor
=
tokenizer_module
[
"processor"
]
self
.
processor
=
tokenizer_module
[
"processor"
]
self
.
tokenizer
.
padding_side
=
"left"
if
self
.
can_generate
else
"right"
self
.
tokenizer
.
padding_side
=
"left"
if
self
.
can_generate
else
"right"
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
.
template
,
data_args
.
tool_format
)
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
model
=
load_model
(
self
.
model
=
load_model
(
self
.
tokenizer
,
model_args
,
finetuning_args
,
is_trainable
=
False
,
add_valuehead
=
(
not
self
.
can_generate
)
self
.
tokenizer
,
model_args
,
finetuning_args
,
is_trainable
=
False
,
add_valuehead
=
(
not
self
.
can_generate
)
)
# must after fixing tokenizer to resize vocab
)
# must after fixing tokenizer to resize vocab
...
@@ -78,31 +79,30 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -78,31 +79,30 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
Tuple
[
Dict
[
str
,
Any
],
int
]:
)
->
Tuple
[
Dict
[
str
,
Any
],
int
]:
if
(
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
]}
processor
is
not
None
if
image
is
not
None
:
and
image
is
not
None
mm_input_dict
.
update
({
"images"
:
[
image
],
"imglens"
:
[
1
]})
and
not
hasattr
(
processor
,
"image_seq_length"
)
if
IMAGE_PLACEHOLDER
not
in
messages
[
0
][
"content"
]:
and
template
.
image_token
not
in
messages
[
0
][
"content"
]
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
+
messages
[
0
][
"content"
]
):
# llava-like models
messages
[
0
][
"content"
]
=
template
.
image_token
+
messages
[
0
][
"content"
]
if
video
is
not
None
:
mm_input_dict
.
update
({
"videos"
:
[
video
],
"vidlens"
:
[
1
]})
if
VIDEO_PLACEHOLDER
not
in
messages
[
0
][
"content"
]:
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
+
messages
[
0
][
"content"
]
messages
=
template
.
mm_plugin
.
process_messages
(
messages
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
generating_args
[
"default_system"
]
system
=
system
or
generating_args
[
"default_system"
]
p
ixel_values
=
None
p
rompt_ids
,
_
=
template
.
encode_oneturn
(
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_ids
,
_
=
template
.
encode_oneturn
(
prompt_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
tokenizer
=
tokenizer
,
messages
=
paired_messages
,
system
=
system
,
tools
=
tools
prompt_ids
,
None
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
tokenizer
,
processor
)
)
if
processor
is
not
None
and
image
is
not
None
:
# add image features
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
batch_feature
=
image_processor
(
image
,
return_tensors
=
"pt"
)
pixel_values
=
batch_feature
.
to
(
model
.
device
)[
"pixel_values"
]
# shape (B, C, H, W)
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
template
.
image_token
)
prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
prompt_ids
prompt_length
=
len
(
prompt_ids
)
prompt_length
=
len
(
prompt_ids
)
inputs
=
torch
.
tensor
([
prompt_ids
],
device
=
model
.
device
)
inputs
=
torch
.
tensor
([
prompt_ids
],
device
=
model
.
device
)
attention_mask
=
torch
.
ones_like
(
inputs
,
dtype
=
torch
.
bool
)
attention_mask
=
torch
.
ones_like
(
inputs
,
dtype
=
torch
.
bool
)
...
@@ -164,8 +164,10 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -164,8 +164,10 @@ class HuggingfaceEngine(BaseEngine):
logits_processor
=
get_logits_processor
(),
logits_processor
=
get_logits_processor
(),
)
)
if
pixel_values
is
not
None
:
mm_inputs
=
template
.
mm_plugin
.
get_mm_inputs
(
**
mm_input_dict
,
seqlens
=
[
prompt_length
],
processor
=
processor
)
gen_kwargs
[
"pixel_values"
]
=
pixel_values
for
key
,
value
in
mm_inputs
.
items
():
value
=
value
if
isinstance
(
value
,
torch
.
Tensor
)
else
torch
.
tensor
(
value
)
gen_kwargs
[
key
]
=
value
.
to
(
model
.
device
)
return
gen_kwargs
,
prompt_length
return
gen_kwargs
,
prompt_length
...
@@ -180,11 +182,12 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -180,11 +182,12 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
List
[
"Response"
]:
)
->
List
[
"Response"
]:
gen_kwargs
,
prompt_length
=
HuggingfaceEngine
.
_process_args
(
gen_kwargs
,
prompt_length
=
HuggingfaceEngine
.
_process_args
(
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
image
,
input_kwargs
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
image
,
video
,
input_kwargs
)
)
generate_output
=
model
.
generate
(
**
gen_kwargs
)
generate_output
=
model
.
generate
(
**
gen_kwargs
)
response_ids
=
generate_output
[:,
prompt_length
:]
response_ids
=
generate_output
[:,
prompt_length
:]
...
@@ -215,11 +218,12 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -215,11 +218,12 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
Callable
[[],
str
]:
)
->
Callable
[[],
str
]:
gen_kwargs
,
_
=
HuggingfaceEngine
.
_process_args
(
gen_kwargs
,
_
=
HuggingfaceEngine
.
_process_args
(
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
image
,
input_kwargs
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
image
,
video
,
input_kwargs
)
)
streamer
=
TextIteratorStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
True
)
streamer
=
TextIteratorStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
True
)
gen_kwargs
[
"streamer"
]
=
streamer
gen_kwargs
[
"streamer"
]
=
streamer
...
@@ -242,37 +246,28 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -242,37 +246,28 @@ class HuggingfaceEngine(BaseEngine):
batch_input
:
List
[
str
],
batch_input
:
List
[
str
],
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
List
[
float
]:
)
->
List
[
float
]:
max_length
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
device
=
getattr
(
model
.
pretrained_model
,
"device"
,
"cuda"
)
device
=
getattr
(
model
.
pretrained_model
,
"device"
,
"cuda"
)
inputs
=
tokenizer
(
inputs
:
Dict
[
str
,
"torch.Tensor"
]
=
tokenizer
(
batch_input
,
batch_input
,
padding
=
True
,
padding
=
True
,
truncation
=
True
,
truncation
=
True
,
max_length
=
max_length
or
getattr
(
model
.
config
,
"max_position_embeddings"
,
1024
),
max_length
=
max_length
or
getattr
(
model
.
config
,
"max_position_embeddings"
,
1024
),
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
add_special_tokens
=
Tru
e
,
add_special_tokens
=
Fals
e
,
).
to
(
device
)
).
to
(
device
)
values
:
"torch.Tensor"
=
model
(
**
inputs
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
input_ids
:
torch
.
Tensor
=
inputs
[
"input_ids"
]
scores
=
values
.
gather
(
dim
=-
1
,
index
=
(
inputs
[
"attention_mask"
].
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
_
,
_
,
values
=
model
(
**
inputs
,
output_hidden_states
=
True
,
return_dict
=
True
)
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"chatglm"
:
values
=
torch
.
transpose
(
values
,
0
,
1
)
scores
=
[]
for
i
in
range
(
input_ids
.
size
(
0
)):
end_indexes
=
(
input_ids
[
i
]
!=
tokenizer
.
pad_token_id
).
nonzero
()
end_index
=
end_indexes
[
-
1
].
item
()
if
len
(
end_indexes
)
else
0
scores
.
append
(
values
[
i
,
end_index
].
nan_to_num
().
item
())
return
scores
return
scores
@
override
async
def
chat
(
async
def
chat
(
self
,
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
)
->
List
[
"Response"
]:
if
not
self
.
can_generate
:
if
not
self
.
can_generate
:
...
@@ -289,18 +284,21 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -289,18 +284,21 @@ class HuggingfaceEngine(BaseEngine):
system
,
system
,
tools
,
tools
,
image
,
image
,
video
,
input_kwargs
,
input_kwargs
,
)
)
async
with
self
.
semaphore
:
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
return
await
loop
.
run_in_executor
(
pool
,
self
.
_chat
,
*
input_args
)
return
await
loop
.
run_in_executor
(
pool
,
self
.
_chat
,
*
input_args
)
@
override
async
def
stream_chat
(
async
def
stream_chat
(
self
,
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
if
not
self
.
can_generate
:
if
not
self
.
can_generate
:
...
@@ -317,6 +315,7 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -317,6 +315,7 @@ class HuggingfaceEngine(BaseEngine):
system
,
system
,
tools
,
tools
,
image
,
image
,
video
,
input_kwargs
,
input_kwargs
,
)
)
async
with
self
.
semaphore
:
async
with
self
.
semaphore
:
...
@@ -328,6 +327,7 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -328,6 +327,7 @@ class HuggingfaceEngine(BaseEngine):
except
StopAsyncIteration
:
except
StopAsyncIteration
:
break
break
@
override
async
def
get_scores
(
async
def
get_scores
(
self
,
self
,
batch_input
:
List
[
str
],
batch_input
:
List
[
str
],
...
...
src/llamafactory/chat/vllm_engine.py
View file @
27a7ad86
...
@@ -15,32 +15,31 @@
...
@@ -15,32 +15,31 @@
import
uuid
import
uuid
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
AsyncIterator
,
Dict
,
List
,
Optional
,
Sequence
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
AsyncIterator
,
Dict
,
List
,
Optional
,
Sequence
,
Union
from
typing_extensions
import
override
from
..data
import
get_template_and_fix_tokenizer
from
..data
import
get_template_and_fix_tokenizer
from
..extras.constants
import
IMAGE_PLACEHOLDER
from
..extras.logging
import
get_logger
from
..extras.logging
import
get_logger
from
..extras.misc
import
get_device_count
from
..extras.misc
import
get_device_count
from
..extras.packages
import
is_
vllm
_available
,
is_vllm_
version_greater_than_0_5
,
is_vllm_version_greater_than_0_5_1
from
..extras.packages
import
is_
pillow
_available
,
is_vllm_
available
from
..model
import
load_config
,
load_tokenizer
from
..model
import
load_config
,
load_tokenizer
from
..model.model_utils.quantization
import
QuantizationMethod
from
..model.model_utils.quantization
import
QuantizationMethod
from
..model.model_utils.visual
import
LlavaMultiModalProjectorForYiVLForVLLM
from
..model.model_utils.visual
import
LlavaMultiModalProjectorForYiVLForVLLM
from
.base_engine
import
BaseEngine
,
Response
from
.base_engine
import
BaseEngine
,
Response
if
is_pillow_available
():
from
PIL
import
Image
from
PIL.Image
import
Image
as
ImageObject
if
is_vllm_available
():
if
is_vllm_available
():
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
,
RequestOutput
,
SamplingParams
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
,
RequestOutput
,
SamplingParams
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
if
is_vllm_version_greater_than_0_5_1
():
pass
elif
is_vllm_version_greater_than_0_5
():
from
vllm.multimodal.image
import
ImagePixelData
else
:
from
vllm.sequence
import
MultiModalData
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
..data.mm_plugin
import
ImageInput
,
VideoInput
from
transformers.image_processing_utils
import
BaseImageProcessor
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
...
@@ -67,7 +66,7 @@ class VllmEngine(BaseEngine):
...
@@ -67,7 +66,7 @@ class VllmEngine(BaseEngine):
self
.
tokenizer
=
tokenizer_module
[
"tokenizer"
]
self
.
tokenizer
=
tokenizer_module
[
"tokenizer"
]
self
.
processor
=
tokenizer_module
[
"processor"
]
self
.
processor
=
tokenizer_module
[
"processor"
]
self
.
tokenizer
.
padding_side
=
"left"
self
.
tokenizer
.
padding_side
=
"left"
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
.
template
,
data_args
.
tool_format
)
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
generating_args
=
generating_args
.
to_dict
()
self
.
generating_args
=
generating_args
.
to_dict
()
engine_args
=
{
engine_args
=
{
...
@@ -85,19 +84,11 @@ class VllmEngine(BaseEngine):
...
@@ -85,19 +84,11 @@ class VllmEngine(BaseEngine):
"max_lora_rank"
:
model_args
.
vllm_max_lora_rank
,
"max_lora_rank"
:
model_args
.
vllm_max_lora_rank
,
}
}
if
model_args
.
visual_inputs
:
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
image_size
=
config
.
vision_config
.
image_size
import
vllm.model_executor.models.llava
patch_size
=
config
.
vision_config
.
patch_size
self
.
image_feature_size
=
(
image_size
//
patch_size
)
**
2
engine_args
[
"image_input_type"
]
=
"pixel_values"
engine_args
[
"image_token_id"
]
=
self
.
tokenizer
.
convert_tokens_to_ids
(
self
.
template
.
image_token
)
engine_args
[
"image_input_shape"
]
=
"1,3,{},{}"
.
format
(
image_size
,
image_size
)
engine_args
[
"image_feature_size"
]
=
self
.
image_feature_size
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
import
vllm.model_executor.models.llava
logger
.
info
(
"Detected Yi-VL model, applying projector patch."
)
logger
.
info
(
"Detected Yi-VL model, applying projector patch."
)
vllm
.
model_executor
.
models
.
llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVLForVLLM
vllm
.
model_executor
.
models
.
llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVLForVLLM
self
.
model
=
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
**
engine_args
))
self
.
model
=
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
**
engine_args
))
if
model_args
.
adapter_name_or_path
is
not
None
:
if
model_args
.
adapter_name_or_path
is
not
None
:
...
@@ -110,37 +101,18 @@ class VllmEngine(BaseEngine):
...
@@ -110,37 +101,18 @@ class VllmEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncIterator
[
"RequestOutput"
]:
)
->
AsyncIterator
[
"RequestOutput"
]:
request_id
=
"chatcmpl-{}"
.
format
(
uuid
.
uuid4
().
hex
)
request_id
=
"chatcmpl-{}"
.
format
(
uuid
.
uuid4
().
hex
)
if
image
is
not
None
:
if
(
if
IMAGE_PLACEHOLDER
not
in
messages
[
0
][
"content"
]:
self
.
processor
is
not
None
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
+
messages
[
0
][
"content"
]
and
image
is
not
None
and
not
hasattr
(
self
.
processor
,
"image_seq_length"
)
and
self
.
template
.
image_token
not
in
messages
[
0
][
"content"
]
):
# llava-like models (TODO: paligemma models)
messages
[
0
][
"content"
]
=
self
.
template
.
image_token
*
self
.
image_feature_size
+
messages
[
0
][
"content"
]
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
tokenizer
=
self
.
tokenizer
,
messages
=
paired_messages
,
system
=
system
,
tools
=
tools
)
if
self
.
processor
is
not
None
and
image
is
not
None
:
# add image features
image_processor
:
"BaseImageProcessor"
=
getattr
(
self
.
processor
,
"image_processor"
)
pixel_values
=
image_processor
(
image
,
return_tensors
=
"pt"
)[
"pixel_values"
]
if
is_vllm_version_greater_than_0_5_1
():
multi_modal_data
=
{
"image"
:
pixel_values
}
elif
is_vllm_version_greater_than_0_5
():
multi_modal_data
=
ImagePixelData
(
image
=
pixel_values
)
else
:
# TODO: remove vllm 0.4.3 support
multi_modal_data
=
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
pixel_values
)
else
:
multi_modal_data
=
None
prompt_length
=
len
(
prompt_ids
)
prompt_length
=
len
(
prompt_ids
)
use_beam_search
:
bool
=
self
.
generating_args
[
"num_beams"
]
>
1
use_beam_search
:
bool
=
self
.
generating_args
[
"num_beams"
]
>
1
...
@@ -185,6 +157,17 @@ class VllmEngine(BaseEngine):
...
@@ -185,6 +157,17 @@ class VllmEngine(BaseEngine):
skip_special_tokens
=
True
,
skip_special_tokens
=
True
,
)
)
if
image
is
not
None
:
# add image features
if
not
isinstance
(
image
,
(
str
,
ImageObject
)):
raise
ValueError
(
"Expected image input is a path or PIL.Image, but got {}."
.
format
(
type
(
image
)))
if
isinstance
(
image
,
str
):
image
=
Image
.
open
(
image
).
convert
(
"RGB"
)
multi_modal_data
=
{
"image"
:
image
}
else
:
multi_modal_data
=
None
result_generator
=
self
.
model
.
generate
(
result_generator
=
self
.
model
.
generate
(
inputs
=
{
"prompt_token_ids"
:
prompt_ids
,
"multi_modal_data"
:
multi_modal_data
},
inputs
=
{
"prompt_token_ids"
:
prompt_ids
,
"multi_modal_data"
:
multi_modal_data
},
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
...
@@ -193,16 +176,18 @@ class VllmEngine(BaseEngine):
...
@@ -193,16 +176,18 @@ class VllmEngine(BaseEngine):
)
)
return
result_generator
return
result_generator
@
override
async
def
chat
(
async
def
chat
(
self
,
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
)
->
List
[
"Response"
]:
final_output
=
None
final_output
=
None
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
,
**
input_kwargs
)
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
)
async
for
request_output
in
generator
:
async
for
request_output
in
generator
:
final_output
=
request_output
final_output
=
request_output
...
@@ -219,21 +204,24 @@ class VllmEngine(BaseEngine):
...
@@ -219,21 +204,24 @@ class VllmEngine(BaseEngine):
return
results
return
results
@
override
async
def
stream_chat
(
async
def
stream_chat
(
self
,
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
generated_text
=
""
generated_text
=
""
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
,
**
input_kwargs
)
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
)
async
for
result
in
generator
:
async
for
result
in
generator
:
delta_text
=
result
.
outputs
[
0
].
text
[
len
(
generated_text
)
:]
delta_text
=
result
.
outputs
[
0
].
text
[
len
(
generated_text
)
:]
generated_text
=
result
.
outputs
[
0
].
text
generated_text
=
result
.
outputs
[
0
].
text
yield
delta_text
yield
delta_text
@
override
async
def
get_scores
(
async
def
get_scores
(
self
,
self
,
batch_input
:
List
[
str
],
batch_input
:
List
[
str
],
...
...
src/llamafactory/cli.py
View file @
27a7ad86
...
@@ -118,4 +118,4 @@ def main():
...
@@ -118,4 +118,4 @@ def main():
elif
command
==
Command
.
HELP
:
elif
command
==
Command
.
HELP
:
print
(
USAGE
)
print
(
USAGE
)
else
:
else
:
raise
NotImplementedError
(
"Unknown command: {}"
.
format
(
command
))
raise
NotImplementedError
(
"Unknown command: {}
.
"
.
format
(
command
))
src/llamafactory/data/__init__.py
View file @
27a7ad86
...
@@ -12,7 +12,12 @@
...
@@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
.collator
import
KTODataCollatorWithPadding
,
PairwiseDataCollatorWithPadding
,
SFTDataCollatorWith4DAttentionMask
from
.collator
import
(
KTODataCollatorWithPadding
,
MultiModalDataCollatorForSeq2Seq
,
PairwiseDataCollatorWithPadding
,
SFTDataCollatorWith4DAttentionMask
,
)
from
.data_utils
import
Role
,
split_dataset
from
.data_utils
import
Role
,
split_dataset
from
.loader
import
get_dataset
from
.loader
import
get_dataset
from
.template
import
TEMPLATES
,
Template
,
get_template_and_fix_tokenizer
from
.template
import
TEMPLATES
,
Template
,
get_template_and_fix_tokenizer
...
@@ -20,6 +25,7 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
...
@@ -20,6 +25,7 @@ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
__all__
=
[
__all__
=
[
"KTODataCollatorWithPadding"
,
"KTODataCollatorWithPadding"
,
"MultiModalDataCollatorForSeq2Seq"
,
"PairwiseDataCollatorWithPadding"
,
"PairwiseDataCollatorWithPadding"
,
"SFTDataCollatorWith4DAttentionMask"
,
"SFTDataCollatorWith4DAttentionMask"
,
"Role"
,
"Role"
,
...
...
src/llamafactory/data/aligner.py
View file @
27a7ad86
...
@@ -14,9 +14,7 @@
...
@@ -14,9 +14,7 @@
import
os
import
os
from
functools
import
partial
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Union
from
datasets
import
Features
from
..extras.logging
import
get_logger
from
..extras.logging
import
get_logger
from
.data_utils
import
Role
from
.data_utils
import
Role
...
@@ -27,88 +25,117 @@ if TYPE_CHECKING:
...
@@ -27,88 +25,117 @@ if TYPE_CHECKING:
from
transformers
import
Seq2SeqTrainingArguments
from
transformers
import
Seq2SeqTrainingArguments
from
..hparams
import
DataArguments
from
..hparams
import
DataArguments
from
.mm_plugin
import
ImageInput
,
VideoInput
from
.parser
import
DatasetAttr
from
.parser
import
DatasetAttr
logger
=
get_logger
(
__name__
)
logger
=
get_logger
(
__name__
)
def
_convert_images
(
images
:
List
[
Any
],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
)
->
List
[
Any
]:
def
_convert_images
(
images
:
Sequence
[
"ImageInput"
],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
)
->
Optional
[
List
[
"ImageInput"
]]:
r
"""
r
"""
Optionally concatenates image path to dataset dir when loading from local disk.
Optionally concatenates image path to dataset dir when loading from local disk.
"""
"""
outputs
=
[]
if
len
(
images
)
==
0
:
return
None
images
=
images
[:]
if
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
for
i
in
range
(
len
(
images
)):
if
isinstance
(
images
[
i
],
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
dataset_dir
,
images
[
i
])):
images
[
i
]
=
os
.
path
.
join
(
data_args
.
dataset_dir
,
images
[
i
])
return
images
def
_convert_videos
(
videos
:
Sequence
[
"VideoInput"
],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
)
->
Optional
[
List
[
"VideoInput"
]]:
r
"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if
len
(
videos
)
==
0
:
return
None
videos
=
videos
[:]
if
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
if
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
for
image
in
images
:
for
i
in
range
(
len
(
videos
)):
if
isinstance
(
image
,
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
dataset_dir
,
image
)):
if
isinstance
(
videos
[
i
],
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
dataset_dir
,
videos
[
i
])):
outputs
.
append
(
os
.
path
.
join
(
data_args
.
dataset_dir
,
image
))
videos
[
i
]
=
os
.
path
.
join
(
data_args
.
dataset_dir
,
videos
[
i
])
else
:
outputs
.
append
(
image
)
return
output
s
return
video
s
def
convert_alpaca
(
def
convert_alpaca
(
examples
:
Dict
[
str
,
List
[
Any
]],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
example
:
Dict
[
str
,
Any
],
)
->
Dict
[
str
,
List
[
Any
]]:
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
Any
]:
r
"""
r
"""
Converts alpaca format dataset to the standard format.
Converts alpaca format dataset to the standard format.
"""
"""
outputs
=
{
"prompt"
:
[],
"response"
:
[],
"system"
:
[],
"tools"
:
[],
"images"
:
[]}
prompt
=
[]
if
dataset_attr
.
history
and
isinstance
(
example
[
dataset_attr
.
history
],
list
):
for
old_prompt
,
old_response
in
example
[
dataset_attr
.
history
]:
prompt
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
old_prompt
})
prompt
.
append
({
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
old_response
})
query
=
[]
if
dataset_attr
.
prompt
and
example
[
dataset_attr
.
prompt
]:
query
.
append
(
example
[
dataset_attr
.
prompt
])
if
dataset_attr
.
query
and
example
[
dataset_attr
.
query
]:
query
.
append
(
example
[
dataset_attr
.
query
])
prompt
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
"
\n
"
.
join
(
query
)})
# "prompt\nquery"
if
dataset_attr
.
kto_tag
and
isinstance
(
example
[
dataset_attr
.
kto_tag
],
bool
):
# kto example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
dataset_attr
.
response
]}]
if
example
[
dataset_attr
.
kto_tag
]:
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
else
:
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
elif
(
dataset_attr
.
ranking
and
isinstance
(
example
[
dataset_attr
.
chosen
],
str
)
and
isinstance
(
example
[
dataset_attr
.
rejected
],
str
)
):
# pairwise example
response
=
[
{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
dataset_attr
.
chosen
]},
{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
dataset_attr
.
rejected
]},
]
elif
dataset_attr
.
response
and
isinstance
(
example
[
dataset_attr
.
response
],
str
):
# normal example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
dataset_attr
.
response
]}]
else
:
# unsupervised
response
=
[]
convert_images
=
partial
(
_convert_images
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
convert_images
=
partial
(
_convert_images
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
for
i
in
range
(
len
(
examples
[
dataset_attr
.
prompt
])):
convert_videos
=
partial
(
_convert_videos
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
prompt
=
[]
output
=
{
if
dataset_attr
.
history
and
isinstance
(
examples
[
dataset_attr
.
history
][
i
],
list
):
"_prompt"
:
prompt
,
for
old_prompt
,
old_response
in
examples
[
dataset_attr
.
history
][
i
]:
"_response"
:
response
,
prompt
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
old_prompt
})
"_system"
:
example
[
dataset_attr
.
system
]
if
dataset_attr
.
system
else
""
,
prompt
.
append
({
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
old_response
})
"_tools"
:
example
[
dataset_attr
.
tools
]
if
dataset_attr
.
tools
else
""
,
"_images"
:
convert_images
(
example
[
dataset_attr
.
images
])
if
dataset_attr
.
images
else
None
,
content
=
[]
"_videos"
:
convert_videos
(
example
[
dataset_attr
.
videos
])
if
dataset_attr
.
videos
else
None
,
if
dataset_attr
.
prompt
and
examples
[
dataset_attr
.
prompt
][
i
]:
}
content
.
append
(
examples
[
dataset_attr
.
prompt
][
i
])
return
output
if
dataset_attr
.
query
and
examples
[
dataset_attr
.
query
][
i
]:
content
.
append
(
examples
[
dataset_attr
.
query
][
i
])
prompt
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
"
\n
"
.
join
(
content
)})
# "prompt\nquery"
if
dataset_attr
.
kto_tag
and
isinstance
(
examples
[
dataset_attr
.
kto_tag
][
i
],
bool
):
# kto example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
examples
[
dataset_attr
.
response
][
i
]}]
if
examples
[
dataset_attr
.
kto_tag
][
i
]:
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
else
:
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
elif
(
dataset_attr
.
ranking
and
isinstance
(
examples
[
dataset_attr
.
chosen
][
i
],
str
)
and
isinstance
(
examples
[
dataset_attr
.
rejected
][
i
],
str
)
):
# pairwise example
response
=
[
{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
examples
[
dataset_attr
.
chosen
][
i
]},
{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
examples
[
dataset_attr
.
rejected
][
i
]},
]
elif
dataset_attr
.
response
and
isinstance
(
examples
[
dataset_attr
.
response
][
i
],
str
):
# normal example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
examples
[
dataset_attr
.
response
][
i
]}]
else
:
# unsupervised
response
=
[]
outputs
[
"prompt"
].
append
(
prompt
)
outputs
[
"response"
].
append
(
response
)
outputs
[
"system"
].
append
(
examples
[
dataset_attr
.
system
][
i
]
if
dataset_attr
.
system
else
""
)
outputs
[
"tools"
].
append
(
examples
[
dataset_attr
.
tools
][
i
]
if
dataset_attr
.
tools
else
""
)
outputs
[
"images"
].
append
(
convert_images
(
examples
[
dataset_attr
.
images
][
i
])
if
dataset_attr
.
images
else
[])
return
outputs
def
convert_sharegpt
(
def
convert_sharegpt
(
examples
:
Dict
[
str
,
List
[
Any
]],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
example
:
Dict
[
str
,
Any
],
)
->
Dict
[
str
,
List
[
Any
]]:
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
Any
]:
r
"""
r
"""
Converts sharegpt format dataset to the standard format.
Converts sharegpt format dataset to the standard format.
"""
"""
outputs
=
{
"prompt"
:
[],
"response"
:
[],
"system"
:
[],
"tools"
:
[],
"images"
:
[]}
convert_images
=
partial
(
_convert_images
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
tag_mapping
=
{
tag_mapping
=
{
dataset_attr
.
user_tag
:
Role
.
USER
.
value
,
dataset_attr
.
user_tag
:
Role
.
USER
.
value
,
dataset_attr
.
assistant_tag
:
Role
.
ASSISTANT
.
value
,
dataset_attr
.
assistant_tag
:
Role
.
ASSISTANT
.
value
,
...
@@ -119,74 +146,79 @@ def convert_sharegpt(
...
@@ -119,74 +146,79 @@ def convert_sharegpt(
odd_tags
=
(
dataset_attr
.
user_tag
,
dataset_attr
.
observation_tag
)
odd_tags
=
(
dataset_attr
.
user_tag
,
dataset_attr
.
observation_tag
)
even_tags
=
(
dataset_attr
.
assistant_tag
,
dataset_attr
.
function_tag
)
even_tags
=
(
dataset_attr
.
assistant_tag
,
dataset_attr
.
function_tag
)
accept_tags
=
(
odd_tags
,
even_tags
)
accept_tags
=
(
odd_tags
,
even_tags
)
for
i
,
messages
in
enumerate
(
examples
[
dataset_attr
.
messages
]):
messages
=
example
[
dataset_attr
.
messages
]
if
len
(
messages
)
==
0
:
if
(
continue
dataset_attr
.
system_tag
and
len
(
messages
)
!=
0
if
dataset_attr
.
system_tag
and
messages
[
0
][
dataset_attr
.
role_tag
]
==
dataset_attr
.
system_tag
:
and
messages
[
0
][
dataset_attr
.
role_tag
]
==
dataset_attr
.
system_tag
system
=
messages
[
0
][
dataset_attr
.
content_tag
]
):
messages
=
messages
[
1
:]
system
=
messages
[
0
][
dataset_attr
.
content_tag
]
else
:
messages
=
messages
[
1
:]
system
=
examples
[
dataset_attr
.
system
][
i
]
if
dataset_attr
.
system
else
""
else
:
system
=
example
[
dataset_attr
.
system
]
if
dataset_attr
.
system
else
""
aligned_messages
=
[]
aligned_messages
=
[]
broken_data
=
False
broken_data
=
False
for
turn_idx
,
message
in
enumerate
(
messages
):
for
turn_idx
,
message
in
enumerate
(
messages
):
if
message
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
turn_idx
%
2
]:
if
message
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
turn_idx
%
2
]:
logger
.
warning
(
"Invalid role tag in {}."
.
format
(
messages
))
logger
.
warning
(
"Invalid role tag in {}."
.
format
(
messages
))
broken_data
=
True
broken_data
=
True
aligned_messages
.
append
(
aligned_messages
.
append
(
{
"role"
:
tag_mapping
[
message
[
dataset_attr
.
role_tag
]],
"content"
:
message
[
dataset_attr
.
content_tag
]}
{
"role"
:
tag_mapping
[
message
[
dataset_attr
.
role_tag
]],
"content"
:
message
[
dataset_attr
.
content_tag
]}
)
)
if
(
not
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
!=
0
)
or
(
if
(
not
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
!=
0
)
or
(
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
==
0
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
==
0
):
logger
.
warning
(
"Invalid message count in {}."
.
format
(
messages
))
broken_data
=
True
if
dataset_attr
.
kto_tag
and
isinstance
(
example
[
dataset_attr
.
kto_tag
],
bool
):
# kto example
prompt
=
aligned_messages
[:
-
1
]
response
=
aligned_messages
[
-
1
:]
if
example
[
dataset_attr
.
kto_tag
]:
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
else
:
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
elif
(
dataset_attr
.
ranking
and
isinstance
(
example
[
dataset_attr
.
chosen
],
dict
)
and
isinstance
(
example
[
dataset_attr
.
rejected
],
dict
)
):
# pairwise example
chosen
=
example
[
dataset_attr
.
chosen
]
rejected
=
example
[
dataset_attr
.
rejected
]
if
(
chosen
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
or
rejected
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
):
):
logger
.
warning
(
"Invalid
message count
in {}."
.
format
(
messages
))
logger
.
warning
(
"Invalid
role tag
in {}."
.
format
(
[
chosen
,
rejected
]
))
broken_data
=
True
broken_data
=
True
if
dataset_attr
.
kto_tag
and
isinstance
(
examples
[
dataset_attr
.
kto_tag
][
i
],
bool
):
# kto example
prompt
=
aligned_messages
prompt
=
aligned_messages
[:
-
1
]
response
=
[
response
=
aligned_messages
[
-
1
:]
{
"role"
:
tag_mapping
[
chosen
[
dataset_attr
.
role_tag
]],
"content"
:
chosen
[
dataset_attr
.
content_tag
]},
if
examples
[
dataset_attr
.
kto_tag
][
i
]:
{
"role"
:
tag_mapping
[
rejected
[
dataset_attr
.
role_tag
]],
"content"
:
rejected
[
dataset_attr
.
content_tag
]},
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
]
else
:
else
:
# normal example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
prompt
=
aligned_messages
[:
-
1
]
elif
(
response
=
aligned_messages
[
-
1
:]
dataset_attr
.
ranking
and
isinstance
(
examples
[
dataset_attr
.
chosen
][
i
],
dict
)
if
broken_data
:
and
isinstance
(
examples
[
dataset_attr
.
rejected
][
i
],
dict
)
logger
.
warning
(
"Skipping this abnormal example."
)
):
# pairwise example
prompt
,
response
=
[],
[]
chosen
=
examples
[
dataset_attr
.
chosen
][
i
]
rejected
=
examples
[
dataset_attr
.
rejected
][
i
]
convert_images
=
partial
(
_convert_images
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
if
(
convert_videos
=
partial
(
_convert_videos
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
chosen
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
output
=
{
or
rejected
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
"_prompt"
:
prompt
,
):
"_response"
:
response
,
logger
.
warning
(
"Invalid role tag in {}."
.
format
([
chosen
,
rejected
]))
"_system"
:
system
,
broken_data
=
True
"_tools"
:
example
[
dataset_attr
.
tools
]
if
dataset_attr
.
tools
else
""
,
"_images"
:
convert_images
(
example
[
dataset_attr
.
images
])
if
dataset_attr
.
images
else
None
,
prompt
=
aligned_messages
"_videos"
:
convert_videos
(
example
[
dataset_attr
.
videos
])
if
dataset_attr
.
videos
else
None
,
response
=
[
}
{
"role"
:
tag_mapping
[
chosen
[
dataset_attr
.
role_tag
]],
"content"
:
chosen
[
dataset_attr
.
content_tag
]},
return
output
{
"role"
:
tag_mapping
[
rejected
[
dataset_attr
.
role_tag
]],
"content"
:
rejected
[
dataset_attr
.
content_tag
]},
]
else
:
# normal example
prompt
=
aligned_messages
[:
-
1
]
response
=
aligned_messages
[
-
1
:]
if
broken_data
:
logger
.
warning
(
"Skipping this abnormal example."
)
continue
outputs
[
"prompt"
].
append
(
prompt
)
outputs
[
"response"
].
append
(
response
)
outputs
[
"system"
].
append
(
system
)
outputs
[
"tools"
].
append
(
examples
[
dataset_attr
.
tools
][
i
]
if
dataset_attr
.
tools
else
""
)
outputs
[
"images"
].
append
(
convert_images
(
examples
[
dataset_attr
.
images
][
i
])
if
dataset_attr
.
images
else
[])
return
outputs
def
align_dataset
(
def
align_dataset
(
...
@@ -197,11 +229,12 @@ def align_dataset(
...
@@ -197,11 +229,12 @@ def align_dataset(
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""
r
"""
Aligned dataset:
Aligned dataset:
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
_system: "..."
tools: "...",
_tools: "...",
images: [],
_images: [],
_videos: [],
"""
"""
if
dataset_attr
.
formatting
==
"alpaca"
:
if
dataset_attr
.
formatting
==
"alpaca"
:
convert_func
=
partial
(
convert_alpaca
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
convert_func
=
partial
(
convert_alpaca
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
...
@@ -209,19 +242,6 @@ def align_dataset(
...
@@ -209,19 +242,6 @@ def align_dataset(
convert_func
=
partial
(
convert_sharegpt
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
convert_func
=
partial
(
convert_sharegpt
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
features
=
Features
.
from_dict
(
{
"prompt"
:
[
{
"role"
:
{
"dtype"
:
"string"
,
"_type"
:
"Value"
},
"content"
:
{
"dtype"
:
"string"
,
"_type"
:
"Value"
}}
],
"response"
:
[
{
"role"
:
{
"dtype"
:
"string"
,
"_type"
:
"Value"
},
"content"
:
{
"dtype"
:
"string"
,
"_type"
:
"Value"
}}
],
"system"
:
{
"dtype"
:
"string"
,
"_type"
:
"Value"
},
"tools"
:
{
"dtype"
:
"string"
,
"_type"
:
"Value"
},
"images"
:
[{
"_type"
:
"Image"
}],
}
)
kwargs
=
{}
kwargs
=
{}
if
not
data_args
.
streaming
:
if
not
data_args
.
streaming
:
kwargs
=
dict
(
kwargs
=
dict
(
...
@@ -232,8 +252,7 @@ def align_dataset(
...
@@ -232,8 +252,7 @@ def align_dataset(
return
dataset
.
map
(
return
dataset
.
map
(
convert_func
,
convert_func
,
batched
=
Tru
e
,
batched
=
Fals
e
,
remove_columns
=
column_names
,
remove_columns
=
column_names
,
features
=
features
,
**
kwargs
,
**
kwargs
,
)
)
src/llamafactory/data/collator.py
View file @
27a7ad86
...
@@ -16,12 +16,18 @@
...
@@ -16,12 +16,18 @@
# limitations under the License.
# limitations under the License.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Literal
,
Sequence
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Literal
,
Optional
,
Sequence
import
torch
import
torch
from
transformers
import
DataCollatorForSeq2Seq
from
transformers
import
DataCollatorForSeq2Seq
if
TYPE_CHECKING
:
from
transformers
import
ProcessorMixin
from
.template
import
Template
def
prepare_4d_attention_mask
(
attention_mask_with_indices
:
"torch.Tensor"
,
dtype
:
"torch.dtype"
)
->
"torch.Tensor"
:
def
prepare_4d_attention_mask
(
attention_mask_with_indices
:
"torch.Tensor"
,
dtype
:
"torch.dtype"
)
->
"torch.Tensor"
:
r
"""
r
"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
...
@@ -62,7 +68,42 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
...
@@ -62,7 +68,42 @@ def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype
@
dataclass
@
dataclass
class
SFTDataCollatorWith4DAttentionMask
(
DataCollatorForSeq2Seq
):
class
MultiModalDataCollatorForSeq2Seq
(
DataCollatorForSeq2Seq
):
r
"""
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels and images.
"""
template
:
Optional
[
"Template"
]
=
None
processor
:
Optional
[
"ProcessorMixin"
]
=
None
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
"torch.Tensor"
]:
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_seqlens
=
[],
[],
[],
[],
[]
for
feature
in
features
:
images
=
feature
.
pop
(
"images"
,
None
)
or
[]
videos
=
feature
.
pop
(
"videos"
,
None
)
or
[]
batch_images
.
extend
(
images
)
batch_videos
.
extend
(
videos
)
batch_imglens
.
append
(
len
(
images
))
batch_vidlens
.
append
(
len
(
videos
))
batch_seqlens
.
append
(
len
(
feature
[
"input_ids"
]))
mm_inputs
=
self
.
template
.
mm_plugin
.
get_mm_inputs
(
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_seqlens
,
self
.
processor
)
if
"token_type_ids"
in
mm_inputs
:
token_type_ids
=
mm_inputs
.
pop
(
"token_type_ids"
)
for
i
,
feature
in
enumerate
(
features
):
feature
[
"token_type_ids"
]
=
token_type_ids
[
i
]
features
:
Dict
[
str
,
"torch.Tensor"
]
=
super
().
__call__
(
features
)
features
.
update
(
mm_inputs
)
return
features
@
dataclass
class
SFTDataCollatorWith4DAttentionMask
(
MultiModalDataCollatorForSeq2Seq
):
r
"""
r
"""
Data collator for 4d attention mask.
Data collator for 4d attention mask.
"""
"""
...
@@ -80,7 +121,7 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
...
@@ -80,7 +121,7 @@ class SFTDataCollatorWith4DAttentionMask(DataCollatorForSeq2Seq):
@
dataclass
@
dataclass
class
PairwiseDataCollatorWithPadding
(
DataCollatorForSeq2Seq
):
class
PairwiseDataCollatorWithPadding
(
MultiModal
DataCollatorForSeq2Seq
):
r
"""
r
"""
Data collator for pairwise data.
Data collator for pairwise data.
"""
"""
...
@@ -99,20 +140,16 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
...
@@ -99,20 +140,16 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
"input_ids"
:
feature
[
"{}_input_ids"
.
format
(
key
)],
"input_ids"
:
feature
[
"{}_input_ids"
.
format
(
key
)],
"attention_mask"
:
feature
[
"{}_attention_mask"
.
format
(
key
)],
"attention_mask"
:
feature
[
"{}_attention_mask"
.
format
(
key
)],
"labels"
:
feature
[
"{}_labels"
.
format
(
key
)],
"labels"
:
feature
[
"{}_labels"
.
format
(
key
)],
"images"
:
feature
[
"images"
],
"videos"
:
feature
[
"videos"
],
}
}
if
"pixel_values"
in
feature
:
target_feature
[
"pixel_values"
]
=
feature
[
"pixel_values"
]
if
"{}_token_type_ids"
.
format
(
key
)
in
feature
:
target_feature
[
"token_type_ids"
]
=
feature
[
"{}_token_type_ids"
.
format
(
key
)]
concatenated_features
.
append
(
target_feature
)
concatenated_features
.
append
(
target_feature
)
return
super
().
__call__
(
concatenated_features
)
return
super
().
__call__
(
concatenated_features
)
@
dataclass
@
dataclass
class
KTODataCollatorWithPadding
(
DataCollatorForSeq2Seq
):
class
KTODataCollatorWithPadding
(
MultiModal
DataCollatorForSeq2Seq
):
r
"""
r
"""
Data collator for KTO data.
Data collator for KTO data.
"""
"""
...
@@ -126,19 +163,16 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
...
@@ -126,19 +163,16 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
"input_ids"
:
feature
[
"input_ids"
],
"input_ids"
:
feature
[
"input_ids"
],
"attention_mask"
:
feature
[
"attention_mask"
],
"attention_mask"
:
feature
[
"attention_mask"
],
"labels"
:
feature
[
"labels"
],
"labels"
:
feature
[
"labels"
],
"images"
:
feature
[
"images"
],
"videos"
:
feature
[
"videos"
],
}
}
kl_feature
=
{
kl_feature
=
{
"input_ids"
:
feature
[
"kl_input_ids"
],
"input_ids"
:
feature
[
"kl_input_ids"
],
"attention_mask"
:
feature
[
"kl_attention_mask"
],
"attention_mask"
:
feature
[
"kl_attention_mask"
],
"labels"
:
feature
[
"kl_labels"
],
"labels"
:
feature
[
"kl_labels"
],
"images"
:
feature
[
"images"
],
"videos"
:
feature
[
"videos"
],
}
}
if
"pixel_values"
in
feature
:
target_feature
[
"pixel_values"
]
=
feature
[
"pixel_values"
]
if
"token_type_ids"
in
feature
:
target_feature
[
"token_type_ids"
]
=
feature
[
"token_type_ids"
]
kl_feature
[
"token_type_ids"
]
=
feature
[
"kl_token_type_ids"
]
target_features
.
append
(
target_feature
)
target_features
.
append
(
target_feature
)
kl_features
.
append
(
kl_feature
)
kl_features
.
append
(
kl_feature
)
kto_tags
.
append
(
feature
[
"kto_tags"
])
kto_tags
.
append
(
feature
[
"kto_tags"
])
...
@@ -148,7 +182,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
...
@@ -148,7 +182,7 @@ class KTODataCollatorWithPadding(DataCollatorForSeq2Seq):
batch
[
"kl_input_ids"
]
=
kl_batch
[
"input_ids"
]
batch
[
"kl_input_ids"
]
=
kl_batch
[
"input_ids"
]
batch
[
"kl_attention_mask"
]
=
kl_batch
[
"attention_mask"
]
batch
[
"kl_attention_mask"
]
=
kl_batch
[
"attention_mask"
]
batch
[
"kl_labels"
]
=
kl_batch
[
"labels"
]
batch
[
"kl_labels"
]
=
kl_batch
[
"labels"
]
if
"token_type_ids"
in
batch
:
if
"token_type_ids"
in
kl_
batch
:
batch
[
"kl_token_type_ids"
]
=
kl_batch
[
"token_type_ids"
]
batch
[
"kl_token_type_ids"
]
=
kl_batch
[
"token_type_ids"
]
batch
[
"kto_tags"
]
=
torch
.
tensor
(
kto_tags
)
batch
[
"kto_tags"
]
=
torch
.
tensor
(
kto_tags
)
...
...
src/llamafactory/data/data_utils.py
View file @
27a7ad86
...
@@ -49,6 +49,9 @@ class DatasetModule(TypedDict):
...
@@ -49,6 +49,9 @@ class DatasetModule(TypedDict):
def
merge_dataset
(
def
merge_dataset
(
all_datasets
:
List
[
Union
[
"Dataset"
,
"IterableDataset"
]],
data_args
:
"DataArguments"
,
seed
:
int
all_datasets
:
List
[
Union
[
"Dataset"
,
"IterableDataset"
]],
data_args
:
"DataArguments"
,
seed
:
int
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""
Merges multiple datasets to a unified dataset.
"""
if
len
(
all_datasets
)
==
1
:
if
len
(
all_datasets
)
==
1
:
return
all_datasets
[
0
]
return
all_datasets
[
0
]
elif
data_args
.
mix_strategy
==
"concat"
:
elif
data_args
.
mix_strategy
==
"concat"
:
...
@@ -67,14 +70,16 @@ def merge_dataset(
...
@@ -67,14 +70,16 @@ def merge_dataset(
stopping_strategy
=
"first_exhausted"
if
data_args
.
mix_strategy
.
endswith
(
"under"
)
else
"all_exhausted"
,
stopping_strategy
=
"first_exhausted"
if
data_args
.
mix_strategy
.
endswith
(
"under"
)
else
"all_exhausted"
,
)
)
else
:
else
:
raise
ValueError
(
"Unknown mixing strategy
."
)
raise
ValueError
(
"Unknown mixing strategy
: {}."
.
format
(
data_args
.
mix_strategy
)
)
def
split_dataset
(
def
split_dataset
(
dataset
:
Union
[
"Dataset"
,
"IterableDataset"
],
data_args
:
"DataArguments"
,
seed
:
int
dataset
:
Union
[
"Dataset"
,
"IterableDataset"
],
data_args
:
"DataArguments"
,
seed
:
int
)
->
"DatasetDict"
:
)
->
"DatasetDict"
:
r
"""
r
"""
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
"""
"""
if
data_args
.
streaming
:
if
data_args
.
streaming
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
seed
)
dataset
=
dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
seed
)
...
...
src/llamafactory/data/formatter.py
View file @
27a7ad86
...
@@ -16,21 +16,36 @@ import json
...
@@ -16,21 +16,36 @@ import json
import
re
import
re
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Literal
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tuple
,
Union
from
typing_extensions
import
override
from
.data_utils
import
SLOTS
from
.data_utils
import
SLOTS
from
.tool_utils
import
DefaultToolUtils
,
GLM4ToolUtils
from
.tool_utils
import
get_tool_utils
if
TYPE_CHECKING
:
from
.tool_utils
import
FunctionCall
@
dataclass
@
dataclass
class
Formatter
(
ABC
):
class
Formatter
(
ABC
):
slots
:
SLOTS
=
field
(
default_factory
=
list
)
slots
:
SLOTS
=
field
(
default_factory
=
list
)
tool_format
:
Optional
[
Literal
[
"default"
,
"glm4"
]
]
=
None
tool_format
:
Optional
[
str
]
=
None
@
abstractmethod
@
abstractmethod
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
...
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
r
"""
Forms a list of slots according to the inputs to encode.
"""
...
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
r
"""
Extract a list of tuples from the response message if using tools.
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
Tuple
[
str
,
str
]]]:
Each tuple consists of function name and function arguments.
"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
...
@@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
if
has_placeholder
:
if
has_placeholder
:
raise
ValueError
(
"Empty formatter should not contain any placeholder."
)
raise
ValueError
(
"Empty formatter should not contain any placeholder."
)
@
override
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
return
self
.
slots
return
self
.
slots
...
@@ -60,6 +76,7 @@ class StringFormatter(Formatter):
...
@@ -60,6 +76,7 @@ class StringFormatter(Formatter):
if
not
has_placeholder
:
if
not
has_placeholder
:
raise
ValueError
(
"A placeholder is required in the string formatter."
)
raise
ValueError
(
"A placeholder is required in the string formatter."
)
@
override
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
elements
=
[]
elements
=
[]
for
slot
in
self
.
slots
:
for
slot
in
self
.
slots
:
...
@@ -81,13 +98,9 @@ class StringFormatter(Formatter):
...
@@ -81,13 +98,9 @@ class StringFormatter(Formatter):
@
dataclass
@
dataclass
class
FunctionFormatter
(
Formatter
):
class
FunctionFormatter
(
Formatter
):
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tool_format
==
"default"
:
self
.
slots
=
get_tool_utils
(
self
.
tool_format
).
get_function_slots
()
+
self
.
slots
self
.
slots
=
DefaultToolUtils
.
get_function_slots
()
+
self
.
slots
elif
self
.
tool_format
==
"glm4"
:
self
.
slots
=
GLM4ToolUtils
.
get_function_slots
()
+
self
.
slots
else
:
raise
NotImplementedError
(
"Tool format {} was not found."
.
format
(
self
.
tool_format
))
@
override
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
content
=
kwargs
.
pop
(
"content"
)
content
=
kwargs
.
pop
(
"content"
)
functions
:
List
[
Tuple
[
str
,
str
]]
=
[]
functions
:
List
[
Tuple
[
str
,
str
]]
=
[]
...
@@ -100,7 +113,7 @@ class FunctionFormatter(Formatter):
...
@@ -100,7 +113,7 @@ class FunctionFormatter(Formatter):
functions
.
append
((
tool_call
[
"name"
],
json
.
dumps
(
tool_call
[
"arguments"
],
ensure_ascii
=
False
)))
functions
.
append
((
tool_call
[
"name"
],
json
.
dumps
(
tool_call
[
"arguments"
],
ensure_ascii
=
False
)))
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
functions
=
[]
raise
RuntimeError
(
"Invalid JSON format in function message: {}"
.
format
(
str
([
content
])))
# flat string
elements
=
[]
elements
=
[]
for
name
,
arguments
in
functions
:
for
name
,
arguments
in
functions
:
...
@@ -119,22 +132,17 @@ class FunctionFormatter(Formatter):
...
@@ -119,22 +132,17 @@ class FunctionFormatter(Formatter):
@
dataclass
@
dataclass
class
ToolFormatter
(
Formatter
):
class
ToolFormatter
(
Formatter
):
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tool_format
==
"default"
:
self
.
tool_utils
=
get_tool_utils
(
self
.
tool_format
)
self
.
_tool_formatter
=
DefaultToolUtils
.
tool_formatter
self
.
_tool_extractor
=
DefaultToolUtils
.
tool_extractor
elif
self
.
tool_format
==
"glm4"
:
self
.
_tool_formatter
=
GLM4ToolUtils
.
tool_formatter
self
.
_tool_extractor
=
GLM4ToolUtils
.
tool_extractor
else
:
raise
NotImplementedError
(
"Tool format {} was not found."
.
format
(
self
.
tool_format
))
@
override
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
content
=
kwargs
.
pop
(
"content"
)
content
=
kwargs
.
pop
(
"content"
)
try
:
try
:
tools
=
json
.
loads
(
content
)
tools
=
json
.
loads
(
content
)
return
[
self
.
_
tool_formatter
(
tools
)
if
len
(
tools
)
!=
0
else
""
]
return
[
self
.
tool_utils
.
tool_formatter
(
tools
)
if
len
(
tools
)
!=
0
else
""
]
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
r
eturn
[
""
]
r
aise
RuntimeError
(
"Invalid JSON format in tool description: {}"
.
format
(
str
([
content
])))
# flat string
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
Tuple
[
str
,
str
]]]:
@
override
return
self
.
_tool_extractor
(
content
)
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
return
self
.
tool_utils
.
tool_extractor
(
content
)
src/llamafactory/data/loader.py
View file @
27a7ad86
...
@@ -27,7 +27,6 @@ from .aligner import align_dataset
...
@@ -27,7 +27,6 @@ from .aligner import align_dataset
from
.data_utils
import
merge_dataset
,
split_dataset
from
.data_utils
import
merge_dataset
,
split_dataset
from
.parser
import
get_dataset_list
from
.parser
import
get_dataset_list
from
.preprocess
import
get_preprocess_and_print_func
from
.preprocess
import
get_preprocess_and_print_func
from
.template
import
get_template_and_fix_tokenizer
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -49,6 +48,9 @@ def _load_single_dataset(
...
@@ -49,6 +48,9 @@ def _load_single_dataset(
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""
Loads a single dataset and aligns it to the standard format.
"""
logger
.
info
(
"Loading dataset {}..."
.
format
(
dataset_attr
))
logger
.
info
(
"Loading dataset {}..."
.
format
(
dataset_attr
))
data_path
,
data_name
,
data_dir
,
data_files
=
None
,
None
,
None
,
None
data_path
,
data_name
,
data_dir
,
data_files
=
None
,
None
,
None
,
None
if
dataset_attr
.
load_from
in
[
"hf_hub"
,
"ms_hub"
]:
if
dataset_attr
.
load_from
in
[
"hf_hub"
,
"ms_hub"
]:
...
@@ -118,7 +120,7 @@ def _load_single_dataset(
...
@@ -118,7 +120,7 @@ def _load_single_dataset(
if
dataset_attr
.
num_samples
is
not
None
and
not
data_args
.
streaming
:
if
dataset_attr
.
num_samples
is
not
None
and
not
data_args
.
streaming
:
target_num
=
dataset_attr
.
num_samples
target_num
=
dataset_attr
.
num_samples
indexes
=
np
.
random
.
permutation
(
len
(
dataset
))[:
target_num
]
indexes
=
np
.
random
.
permutation
(
len
(
dataset
))[:
target_num
]
# all samples should be included
target_num
-=
len
(
indexes
)
target_num
-=
len
(
indexes
)
if
target_num
>
0
:
if
target_num
>
0
:
expand_indexes
=
np
.
random
.
choice
(
len
(
dataset
),
target_num
)
expand_indexes
=
np
.
random
.
choice
(
len
(
dataset
),
target_num
)
...
@@ -142,6 +144,9 @@ def _get_merged_dataset(
...
@@ -142,6 +144,9 @@ def _get_merged_dataset(
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]:
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]:
r
"""
Gets the merged datasets in the standard format.
"""
if
dataset_names
is
None
:
if
dataset_names
is
None
:
return
None
return
None
...
@@ -165,6 +170,9 @@ def _get_preprocessed_dataset(
...
@@ -165,6 +170,9 @@ def _get_preprocessed_dataset(
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
is_eval
:
bool
=
False
,
is_eval
:
bool
=
False
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]:
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]:
r
"""
Preprocesses the dataset, including format checking and tokenization.
"""
if
dataset
is
None
:
if
dataset
is
None
:
return
None
return
None
...
@@ -180,7 +188,13 @@ def _get_preprocessed_dataset(
...
@@ -180,7 +188,13 @@ def _get_preprocessed_dataset(
desc
=
"Running tokenizer on dataset"
,
desc
=
"Running tokenizer on dataset"
,
)
)
dataset
=
dataset
.
map
(
preprocess_func
,
batched
=
True
,
remove_columns
=
column_names
,
**
kwargs
)
dataset
=
dataset
.
map
(
preprocess_func
,
batched
=
True
,
batch_size
=
data_args
.
preprocessing_batch_size
,
remove_columns
=
column_names
,
**
kwargs
,
)
if
training_args
.
should_log
:
if
training_args
.
should_log
:
try
:
try
:
...
@@ -196,6 +210,7 @@ def _get_preprocessed_dataset(
...
@@ -196,6 +210,7 @@ def _get_preprocessed_dataset(
def
get_dataset
(
def
get_dataset
(
template
:
"Template"
,
model_args
:
"ModelArguments"
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
...
@@ -203,10 +218,9 @@ def get_dataset(
...
@@ -203,10 +218,9 @@ def get_dataset(
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
)
->
"DatasetModule"
:
)
->
"DatasetModule"
:
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
.
template
,
data_args
.
tool_format
)
r
"""
if
data_args
.
train_on_prompt
and
template
.
efficient_eos
:
Gets the train dataset and optionally gets the evaluation dataset.
raise
ValueError
(
"Current template does not support `train_on_prompt`."
)
"""
# Load tokenized dataset
# Load tokenized dataset
if
data_args
.
tokenized_path
is
not
None
:
if
data_args
.
tokenized_path
is
not
None
:
if
has_tokenized_data
(
data_args
.
tokenized_path
):
if
has_tokenized_data
(
data_args
.
tokenized_path
):
...
@@ -217,6 +231,7 @@ def get_dataset(
...
@@ -217,6 +231,7 @@ def get_dataset(
dataset_module
:
Dict
[
str
,
"Dataset"
]
=
{}
dataset_module
:
Dict
[
str
,
"Dataset"
]
=
{}
if
"train"
in
dataset_dict
:
if
"train"
in
dataset_dict
:
dataset_module
[
"train_dataset"
]
=
dataset_dict
[
"train"
]
dataset_module
[
"train_dataset"
]
=
dataset_dict
[
"train"
]
if
"validation"
in
dataset_dict
:
if
"validation"
in
dataset_dict
:
dataset_module
[
"eval_dataset"
]
=
dataset_dict
[
"validation"
]
dataset_module
[
"eval_dataset"
]
=
dataset_dict
[
"validation"
]
...
@@ -270,6 +285,7 @@ def get_dataset(
...
@@ -270,6 +285,7 @@ def get_dataset(
dataset_module
=
{}
dataset_module
=
{}
if
"train"
in
dataset_dict
:
if
"train"
in
dataset_dict
:
dataset_module
[
"train_dataset"
]
=
dataset_dict
[
"train"
]
dataset_module
[
"train_dataset"
]
=
dataset_dict
[
"train"
]
if
"validation"
in
dataset_dict
:
if
"validation"
in
dataset_dict
:
dataset_module
[
"eval_dataset"
]
=
dataset_dict
[
"validation"
]
dataset_module
[
"eval_dataset"
]
=
dataset_dict
[
"validation"
]
...
...
src/llamafactory/data/mm_plugin.py
0 → 100644
View file @
27a7ad86
This diff is collapsed.
Click to expand it.
src/llamafactory/data/parser.py
View file @
27a7ad86
...
@@ -43,6 +43,7 @@ class DatasetAttr:
...
@@ -43,6 +43,7 @@ class DatasetAttr:
system
:
Optional
[
str
]
=
None
system
:
Optional
[
str
]
=
None
tools
:
Optional
[
str
]
=
None
tools
:
Optional
[
str
]
=
None
images
:
Optional
[
str
]
=
None
images
:
Optional
[
str
]
=
None
videos
:
Optional
[
str
]
=
None
# rlhf columns
# rlhf columns
chosen
:
Optional
[
str
]
=
None
chosen
:
Optional
[
str
]
=
None
rejected
:
Optional
[
str
]
=
None
rejected
:
Optional
[
str
]
=
None
...
@@ -126,7 +127,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
...
@@ -126,7 +127,7 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
dataset_attr
.
set_attr
(
"num_samples"
,
dataset_info
[
name
])
dataset_attr
.
set_attr
(
"num_samples"
,
dataset_info
[
name
])
if
"columns"
in
dataset_info
[
name
]:
if
"columns"
in
dataset_info
[
name
]:
column_names
=
[
"system"
,
"tools"
,
"images"
,
"chosen"
,
"rejected"
,
"kto_tag"
]
column_names
=
[
"system"
,
"tools"
,
"images"
,
"videos"
,
"chosen"
,
"rejected"
,
"kto_tag"
]
if
dataset_attr
.
formatting
==
"alpaca"
:
if
dataset_attr
.
formatting
==
"alpaca"
:
column_names
.
extend
([
"prompt"
,
"query"
,
"response"
,
"history"
])
column_names
.
extend
([
"prompt"
,
"query"
,
"response"
,
"history"
])
else
:
else
:
...
...
src/llamafactory/data/preprocess.py
View file @
27a7ad86
...
@@ -50,7 +50,7 @@ def get_preprocess_and_print_func(
...
@@ -50,7 +50,7 @@ def get_preprocess_and_print_func(
print_function
=
partial
(
print_unsupervised_dataset_example
,
tokenizer
=
tokenizer
)
print_function
=
partial
(
print_unsupervised_dataset_example
,
tokenizer
=
tokenizer
)
elif
stage
==
"sft"
and
not
do_generate
:
elif
stage
==
"sft"
and
not
do_generate
:
if
data_args
.
packing
:
if
data_args
.
packing
:
if
data_args
.
neat_packing
:
if
data_args
.
neat_packing
:
# hack datasets to have int32 attention mask
from
datasets.arrow_writer
import
OptimizedTypedSequence
,
TypedSequence
from
datasets.arrow_writer
import
OptimizedTypedSequence
,
TypedSequence
def
__init__
(
self
,
data
,
**
kwargs
):
def
__init__
(
self
,
data
,
**
kwargs
):
...
@@ -67,6 +67,7 @@ def get_preprocess_and_print_func(
...
@@ -67,6 +67,7 @@ def get_preprocess_and_print_func(
preprocess_packed_supervised_dataset
,
preprocess_packed_supervised_dataset
,
template
=
template
,
template
=
template
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
data_args
=
data_args
,
)
)
else
:
else
:
...
...
src/llamafactory/data/processors/feedback.py
View file @
27a7ad86
...
@@ -12,17 +12,19 @@
...
@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
...extras.logging
import
get_logger
from
.processor_utils
import
get_paligemma_token_type_ids
,
get_pixel_values
,
infer_seqlen
from
.processor_utils
import
infer_seqlen
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
...hparams
import
DataArguments
from
..mm_plugin
import
ImageInput
,
VideoInput
from
..template
import
Template
from
..template
import
Template
...
@@ -35,14 +37,13 @@ def _encode_feedback_example(
...
@@ -35,14 +37,13 @@ def _encode_feedback_example(
kl_response
:
Sequence
[
Dict
[
str
,
str
]],
kl_response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
template
:
"Template"
,
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
cutoff_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
bool
]:
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
bool
]:
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
if
response
[
0
][
"content"
]:
# desired example
if
response
[
0
][
"content"
]:
# desired example
kto_tag
=
True
kto_tag
=
True
messages
=
prompt
+
[
response
[
0
]]
messages
=
prompt
+
[
response
[
0
]]
...
@@ -55,6 +56,8 @@ def _encode_feedback_example(
...
@@ -55,6 +56,8 @@ def _encode_feedback_example(
else
:
else
:
kl_messages
=
prompt
+
[
kl_response
[
1
]]
kl_messages
=
prompt
+
[
kl_response
[
1
]]
messages
=
template
.
mm_plugin
.
process_messages
(
messages
,
images
,
videos
,
processor
)
kl_messages
=
template
.
mm_plugin
.
process_messages
(
kl_messages
,
images
,
videos
,
processor
)
prompt_ids
,
response_ids
=
template
.
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
prompt_ids
,
response_ids
=
template
.
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
kl_prompt_ids
,
kl_response_ids
=
template
.
encode_oneturn
(
tokenizer
,
kl_messages
,
system
,
tools
)
kl_prompt_ids
,
kl_response_ids
=
template
.
encode_oneturn
(
tokenizer
,
kl_messages
,
system
,
tools
)
...
@@ -62,10 +65,8 @@ def _encode_feedback_example(
...
@@ -62,10 +65,8 @@ def _encode_feedback_example(
response_ids
+=
[
tokenizer
.
eos_token_id
]
response_ids
+=
[
tokenizer
.
eos_token_id
]
kl_response_ids
+=
[
tokenizer
.
eos_token_id
]
kl_response_ids
+=
[
tokenizer
.
eos_token_id
]
if
processor
is
not
None
and
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
prompt_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
images
,
videos
,
tokenizer
,
processor
)
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
template
.
image_token
)
kl_prompt_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
kl_prompt_ids
,
None
,
images
,
videos
,
tokenizer
,
processor
)
prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
prompt_ids
kl_prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
kl_prompt_ids
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
len
(
response_ids
),
cutoff_len
)
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
len
(
response_ids
),
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
prompt_ids
=
prompt_ids
[:
source_len
]
...
@@ -78,7 +79,6 @@ def _encode_feedback_example(
...
@@ -78,7 +79,6 @@ def _encode_feedback_example(
labels
=
[
IGNORE_INDEX
]
*
source_len
+
response_ids
labels
=
[
IGNORE_INDEX
]
*
source_len
+
response_ids
kl_input_ids
=
kl_prompt_ids
+
kl_response_ids
kl_input_ids
=
kl_prompt_ids
+
kl_response_ids
kl_labels
=
[
IGNORE_INDEX
]
*
kl_source_len
+
kl_response_ids
kl_labels
=
[
IGNORE_INDEX
]
*
kl_source_len
+
kl_response_ids
return
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
return
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
...
@@ -88,35 +88,23 @@ def preprocess_feedback_dataset(
...
@@ -88,35 +88,23 @@ def preprocess_feedback_dataset(
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
List
[
int
]
]]:
)
->
Dict
[
str
,
List
[
Any
]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response
=
examples
[
"response"
][::
-
1
]
kl_response
=
examples
[
"_response"
][::
-
1
]
model_inputs
=
{
model_inputs
=
defaultdict
(
list
)
"input_ids"
:
[],
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
"attention_mask"
:
[],
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
"labels"
:
[],
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
]))
"kl_input_ids"
:
[],
"kl_attention_mask"
:
[],
"kl_labels"
:
[],
"kto_tags"
:
[],
}
if
processor
is
not
None
:
model_inputs
[
"pixel_values"
]
=
[]
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
model_inputs
[
"token_type_ids"
]
=
[]
model_inputs
[
"kl_token_type_ids"
]
=
[]
for
i
in
range
(
len
(
examples
[
"prompt"
])):
if
len
(
examples
[
"prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"response"
][
i
])
<
2
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"prompt"
][
i
]
+
examples
[
"response"
][
i
]))
continue
continue
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
=
_encode_feedback_example
(
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
=
_encode_feedback_example
(
prompt
=
examples
[
"prompt"
][
i
],
prompt
=
examples
[
"
_
prompt"
][
i
],
response
=
examples
[
"response"
][
i
],
response
=
examples
[
"
_
response"
][
i
],
kl_response
=
kl_response
[
i
],
kl_response
=
kl_response
[
i
],
system
=
examples
[
"system"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"tools"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
template
=
template
,
template
=
template
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
processor
=
processor
,
processor
=
processor
,
...
@@ -129,11 +117,8 @@ def preprocess_feedback_dataset(
...
@@ -129,11 +117,8 @@ def preprocess_feedback_dataset(
model_inputs
[
"kl_attention_mask"
].
append
([
1
]
*
len
(
kl_input_ids
))
model_inputs
[
"kl_attention_mask"
].
append
([
1
]
*
len
(
kl_input_ids
))
model_inputs
[
"kl_labels"
].
append
(
kl_labels
)
model_inputs
[
"kl_labels"
].
append
(
kl_labels
)
model_inputs
[
"kto_tags"
].
append
(
kto_tag
)
model_inputs
[
"kto_tags"
].
append
(
kto_tag
)
if
processor
is
not
None
:
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"pixel_values"
].
append
(
get_pixel_values
(
examples
[
"images"
][
i
],
processor
))
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
model_inputs
[
"token_type_ids"
].
append
(
get_paligemma_token_type_ids
(
len
(
input_ids
),
processor
))
model_inputs
[
"kl_token_type_ids"
].
append
(
get_paligemma_token_type_ids
(
len
(
kl_input_ids
),
processor
))
desirable_num
=
sum
([
1
for
tag
in
model_inputs
[
"kto_tags"
]
if
tag
])
desirable_num
=
sum
([
1
for
tag
in
model_inputs
[
"kto_tags"
]
if
tag
])
undesirable_num
=
len
(
model_inputs
[
"kto_tags"
])
-
desirable_num
undesirable_num
=
len
(
model_inputs
[
"kto_tags"
])
-
desirable_num
...
...
src/llamafactory/data/processors/pairwise.py
View file @
27a7ad86
...
@@ -12,17 +12,19 @@
...
@@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
...extras.logging
import
get_logger
from
.processor_utils
import
get_paligemma_token_type_ids
,
get_pixel_values
,
infer_seqlen
from
.processor_utils
import
infer_seqlen
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
...hparams
import
DataArguments
from
..mm_plugin
import
ImageInput
,
VideoInput
from
..template
import
Template
from
..template
import
Template
...
@@ -34,16 +36,15 @@ def _encode_pairwise_example(
...
@@ -34,16 +36,15 @@ def _encode_pairwise_example(
response
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
tools
:
Optional
[
str
],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
template
:
"Template"
,
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
cutoff_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
]]:
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
]]:
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
chosen_messages
=
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
0
]],
images
,
videos
,
processor
)
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
rejected_messages
=
template
.
mm_plugin
.
process_messages
(
prompt
+
[
response
[
1
]],
images
,
videos
,
processor
)
chosen_messages
=
prompt
+
[
response
[
0
]]
rejected_messages
=
prompt
+
[
response
[
1
]]
prompt_ids
,
chosen_ids
=
template
.
encode_oneturn
(
tokenizer
,
chosen_messages
,
system
,
tools
)
prompt_ids
,
chosen_ids
=
template
.
encode_oneturn
(
tokenizer
,
chosen_messages
,
system
,
tools
)
_
,
rejected_ids
=
template
.
encode_oneturn
(
tokenizer
,
rejected_messages
,
system
,
tools
)
_
,
rejected_ids
=
template
.
encode_oneturn
(
tokenizer
,
rejected_messages
,
system
,
tools
)
...
@@ -51,10 +52,7 @@ def _encode_pairwise_example(
...
@@ -51,10 +52,7 @@ def _encode_pairwise_example(
chosen_ids
+=
[
tokenizer
.
eos_token_id
]
chosen_ids
+=
[
tokenizer
.
eos_token_id
]
rejected_ids
+=
[
tokenizer
.
eos_token_id
]
rejected_ids
+=
[
tokenizer
.
eos_token_id
]
if
processor
is
not
None
and
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
prompt_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
images
,
videos
,
tokenizer
,
processor
)
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
template
.
image_token
)
prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
prompt_ids
# consider the response is more important
# consider the response is more important
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
max
(
len
(
chosen_ids
),
len
(
rejected_ids
)),
cutoff_len
)
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
max
(
len
(
chosen_ids
),
len
(
rejected_ids
)),
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
prompt_ids
=
prompt_ids
[:
source_len
]
...
@@ -65,7 +63,6 @@ def _encode_pairwise_example(
...
@@ -65,7 +63,6 @@ def _encode_pairwise_example(
chosen_labels
=
[
IGNORE_INDEX
]
*
source_len
+
chosen_ids
chosen_labels
=
[
IGNORE_INDEX
]
*
source_len
+
chosen_ids
rejected_input_ids
=
prompt_ids
+
rejected_ids
rejected_input_ids
=
prompt_ids
+
rejected_ids
rejected_labels
=
[
IGNORE_INDEX
]
*
source_len
+
rejected_ids
rejected_labels
=
[
IGNORE_INDEX
]
*
source_len
+
rejected_ids
return
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
return
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
...
@@ -75,32 +72,21 @@ def preprocess_pairwise_dataset(
...
@@ -75,32 +72,21 @@ def preprocess_pairwise_dataset(
tokenizer
:
"PreTrainedTokenizer"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
List
[
int
]
]]:
)
->
Dict
[
str
,
List
[
Any
]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs
=
{
model_inputs
=
defaultdict
(
list
)
"chosen_input_ids"
:
[],
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
"chosen_attention_mask"
:
[],
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
"chosen_labels"
:
[],
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
]))
"rejected_input_ids"
:
[],
"rejected_attention_mask"
:
[],
"rejected_labels"
:
[],
}
if
processor
is
not
None
:
model_inputs
[
"pixel_values"
]
=
[]
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
model_inputs
[
"chosen_token_type_ids"
]
=
[]
model_inputs
[
"rejected_token_type_ids"
]
=
[]
for
i
in
range
(
len
(
examples
[
"prompt"
])):
if
len
(
examples
[
"prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"response"
][
i
])
<
2
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"prompt"
][
i
]
+
examples
[
"response"
][
i
]))
continue
continue
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
=
_encode_pairwise_example
(
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
=
_encode_pairwise_example
(
prompt
=
examples
[
"prompt"
][
i
],
prompt
=
examples
[
"_prompt"
][
i
],
response
=
examples
[
"response"
][
i
],
response
=
examples
[
"_response"
][
i
],
system
=
examples
[
"system"
][
i
],
system
=
examples
[
"_system"
][
i
],
tools
=
examples
[
"tools"
][
i
],
tools
=
examples
[
"_tools"
][
i
],
images
=
examples
[
"_images"
][
i
]
or
[],
videos
=
examples
[
"_videos"
][
i
]
or
[],
template
=
template
,
template
=
template
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
processor
=
processor
,
processor
=
processor
,
...
@@ -112,15 +98,8 @@ def preprocess_pairwise_dataset(
...
@@ -112,15 +98,8 @@ def preprocess_pairwise_dataset(
model_inputs
[
"rejected_input_ids"
].
append
(
rejected_input_ids
)
model_inputs
[
"rejected_input_ids"
].
append
(
rejected_input_ids
)
model_inputs
[
"rejected_attention_mask"
].
append
([
1
]
*
len
(
rejected_input_ids
))
model_inputs
[
"rejected_attention_mask"
].
append
([
1
]
*
len
(
rejected_input_ids
))
model_inputs
[
"rejected_labels"
].
append
(
rejected_labels
)
model_inputs
[
"rejected_labels"
].
append
(
rejected_labels
)
if
processor
is
not
None
:
model_inputs
[
"images"
].
append
(
examples
[
"_images"
][
i
])
model_inputs
[
"pixel_values"
].
append
(
get_pixel_values
(
examples
[
"images"
][
i
],
processor
))
model_inputs
[
"videos"
].
append
(
examples
[
"_videos"
][
i
])
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
model_inputs
[
"chosen_token_type_ids"
].
append
(
get_paligemma_token_type_ids
(
len
(
chosen_input_ids
),
processor
)
)
model_inputs
[
"rejected_token_type_ids"
].
append
(
get_paligemma_token_type_ids
(
len
(
rejected_input_ids
),
processor
)
)
return
model_inputs
return
model_inputs
...
...
src/llamafactory/data/processors/pretrain.py
View file @
27a7ad86
...
@@ -27,16 +27,16 @@ if TYPE_CHECKING:
...
@@ -27,16 +27,16 @@ if TYPE_CHECKING:
def
preprocess_pretrain_dataset
(
def
preprocess_pretrain_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
examples
:
Dict
[
str
,
List
[
Any
]],
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
)
->
Dict
[
str
,
List
[
List
[
int
]
]]:
)
->
Dict
[
str
,
List
[
Any
]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token
=
"<|end_of_text|>"
if
data_args
.
template
==
"llama3"
else
tokenizer
.
eos_token
eos_token
=
"<|end_of_text|>"
if
data_args
.
template
==
"llama3"
else
tokenizer
.
eos_token
text_examples
=
[
messages
[
0
][
"content"
]
+
eos_token
for
messages
in
examples
[
"prompt"
]]
text_examples
=
[
messages
[
0
][
"content"
]
+
eos_token
for
messages
in
examples
[
"
_
prompt"
]]
if
not
data_args
.
packing
:
if
not
data_args
.
packing
:
if
data_args
.
template
==
"gemma"
:
if
data_args
.
template
==
"gemma"
:
text_examples
=
[
tokenizer
.
bos_token
+
example
for
example
in
text_examples
]
text_examples
=
[
tokenizer
.
bos_token
+
example
for
example
in
text_examples
]
result
=
tokenizer
(
text_examples
,
add_special_tokens
=
False
,
max_length
=
data_args
.
cutoff_len
,
truncation
=
True
)
result
=
tokenizer
(
text_examples
,
add_special_tokens
=
False
,
truncation
=
True
,
max_length
=
data_args
.
cutoff_len
)
else
:
else
:
tokenized_examples
=
tokenizer
(
text_examples
,
add_special_tokens
=
False
)
tokenized_examples
=
tokenizer
(
text_examples
,
add_special_tokens
=
False
)
concatenated_examples
=
{
k
:
list
(
chain
(
*
tokenized_examples
[
k
]))
for
k
in
tokenized_examples
.
keys
()}
concatenated_examples
=
{
k
:
list
(
chain
(
*
tokenized_examples
[
k
]))
for
k
in
tokenized_examples
.
keys
()}
...
...
src/llamafactory/data/processors/processor_utils.py
View file @
27a7ad86
This diff is collapsed.
Click to expand it.
src/llamafactory/data/processors/supervised.py
View file @
27a7ad86
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment