Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
2778a3d0
Commit
2778a3d0
authored
Jan 16, 2025
by
luopl
Browse files
updata to v0.9.1_stable
parent
e92143e3
Changes
172
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
498 additions
and
259 deletions
+498
-259
setup.py
setup.py
+6
-5
src/api.py
src/api.py
+3
-3
src/llamafactory/__init__.py
src/llamafactory/__init__.py
+6
-5
src/llamafactory/api/app.py
src/llamafactory/api/app.py
+6
-6
src/llamafactory/api/chat.py
src/llamafactory/api/chat.py
+15
-15
src/llamafactory/chat/base_engine.py
src/llamafactory/chat/base_engine.py
+4
-4
src/llamafactory/chat/chat_model.py
src/llamafactory/chat/chat_model.py
+13
-13
src/llamafactory/chat/hf_engine.py
src/llamafactory/chat/hf_engine.py
+55
-31
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+40
-26
src/llamafactory/cli.py
src/llamafactory/cli.py
+15
-14
src/llamafactory/data/aligner.py
src/llamafactory/data/aligner.py
+22
-16
src/llamafactory/data/collator.py
src/llamafactory/data/collator.py
+9
-6
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+5
-5
src/llamafactory/data/formatter.py
src/llamafactory/data/formatter.py
+5
-5
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+36
-23
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+219
-59
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+16
-8
src/llamafactory/data/processors/feedback.py
src/llamafactory/data/processors/feedback.py
+6
-4
src/llamafactory/data/processors/pairwise.py
src/llamafactory/data/processors/pairwise.py
+7
-5
src/llamafactory/data/processors/supervised.py
src/llamafactory/data/processors/supervised.py
+10
-6
No files found.
setup.py
View file @
2778a3d0
...
...
@@ -20,7 +20,7 @@ from setuptools import find_packages, setup
def
get_version
()
->
str
:
with
open
(
os
.
path
.
join
(
"src"
,
"llamafactory"
,
"extras"
,
"env.py"
),
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
os
.
path
.
join
(
"src"
,
"llamafactory"
,
"extras"
,
"env.py"
),
encoding
=
"utf-8"
)
as
f
:
file_content
=
f
.
read
()
pattern
=
r
"{}\W*=\W*\"([^\"]+)\""
.
format
(
"VERSION"
)
(
version
,)
=
re
.
findall
(
pattern
,
file_content
)
...
...
@@ -28,7 +28,7 @@ def get_version() -> str:
def
get_requires
()
->
List
[
str
]:
with
open
(
"requirements.txt"
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
"requirements.txt"
,
encoding
=
"utf-8"
)
as
f
:
file_content
=
f
.
read
()
lines
=
[
line
.
strip
()
for
line
in
file_content
.
strip
().
split
(
"
\n
"
)
if
not
line
.
startswith
(
"#"
)]
return
lines
...
...
@@ -54,13 +54,14 @@ extra_require = {
"gptq"
:
[
"optimum>=1.17.0"
,
"auto-gptq>=0.5.0"
],
"awq"
:
[
"autoawq"
],
"aqlm"
:
[
"aqlm[gpu]>=1.1.0"
],
"vllm"
:
[
"vllm>=0.4.3,<
=
0.6.
2
"
],
"vllm"
:
[
"vllm>=0.4.3,<0.6.
4
"
],
"galore"
:
[
"galore-torch"
],
"badam"
:
[
"badam>=1.2.1"
],
"adam-mini"
:
[
"adam-mini"
],
"qwen"
:
[
"transformers_stream_generator"
],
"modelscope"
:
[
"modelscope"
],
"dev"
:
[
"ruff"
,
"pytest"
],
"openmind"
:
[
"openmind"
],
"dev"
:
[
"pre-commit"
,
"ruff"
,
"pytest"
],
}
...
...
@@ -71,7 +72,7 @@ def main():
author
=
"hiyouga"
,
author_email
=
"hiyouga"
"@"
"buaa.edu.cn"
,
description
=
"Easy-to-use LLM fine-tuning framework"
,
long_description
=
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
).
read
(),
long_description
=
open
(
"README.md"
,
encoding
=
"utf-8"
).
read
(),
long_description_content_type
=
"text/markdown"
,
keywords
=
[
"LLaMA"
,
"BLOOM"
,
"Falcon"
,
"LLM"
,
"ChatGPT"
,
"transformer"
,
"pytorch"
,
"deep learning"
],
license
=
"Apache 2.0 License"
,
...
...
src/api.py
View file @
2778a3d0
...
...
@@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
def
main
():
chat_model
=
ChatModel
()
app
=
create_app
(
chat_model
)
api_host
=
os
.
environ
.
get
(
"API_HOST"
,
"0.0.0.0"
)
api_port
=
int
(
os
.
environ
.
get
(
"API_PORT"
,
"8000"
))
print
(
"Visit http://localhost:{}/docs for API document."
.
format
(
api_port
)
)
api_host
=
os
.
get
env
(
"API_HOST"
,
"0.0.0.0"
)
api_port
=
int
(
os
.
get
env
(
"API_PORT"
,
"8000"
))
print
(
f
"Visit http://localhost:
{
api_port
}
/docs for API document."
)
uvicorn
.
run
(
app
,
host
=
api_host
,
port
=
api_port
)
...
...
src/llamafactory/__init__.py
View file @
2778a3d0
...
...
@@ -20,17 +20,17 @@ Level:
Dependency graph:
main:
transformers>=4.41.2,<=4.4
5.2
datasets>=2.16.0,<=
2.2
1.0
accelerate>=0.3
0.1,<=0.34.2
transformers>=4.41.2,<=4.4
6.1
datasets>=2.16.0,<=
3.
1.0
accelerate>=0.3
4.0,<=1.0.1
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<=4.4
5.2
transformers>=4.41.2,<=4.4
6.1
packing:
transformers>=4.41.2,<=4.4
5.2
transformers>=4.41.2,<=4.4
6.1
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
...
...
@@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1
Use openmind: USE_OPENMIND_HUB=1
"""
from
.extras.env
import
VERSION
...
...
src/llamafactory/api/app.py
View file @
2778a3d0
...
...
@@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
def
create_app
(
chat_model
:
"ChatModel"
)
->
"FastAPI"
:
root_path
=
os
.
environ
.
get
(
"FASTAPI_ROOT_PATH"
,
""
)
root_path
=
os
.
get
env
(
"FASTAPI_ROOT_PATH"
,
""
)
app
=
FastAPI
(
lifespan
=
partial
(
lifespan
,
chat_model
=
chat_model
),
root_path
=
root_path
)
app
.
add_middleware
(
CORSMiddleware
,
...
...
@@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods
=
[
"*"
],
allow_headers
=
[
"*"
],
)
api_key
=
os
.
environ
.
get
(
"API_KEY"
,
None
)
api_key
=
os
.
get
env
(
"API_KEY"
)
security
=
HTTPBearer
(
auto_error
=
False
)
async
def
verify_api_key
(
auth
:
Annotated
[
Optional
[
HTTPAuthorizationCredentials
],
Depends
(
security
)]):
...
...
@@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
dependencies
=
[
Depends
(
verify_api_key
)],
)
async
def
list_models
():
model_card
=
ModelCard
(
id
=
os
.
environ
.
get
(
"API_MODEL_NAME"
,
"gpt-3.5-turbo"
))
model_card
=
ModelCard
(
id
=
os
.
get
env
(
"API_MODEL_NAME"
,
"gpt-3.5-turbo"
))
return
ModelList
(
data
=
[
model_card
])
@
app
.
post
(
...
...
@@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
def
run_api
()
->
None
:
chat_model
=
ChatModel
()
app
=
create_app
(
chat_model
)
api_host
=
os
.
environ
.
get
(
"API_HOST"
,
"0.0.0.0"
)
api_port
=
int
(
os
.
environ
.
get
(
"API_PORT"
,
"8000"
))
print
(
"Visit http://localhost:{}/docs for API document."
.
format
(
api_port
)
)
api_host
=
os
.
get
env
(
"API_HOST"
,
"0.0.0.0"
)
api_port
=
int
(
os
.
get
env
(
"API_PORT"
,
"8000"
))
print
(
f
"Visit http://localhost:
{
api_port
}
/docs for API document."
)
uvicorn
.
run
(
app
,
host
=
api_host
,
port
=
api_port
)
src/llamafactory/api/chat.py
View file @
2778a3d0
...
...
@@ -21,7 +21,7 @@ import uuid
from
typing
import
TYPE_CHECKING
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
from
..data
import
Role
as
DataRole
from
..extras
.logging
import
get_
logg
er
from
..extras
import
logg
ing
from
..extras.packages
import
is_fastapi_available
,
is_pillow_available
,
is_requests_available
from
.common
import
dictify
,
jsonify
from
.protocol
import
(
...
...
@@ -57,7 +57,7 @@ if TYPE_CHECKING:
from
.protocol
import
ChatCompletionRequest
,
ScoreEvaluationRequest
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
ROLE_MAPPING
=
{
Role
.
USER
:
DataRole
.
USER
.
value
,
Role
.
ASSISTANT
:
DataRole
.
ASSISTANT
.
value
,
...
...
@@ -69,8 +69,8 @@ ROLE_MAPPING = {
def
_process_request
(
request
:
"ChatCompletionRequest"
,
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
Optional
[
str
],
Optional
[
str
],
Optional
[
"ImageInput"
]]:
logger
.
info
(
"==== request ====
\n
{
}"
.
format
(
json
.
dumps
(
dictify
(
request
),
indent
=
2
,
ensure_ascii
=
False
)
)
)
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
Optional
[
str
],
Optional
[
str
],
Optional
[
List
[
"ImageInput"
]]
]
:
logger
.
info
_rank0
(
f
"==== request ====
\n
{
json
.
dumps
(
dictify
(
request
),
indent
=
2
,
ensure_ascii
=
False
)
}
"
)
if
len
(
request
.
messages
)
==
0
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid length"
)
...
...
@@ -84,7 +84,7 @@ def _process_request(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Only supports u/a/u/a/u..."
)
input_messages
=
[]
image
=
None
image
s
=
[]
for
i
,
message
in
enumerate
(
request
.
messages
):
if
i
%
2
==
0
and
message
.
role
not
in
[
Role
.
USER
,
Role
.
TOOL
]:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid role"
)
...
...
@@ -111,7 +111,7 @@ def _process_request(
else
:
# web uri
image_stream
=
requests
.
get
(
image_url
,
stream
=
True
).
raw
image
=
Image
.
open
(
image_stream
).
convert
(
"RGB"
)
image
s
.
append
(
Image
.
open
(
image_stream
).
convert
(
"RGB"
)
)
else
:
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
message
.
content
})
...
...
@@ -124,7 +124,7 @@ def _process_request(
else
:
tools
=
None
return
input_messages
,
system
,
tools
,
image
return
input_messages
,
system
,
tools
,
image
s
or
None
def
_create_stream_chat_completion_chunk
(
...
...
@@ -142,13 +142,13 @@ def _create_stream_chat_completion_chunk(
async
def
create_chat_completion_response
(
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
)
->
"ChatCompletionResponse"
:
completion_id
=
"chatcmpl-{
}"
.
format
(
uuid
.
uuid4
().
hex
)
input_messages
,
system
,
tools
,
image
=
_process_request
(
request
)
completion_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
input_messages
,
system
,
tools
,
image
s
=
_process_request
(
request
)
responses
=
await
chat_model
.
achat
(
input_messages
,
system
,
tools
,
image
,
image
s
,
do_sample
=
request
.
do_sample
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
...
...
@@ -169,7 +169,7 @@ async def create_chat_completion_response(
tool_calls
=
[]
for
tool
in
result
:
function
=
Function
(
name
=
tool
[
0
],
arguments
=
tool
[
1
])
tool_calls
.
append
(
FunctionCall
(
id
=
"call_{
}"
.
format
(
uuid
.
uuid4
().
hex
)
,
function
=
function
))
tool_calls
.
append
(
FunctionCall
(
id
=
f
"call_
{
uuid
.
uuid4
().
hex
}
"
,
function
=
function
))
response_message
=
ChatCompletionMessage
(
role
=
Role
.
ASSISTANT
,
tool_calls
=
tool_calls
)
finish_reason
=
Finish
.
TOOL
...
...
@@ -193,8 +193,8 @@ async def create_chat_completion_response(
async
def
create_stream_chat_completion_response
(
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
)
->
AsyncGenerator
[
str
,
None
]:
completion_id
=
"chatcmpl-{
}"
.
format
(
uuid
.
uuid4
().
hex
)
input_messages
,
system
,
tools
,
image
=
_process_request
(
request
)
completion_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
input_messages
,
system
,
tools
,
image
s
=
_process_request
(
request
)
if
tools
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Cannot stream function calls."
)
...
...
@@ -208,7 +208,7 @@ async def create_stream_chat_completion_response(
input_messages
,
system
,
tools
,
image
,
image
s
,
do_sample
=
request
.
do_sample
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
...
...
@@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
async
def
create_score_evaluation_response
(
request
:
"ScoreEvaluationRequest"
,
chat_model
:
"ChatModel"
)
->
"ScoreEvaluationResponse"
:
score_id
=
"scoreval-{
}"
.
format
(
uuid
.
uuid4
().
hex
)
score_id
=
f
"scoreval-
{
uuid
.
uuid4
().
hex
}
"
if
len
(
request
.
messages
)
==
0
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid request"
)
...
...
src/llamafactory/chat/base_engine.py
View file @
2778a3d0
...
...
@@ -66,8 +66,8 @@ class BaseEngine(ABC):
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
r
"""
...
...
@@ -81,8 +81,8 @@ class BaseEngine(ABC):
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
r
"""
...
...
src/llamafactory/chat/chat_model.py
View file @
2778a3d0
...
...
@@ -53,7 +53,7 @@ class ChatModel:
elif
model_args
.
infer_backend
==
"vllm"
:
self
.
engine
:
"BaseEngine"
=
VllmEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
else
:
raise
NotImplementedError
(
"Unknown backend: {
}"
.
format
(
model_args
.
infer_backend
)
)
raise
NotImplementedError
(
f
"Unknown backend:
{
model_args
.
infer_backend
}
"
)
self
.
_loop
=
asyncio
.
new_event_loop
()
self
.
_thread
=
Thread
(
target
=
_start_background_loop
,
args
=
(
self
.
_loop
,),
daemon
=
True
)
...
...
@@ -64,15 +64,15 @@ class ChatModel:
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
r
"""
Gets a list of responses of the chat model.
"""
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
achat
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
),
self
.
_loop
self
.
achat
(
messages
,
system
,
tools
,
image
s
,
video
s
,
**
input_kwargs
),
self
.
_loop
)
return
task
.
result
()
...
...
@@ -81,28 +81,28 @@ class ChatModel:
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
r
"""
Asynchronously gets a list of responses of the chat model.
"""
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
)
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
image
s
,
video
s
,
**
input_kwargs
)
def
stream_chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
)
->
Generator
[
str
,
None
,
None
]:
r
"""
Gets the response token-by-token of the chat model.
"""
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
)
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
image
s
,
video
s
,
**
input_kwargs
)
while
True
:
try
:
task
=
asyncio
.
run_coroutine_threadsafe
(
generator
.
__anext__
(),
self
.
_loop
)
...
...
@@ -115,14 +115,14 @@ class ChatModel:
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
r
"""
Asynchronously gets the response token-by-token of the chat model.
"""
async
for
new_token
in
self
.
engine
.
stream_chat
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
):
async
for
new_token
in
self
.
engine
.
stream_chat
(
messages
,
system
,
tools
,
image
s
,
video
s
,
**
input_kwargs
):
yield
new_token
def
get_scores
(
...
...
src/llamafactory/chat/hf_engine.py
View file @
2778a3d0
...
...
@@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
from
typing_extensions
import
override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.logging
import
get_logger
from
..extras.misc
import
get_logits_processor
from
..model
import
load_model
,
load_tokenizer
from
.base_engine
import
BaseEngine
,
Response
...
...
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
HuggingfaceEngine
(
BaseEngine
):
...
...
@@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
try
:
asyncio
.
get_event_loop
()
except
RuntimeError
:
logger
.
warning
(
"There is no current event loop, creating a new one."
)
logger
.
warning
_once
(
"There is no current event loop, creating a new one."
)
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
self
.
semaphore
=
asyncio
.
Semaphore
(
int
(
os
.
environ
.
get
(
"MAX_CONCURRENT"
,
"1"
)))
self
.
semaphore
=
asyncio
.
Semaphore
(
int
(
os
.
get
env
(
"MAX_CONCURRENT"
,
"1"
)))
@
staticmethod
def
_process_args
(
...
...
@@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
Tuple
[
Dict
[
str
,
Any
],
int
]:
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
]}
if
image
is
not
None
:
mm_input_dict
.
update
({
"images"
:
[
image
]
,
"imglens"
:
[
1
]})
if
IMAGE_PLACEHOLDER
not
in
message
s
[
0
]
[
"content"
]:
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
+
messages
[
0
][
"content"
]
if
image
s
is
not
None
:
mm_input_dict
.
update
({
"images"
:
image
s
,
"imglens"
:
[
len
(
images
)
]})
if
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
)
:
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
*
len
(
images
)
+
messages
[
0
][
"content"
]
if
video
is
not
None
:
mm_input_dict
.
update
({
"videos"
:
[
video
]
,
"vidlens"
:
[
1
]})
if
VIDEO_PLACEHOLDER
not
in
message
s
[
0
]
[
"content"
]:
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
+
messages
[
0
][
"content"
]
if
video
s
is
not
None
:
mm_input_dict
.
update
({
"videos"
:
video
s
,
"vidlens"
:
[
len
(
videos
)
]})
if
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
)
:
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
*
len
(
videos
)
+
messages
[
0
][
"content"
]
messages
=
template
.
mm_plugin
.
process_messages
(
messages
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
processor
...
...
@@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
stop
is
not
None
:
logger
.
warning
(
"Stop parameter is not supported by the huggingface engine yet."
)
logger
.
warning
_rank0
(
"Stop parameter is not supported by the huggingface engine yet."
)
generating_args
=
generating_args
.
copy
()
generating_args
.
update
(
...
...
@@ -164,9 +164,13 @@ class HuggingfaceEngine(BaseEngine):
logits_processor
=
get_logits_processor
(),
)
mm_inputs
=
template
.
mm_plugin
.
get_mm_inputs
(
**
mm_input_dict
,
seqlen
s
=
[
prompt_
length
],
processor
=
processor
)
mm_inputs
=
template
.
mm_plugin
.
get_mm_inputs
(
**
mm_input_dict
,
batch_id
s
=
[
prompt_
ids
],
processor
=
processor
)
for
key
,
value
in
mm_inputs
.
items
():
value
=
value
if
isinstance
(
value
,
torch
.
Tensor
)
else
torch
.
tensor
(
value
)
if
isinstance
(
value
,
list
)
and
all
(
isinstance
(
v
,
torch
.
Tensor
)
for
v
in
value
):
# for pixtral inputs
value
=
torch
.
stack
(
value
)
# assume they have same sizes
elif
not
isinstance
(
value
,
torch
.
Tensor
):
value
=
torch
.
tensor
(
value
)
gen_kwargs
[
key
]
=
value
.
to
(
model
.
device
)
return
gen_kwargs
,
prompt_length
...
...
@@ -182,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
List
[
"Response"
]:
gen_kwargs
,
prompt_length
=
HuggingfaceEngine
.
_process_args
(
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
image
,
video
,
input_kwargs
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
images
,
videos
,
input_kwargs
,
)
generate_output
=
model
.
generate
(
**
gen_kwargs
)
response_ids
=
generate_output
[:,
prompt_length
:]
...
...
@@ -218,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
Callable
[[],
str
]:
gen_kwargs
,
_
=
HuggingfaceEngine
.
_process_args
(
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
image
,
video
,
input_kwargs
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
images
,
videos
,
input_kwargs
,
)
streamer
=
TextIteratorStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
True
)
gen_kwargs
[
"streamer"
]
=
streamer
...
...
@@ -266,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
if
not
self
.
can_generate
:
...
...
@@ -283,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
messages
,
system
,
tools
,
image
,
video
,
image
s
,
video
s
,
input_kwargs
,
)
async
with
self
.
semaphore
:
...
...
@@ -297,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
if
not
self
.
can_generate
:
...
...
@@ -314,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
messages
,
system
,
tools
,
image
,
video
,
image
s
,
video
s
,
input_kwargs
,
)
async
with
self
.
semaphore
:
...
...
src/llamafactory/chat/vllm_engine.py
View file @
2778a3d0
...
...
@@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
from
typing_extensions
import
override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
IMAGE_PLACEHOLDER
from
..extras.logging
import
get_logger
from
..extras.misc
import
get_device_count
from
..extras.packages
import
is_pillow_available
,
is_vllm_available
from
..model
import
load_config
,
load_tokenizer
...
...
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
VllmEngine
(
BaseEngine
):
...
...
@@ -83,11 +83,13 @@ class VllmEngine(BaseEngine):
"enable_lora"
:
model_args
.
adapter_name_or_path
is
not
None
,
"max_lora_rank"
:
model_args
.
vllm_max_lora_rank
,
}
if
isinstance
(
model_args
.
vllm_config
,
dict
):
engine_args
.
update
(
model_args
.
vllm_config
)
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
import
vllm.model_executor.models.llava
logger
.
info
(
"Detected Yi-VL model, applying projector patch."
)
logger
.
info
_rank0
(
"Detected Yi-VL model, applying projector patch."
)
vllm
.
model_executor
.
models
.
llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVLForVLLM
self
.
model
=
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
**
engine_args
))
...
...
@@ -101,21 +103,28 @@ class VllmEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
)
->
AsyncIterator
[
"RequestOutput"
]:
request_id
=
"chatcmpl-{
}"
.
format
(
uuid
.
uuid4
().
hex
)
if
image
is
not
None
:
if
IMAGE_PLACEHOLDER
not
in
message
s
[
0
]
[
"content"
]:
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
+
messages
[
0
][
"content"
]
request_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
if
image
s
is
not
None
:
if
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
)
:
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
*
len
(
images
)
+
messages
[
0
][
"content"
]
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
if
self
.
template
.
mm_plugin
.
__class__
.
__name__
==
"Qwen2vlPlugin"
:
# temporary solution
image_str
=
f
"<|vision_start|>
{
self
.
template
.
mm_plugin
.
image_token
}
<|vision_end|>"
else
:
image_str
=
self
.
template
.
mm_plugin
.
image_token
or
""
paired_messages
=
[
{
"role"
:
message
[
"role"
],
"content"
:
message
[
"content"
].
replace
(
IMAGE_PLACEHOLDER
,
image_str
)}
for
message
in
messages
]
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
use_beam_search
:
bool
=
self
.
generating_args
[
"num_beams"
]
>
1
temperature
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"temperature"
,
None
)
top_p
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_p"
,
None
)
top_k
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_k"
,
None
)
...
...
@@ -126,6 +135,9 @@ class VllmEngine(BaseEngine):
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
length_penalty
is
not
None
:
logger
.
warning_rank0
(
"Length penalty is not supported by the vllm engine yet."
)
if
"max_new_tokens"
in
self
.
generating_args
:
max_tokens
=
self
.
generating_args
[
"max_new_tokens"
]
elif
"max_length"
in
self
.
generating_args
:
...
...
@@ -149,27 +161,29 @@ class VllmEngine(BaseEngine):
temperature
=
temperature
if
temperature
is
not
None
else
self
.
generating_args
[
"temperature"
],
top_p
=
(
top_p
if
top_p
is
not
None
else
self
.
generating_args
[
"top_p"
])
or
1.0
,
# top_p must > 0
top_k
=
top_k
if
top_k
is
not
None
else
self
.
generating_args
[
"top_k"
],
use_beam_search
=
use_beam_search
,
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
generating_args
[
"length_penalty"
],
stop
=
stop
,
stop_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
+
self
.
tokenizer
.
additional_special_tokens_ids
,
max_tokens
=
max_tokens
,
skip_special_tokens
=
True
,
)
if
image
is
not
None
:
# add image features
if
not
isinstance
(
image
,
(
str
,
ImageObject
)):
raise
ValueError
(
"Expected image input is a path or PIL.Image, but got {}."
.
format
(
type
(
image
)))
if
images
is
not
None
:
# add image features
image_data
=
[]
for
image
in
images
:
if
not
isinstance
(
image
,
(
str
,
ImageObject
)):
raise
ValueError
(
f
"Expected image input is a path or PIL.Image, but got
{
type
(
image
)
}
."
)
if
isinstance
(
image
,
str
):
image
=
Image
.
open
(
image
).
convert
(
"RGB"
)
if
isinstance
(
image
,
str
):
image
=
Image
.
open
(
image
).
convert
(
"RGB"
)
image_data
.
append
(
image
)
multi_modal_data
=
{
"image"
:
image
}
multi_modal_data
=
{
"image"
:
image
_data
}
else
:
multi_modal_data
=
None
result_generator
=
self
.
model
.
generate
(
inputs
=
{
"prompt_token_ids"
:
prompt_ids
,
"multi_modal_data"
:
multi_modal_data
},
{
"prompt_token_ids"
:
prompt_ids
,
"multi_modal_data"
:
multi_modal_data
},
sampling_params
=
sampling_params
,
request_id
=
request_id
,
lora_request
=
self
.
lora_request
,
...
...
@@ -182,12 +196,12 @@ class VllmEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
final_output
=
None
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
)
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
s
,
video
s
,
**
input_kwargs
)
async
for
request_output
in
generator
:
final_output
=
request_output
...
...
@@ -210,12 +224,12 @@ class VllmEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
generated_text
=
""
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
)
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
s
,
video
s
,
**
input_kwargs
)
async
for
result
in
generator
:
delta_text
=
result
.
outputs
[
0
].
text
[
len
(
generated_text
)
:]
generated_text
=
result
.
outputs
[
0
].
text
...
...
src/llamafactory/cli.py
View file @
2778a3d0
...
...
@@ -22,8 +22,8 @@ from . import launcher
from
.api.app
import
run_api
from
.chat.chat_model
import
run_chat
from
.eval.evaluator
import
run_eval
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.logging
import
get_logger
from
.extras.misc
import
get_device_count
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
...
...
@@ -47,7 +47,7 @@ USAGE = (
WELCOME
=
(
"-"
*
58
+
"
\n
"
+
"| Welcome to LLaMA Factory, version {
}"
.
format
(
VERSION
)
+
f
"| Welcome to LLaMA Factory, version
{
VERSION
}
"
+
" "
*
(
21
-
len
(
VERSION
))
+
"|
\n
|"
+
" "
*
56
...
...
@@ -56,7 +56,7 @@ WELCOME = (
+
"-"
*
58
)
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
@
unique
...
...
@@ -86,25 +86,26 @@ def main():
elif
command
==
Command
.
EXPORT
:
export_model
()
elif
command
==
Command
.
TRAIN
:
force_torchrun
=
os
.
environ
.
get
(
"FORCE_TORCHRUN"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
force_torchrun
=
os
.
get
env
(
"FORCE_TORCHRUN"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
if
force_torchrun
or
get_device_count
()
>
1
:
master_addr
=
os
.
environ
.
get
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
environ
.
get
(
"MASTER_PORT"
,
str
(
random
.
randint
(
20001
,
29999
)))
logger
.
info
(
"Initializing distributed tasks at: {
}:{}"
.
format
(
master_addr
,
master_port
)
)
master_addr
=
os
.
get
env
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
get
env
(
"MASTER_PORT"
,
str
(
random
.
randint
(
20001
,
29999
)))
logger
.
info
_rank0
(
f
"Initializing distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
process
=
subprocess
.
run
(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).
format
(
nnodes
=
os
.
environ
.
get
(
"NNODES"
,
"1"
),
node_rank
=
os
.
environ
.
get
(
"RANK"
,
"0"
),
nproc_per_node
=
os
.
environ
.
get
(
"NPROC_PER_NODE"
,
str
(
get_device_count
())),
)
.
format
(
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
),
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
),
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
())),
master_addr
=
master_addr
,
master_port
=
master_port
,
file_name
=
launcher
.
__file__
,
args
=
" "
.
join
(
sys
.
argv
[
1
:]),
)
,
shell
=
True
,
)
.
split
()
)
sys
.
exit
(
process
.
returncode
)
else
:
...
...
@@ -118,4 +119,4 @@ def main():
elif
command
==
Command
.
HELP
:
print
(
USAGE
)
else
:
raise
NotImplementedError
(
"Unknown command: {
}."
.
format
(
command
)
)
raise
NotImplementedError
(
f
"Unknown command:
{
command
}
."
)
src/llamafactory/data/aligner.py
View file @
2778a3d0
...
...
@@ -16,7 +16,7 @@ import os
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Union
from
..extras
.logging
import
get_
logg
er
from
..extras
import
logg
ing
from
.data_utils
import
Role
...
...
@@ -29,45 +29,51 @@ if TYPE_CHECKING:
from
.parser
import
DatasetAttr
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_convert_images
(
images
:
Sequence
[
"ImageInput"
],
images
:
Union
[
"ImageInput"
,
Sequence
[
"ImageInput"
]
]
,
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
)
->
Optional
[
List
[
"ImageInput"
]]:
r
"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
if
len
(
images
)
==
0
:
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
elif
len
(
images
)
==
0
:
return
None
else
:
images
=
images
[:]
images
=
images
[:]
if
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
for
i
in
range
(
len
(
images
)):
if
isinstance
(
images
[
i
],
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
dataset
_dir
,
images
[
i
])):
images
[
i
]
=
os
.
path
.
join
(
data_args
.
dataset
_dir
,
images
[
i
])
if
isinstance
(
images
[
i
],
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
image
_dir
,
images
[
i
])):
images
[
i
]
=
os
.
path
.
join
(
data_args
.
image
_dir
,
images
[
i
])
return
images
def
_convert_videos
(
videos
:
Sequence
[
"VideoInput"
],
videos
:
Union
[
"VideoInput"
,
Sequence
[
"VideoInput"
]
]
,
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
)
->
Optional
[
List
[
"VideoInput"
]]:
r
"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if
len
(
videos
)
==
0
:
if
not
isinstance
(
videos
,
list
):
videos
=
[
videos
]
elif
len
(
videos
)
==
0
:
return
None
else
:
videos
=
videos
[:]
videos
=
videos
[:]
if
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
for
i
in
range
(
len
(
videos
)):
if
isinstance
(
videos
[
i
],
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
dataset
_dir
,
videos
[
i
])):
videos
[
i
]
=
os
.
path
.
join
(
data_args
.
dataset
_dir
,
videos
[
i
])
if
isinstance
(
videos
[
i
],
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
image
_dir
,
videos
[
i
])):
videos
[
i
]
=
os
.
path
.
join
(
data_args
.
image
_dir
,
videos
[
i
])
return
videos
...
...
@@ -161,7 +167,7 @@ def convert_sharegpt(
broken_data
=
False
for
turn_idx
,
message
in
enumerate
(
messages
):
if
message
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
turn_idx
%
2
]:
logger
.
warning
(
"Invalid role tag in {
}."
.
format
(
messages
)
)
logger
.
warning
_rank0
(
f
"Invalid role tag in
{
messages
}
."
)
broken_data
=
True
aligned_messages
.
append
(
...
...
@@ -171,7 +177,7 @@ def convert_sharegpt(
if
(
not
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
!=
0
)
or
(
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
==
0
):
logger
.
warning
(
"Invalid message count in {
}."
.
format
(
messages
)
)
logger
.
warning
_rank0
(
f
"Invalid message count in
{
messages
}
."
)
broken_data
=
True
if
dataset_attr
.
kto_tag
and
isinstance
(
example
[
dataset_attr
.
kto_tag
],
bool
):
# kto example
...
...
@@ -192,7 +198,7 @@ def convert_sharegpt(
chosen
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
or
rejected
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
):
logger
.
warning
(
"Invalid role tag in {
}."
.
format
(
[
chosen
,
rejected
]
)
)
logger
.
warning
_rank0
(
f
"Invalid role tag in
{
[
chosen
,
rejected
]
}
."
)
broken_data
=
True
prompt
=
aligned_messages
...
...
@@ -205,7 +211,7 @@ def convert_sharegpt(
response
=
aligned_messages
[
-
1
:]
if
broken_data
:
logger
.
warning
(
"Skipping this abnormal example."
)
logger
.
warning
_rank0
(
"Skipping this abnormal example."
)
prompt
,
response
=
[],
[]
convert_images
=
partial
(
_convert_images
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
...
...
src/llamafactory/data/collator.py
View file @
2778a3d0
...
...
@@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
processor
:
Optional
[
"ProcessorMixin"
]
=
None
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
"torch.Tensor"
]:
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_
seqlen
s
=
[],
[],
[],
[],
[]
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_
input_id
s
=
[],
[],
[],
[],
[]
for
feature
in
features
:
images
=
feature
.
pop
(
"images"
,
None
)
or
[]
videos
=
feature
.
pop
(
"videos"
,
None
)
or
[]
...
...
@@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_videos
.
extend
(
videos
)
batch_imglens
.
append
(
len
(
images
))
batch_vidlens
.
append
(
len
(
videos
))
batch_
seqlen
s
.
append
(
len
(
feature
[
"input_ids"
])
)
batch_
input_id
s
.
append
(
feature
[
"input_ids"
])
mm_inputs
=
self
.
template
.
mm_plugin
.
get_mm_inputs
(
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_
seqlen
s
,
self
.
processor
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_
input_id
s
,
self
.
processor
)
if
"token_type_ids"
in
mm_inputs
:
token_type_ids
=
mm_inputs
.
pop
(
"token_type_ids"
)
...
...
@@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features
:
Dict
[
str
,
"torch.Tensor"
]
=
super
().
__call__
(
features
)
features
.
update
(
mm_inputs
)
if
isinstance
(
features
.
get
(
"pixel_values"
),
list
):
# for pixtral inputs
features
=
features
.
data
# use default_collate() instead of BatchEncoding.to()
return
features
...
...
@@ -137,9 +140,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
for
key
in
(
"chosen"
,
"rejected"
):
for
feature
in
features
:
target_feature
=
{
"input_ids"
:
feature
[
"{}_input_ids"
.
format
(
key
)
],
"attention_mask"
:
feature
[
"{}_attention_mask"
.
format
(
key
)
],
"labels"
:
feature
[
"{}_labels"
.
format
(
key
)
],
"input_ids"
:
feature
[
f
"
{
key
}
_input_ids"
],
"attention_mask"
:
feature
[
f
"
{
key
}
_attention_mask"
],
"labels"
:
feature
[
f
"
{
key
}
_labels"
],
"images"
:
feature
[
"images"
],
"videos"
:
feature
[
"videos"
],
}
...
...
src/llamafactory/data/data_utils.py
View file @
2778a3d0
...
...
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
from
..extras
.logging
import
get_
logg
er
from
..extras
import
logg
ing
if
TYPE_CHECKING
:
...
...
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from
..hparams
import
DataArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
SLOTS
=
Sequence
[
Union
[
str
,
Set
[
str
],
Dict
[
str
,
str
]]]
...
...
@@ -56,12 +56,12 @@ def merge_dataset(
return
all_datasets
[
0
]
elif
data_args
.
mix_strategy
==
"concat"
:
if
data_args
.
streaming
:
logger
.
warning
(
"The samples between different datasets will not be mixed in streaming mode."
)
logger
.
warning
_once
(
"The samples between different datasets will not be mixed in streaming mode."
)
return
concatenate_datasets
(
all_datasets
)
elif
data_args
.
mix_strategy
.
startswith
(
"interleave"
):
if
not
data_args
.
streaming
:
logger
.
warning
(
"We recommend using `mix_strategy=concat` in non-streaming mode."
)
logger
.
warning
_once
(
"We recommend using `mix_strategy=concat` in non-streaming mode."
)
return
interleave_datasets
(
datasets
=
all_datasets
,
...
...
@@ -70,7 +70,7 @@ def merge_dataset(
stopping_strategy
=
"first_exhausted"
if
data_args
.
mix_strategy
.
endswith
(
"under"
)
else
"all_exhausted"
,
)
else
:
raise
ValueError
(
"Unknown mixing strategy: {
}."
.
format
(
data_args
.
mix_strategy
)
)
raise
ValueError
(
f
"Unknown mixing strategy:
{
data_args
.
mix_strategy
}
."
)
def
split_dataset
(
...
...
src/llamafactory/data/formatter.py
View file @
2778a3d0
...
...
@@ -83,14 +83,14 @@ class StringFormatter(Formatter):
if
isinstance
(
slot
,
str
):
for
name
,
value
in
kwargs
.
items
():
if
not
isinstance
(
value
,
str
):
raise
RuntimeError
(
"Expected a string, got {
}"
.
format
(
value
)
)
raise
RuntimeError
(
f
"Expected a string, got
{
value
}
"
)
slot
=
slot
.
replace
(
"{{"
+
name
+
"}}"
,
value
,
1
)
elements
.
append
(
slot
)
elif
isinstance
(
slot
,
(
dict
,
set
)):
elements
.
append
(
slot
)
else
:
raise
RuntimeError
(
"Input must be string, set[str] or dict[str, str], got {
}"
.
format
(
type
(
slot
)
)
)
raise
RuntimeError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
slot
)
}
"
)
return
elements
...
...
@@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
functions
.
append
((
tool_call
[
"name"
],
json
.
dumps
(
tool_call
[
"arguments"
],
ensure_ascii
=
False
)))
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
"Invalid JSON format in function message: {
}"
.
format
(
str
([
content
])
)
)
# flat string
raise
RuntimeError
(
f
"Invalid JSON format in function message:
{
str
([
content
])
}
"
)
# flat string
elements
=
[]
for
name
,
arguments
in
functions
:
...
...
@@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
elif
isinstance
(
slot
,
(
dict
,
set
)):
elements
.
append
(
slot
)
else
:
raise
RuntimeError
(
"Input must be string, set[str] or dict[str, str], got {
}"
.
format
(
type
(
slot
)
)
)
raise
RuntimeError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
slot
)
}
"
)
return
elements
...
...
@@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
tools
=
json
.
loads
(
content
)
return
[
self
.
tool_utils
.
tool_formatter
(
tools
)
if
len
(
tools
)
!=
0
else
""
]
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
"Invalid JSON format in tool description: {
}"
.
format
(
str
([
content
])
)
)
# flat string
raise
RuntimeError
(
f
"Invalid JSON format in tool description:
{
str
([
content
])
}
"
)
# flat string
@
override
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
...
...
src/llamafactory/data/loader.py
View file @
2778a3d0
...
...
@@ -20,8 +20,8 @@ import numpy as np
from
datasets
import
DatasetDict
,
load_dataset
,
load_from_disk
from
transformers.utils.versions
import
require_version
from
..extras
import
logging
from
..extras.constants
import
FILEEXT2TYPE
from
..extras.logging
import
get_logger
from
..extras.misc
import
has_tokenized_data
from
.aligner
import
align_dataset
from
.data_utils
import
merge_dataset
,
split_dataset
...
...
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from
.template
import
Template
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_load_single_dataset
(
...
...
@@ -51,9 +51,9 @@ def _load_single_dataset(
r
"""
Loads a single dataset and aligns it to the standard format.
"""
logger
.
info
(
"Loading dataset {
}..."
.
format
(
dataset_attr
)
)
logger
.
info
_rank0
(
f
"Loading dataset
{
dataset_attr
}
..."
)
data_path
,
data_name
,
data_dir
,
data_files
=
None
,
None
,
None
,
None
if
dataset_attr
.
load_from
in
[
"hf_hub"
,
"ms_hub"
]:
if
dataset_attr
.
load_from
in
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
]:
data_path
=
dataset_attr
.
dataset_name
data_name
=
dataset_attr
.
subset
data_dir
=
dataset_attr
.
folder
...
...
@@ -69,25 +69,24 @@ def _load_single_dataset(
if
os
.
path
.
isdir
(
local_path
):
# is directory
for
file_name
in
os
.
listdir
(
local_path
):
data_files
.
append
(
os
.
path
.
join
(
local_path
,
file_name
))
if
data_path
is
None
:
data_path
=
FILEEXT2TYPE
.
get
(
file_name
.
split
(
"."
)[
-
1
],
None
)
elif
data_path
!=
FILEEXT2TYPE
.
get
(
file_name
.
split
(
"."
)[
-
1
],
None
):
raise
ValueError
(
"File types should be identical."
)
elif
os
.
path
.
isfile
(
local_path
):
# is file
data_files
.
append
(
local_path
)
data_path
=
FILEEXT2TYPE
.
get
(
local_path
.
split
(
"."
)[
-
1
],
None
)
else
:
raise
ValueError
(
"File {} not found."
.
format
(
local_path
)
)
raise
ValueError
(
f
"File
{
local_path
}
not found."
)
data_path
=
FILEEXT2TYPE
.
get
(
os
.
path
.
splitext
(
data_files
[
0
])[
-
1
][
1
:],
None
)
if
data_path
is
None
:
raise
ValueError
(
"Allowed file types: {}."
.
format
(
","
.
join
(
FILEEXT2TYPE
.
keys
())))
if
any
(
data_path
!=
FILEEXT2TYPE
.
get
(
os
.
path
.
splitext
(
data_file
)[
-
1
][
1
:],
None
)
for
data_file
in
data_files
):
raise
ValueError
(
"File types should be identical."
)
else
:
raise
NotImplementedError
(
"Unknown load type: {
}."
.
format
(
dataset_attr
.
load_from
)
)
raise
NotImplementedError
(
f
"Unknown load type:
{
dataset_attr
.
load_from
}
."
)
if
dataset_attr
.
load_from
==
"ms_hub"
:
require_version
(
"modelscope>=1.11.0"
,
"To fix: pip install modelscope>=1.11.0"
)
from
modelscope
import
MsDataset
from
modelscope.utils.config_ds
import
MS_DATASETS_CACHE
from
modelscope
import
MsDataset
# type: ignore
from
modelscope.utils.config_ds
import
MS_DATASETS_CACHE
# type: ignore
cache_dir
=
model_args
.
cache_dir
or
MS_DATASETS_CACHE
dataset
=
MsDataset
.
load
(
...
...
@@ -98,10 +97,27 @@ def _load_single_dataset(
split
=
dataset_attr
.
split
,
cache_dir
=
cache_dir
,
token
=
model_args
.
ms_hub_token
,
use_streaming
=
(
data_args
.
streaming
and
(
dataset_attr
.
load_from
!=
"file"
))
,
use_streaming
=
data_args
.
streaming
,
)
if
isinstance
(
dataset
,
MsDataset
):
dataset
=
dataset
.
to_hf_dataset
()
elif
dataset_attr
.
load_from
==
"om_hub"
:
require_version
(
"openmind>=0.8.0"
,
"To fix: pip install openmind>=0.8.0"
)
from
openmind
import
OmDataset
# type: ignore
from
openmind.utils.hub
import
OM_DATASETS_CACHE
# type: ignore
cache_dir
=
model_args
.
cache_dir
or
OM_DATASETS_CACHE
dataset
=
OmDataset
.
load_dataset
(
path
=
data_path
,
name
=
data_name
,
data_dir
=
data_dir
,
data_files
=
data_files
,
split
=
dataset_attr
.
split
,
cache_dir
=
cache_dir
,
token
=
model_args
.
om_hub_token
,
streaming
=
data_args
.
streaming
,
)
else
:
dataset
=
load_dataset
(
path
=
data_path
,
...
...
@@ -111,13 +127,10 @@ def _load_single_dataset(
split
=
dataset_attr
.
split
,
cache_dir
=
model_args
.
cache_dir
,
token
=
model_args
.
hf_hub_token
,
streaming
=
(
data_args
.
streaming
and
(
dataset_attr
.
load_from
!=
"file"
))
,
streaming
=
data_args
.
streaming
,
trust_remote_code
=
True
,
)
if
data_args
.
streaming
and
(
dataset_attr
.
load_from
==
"file"
):
# faster than specifying streaming=True
dataset
=
dataset
.
to_iterable_dataset
()
# TODO: add num shards parameter
if
dataset_attr
.
num_samples
is
not
None
and
not
data_args
.
streaming
:
target_num
=
dataset_attr
.
num_samples
indexes
=
np
.
random
.
permutation
(
len
(
dataset
))[:
target_num
]
# all samples should be included
...
...
@@ -128,7 +141,7 @@ def _load_single_dataset(
assert
len
(
indexes
)
==
dataset_attr
.
num_samples
,
"Sample num mismatched."
dataset
=
dataset
.
select
(
indexes
)
logger
.
info
(
"Sampled {} examples from dataset {
}."
.
format
(
dataset_attr
.
num_samples
,
dataset_attr
)
)
logger
.
info
_rank0
(
f
"Sampled
{
dataset_attr
.
num_samples
}
examples from dataset
{
dataset_attr
}
."
)
if
data_args
.
max_samples
is
not
None
:
# truncate dataset
max_samples
=
min
(
data_args
.
max_samples
,
len
(
dataset
))
...
...
@@ -224,9 +237,9 @@ def get_dataset(
# Load tokenized dataset
if
data_args
.
tokenized_path
is
not
None
:
if
has_tokenized_data
(
data_args
.
tokenized_path
):
logger
.
warning
(
"Loading dataset from disk will ignore other data arguments."
)
logger
.
warning
_rank0
(
"Loading dataset from disk will ignore other data arguments."
)
dataset_dict
:
"DatasetDict"
=
load_from_disk
(
data_args
.
tokenized_path
)
logger
.
info
(
"Loaded tokenized dataset from {
}."
.
format
(
data_args
.
tokenized_path
)
)
logger
.
info
_rank0
(
f
"Loaded tokenized dataset from
{
data_args
.
tokenized_path
}
."
)
dataset_module
:
Dict
[
str
,
"Dataset"
]
=
{}
if
"train"
in
dataset_dict
:
...
...
@@ -277,8 +290,8 @@ def get_dataset(
if
data_args
.
tokenized_path
is
not
None
:
if
training_args
.
should_save
:
dataset_dict
.
save_to_disk
(
data_args
.
tokenized_path
)
logger
.
info
(
"Tokenized dataset saved at {
}."
.
format
(
data_args
.
tokenized_path
)
)
logger
.
info
(
"Please restart the training with `tokenized_path: {
}`."
.
format
(
data_args
.
tokenized_path
)
)
logger
.
info
_rank0
(
f
"Tokenized dataset saved at
{
data_args
.
tokenized_path
}
."
)
logger
.
info
_rank0
(
f
"Please restart the training with `tokenized_path:
{
data_args
.
tokenized_path
}
`."
)
sys
.
exit
(
0
)
...
...
src/llamafactory/data/mm_plugin.py
View file @
2778a3d0
...
...
@@ -4,11 +4,12 @@ from io import BytesIO
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
TypedDict
,
Union
import
numpy
as
np
import
torch
from
transformers.image_utils
import
get_image_size
,
to_numpy_array
from
typing_extensions
import
override
from
..extras.constants
import
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.packages
import
is_pillow_available
,
is_pyav_available
from
..extras.packages
import
is_pillow_available
,
is_pyav_available
,
is_transformers_version_greater_than
if
is_pillow_available
():
...
...
@@ -20,8 +21,14 @@ if is_pyav_available():
import
av
if
is_transformers_version_greater_than
(
"4.45.0"
):
from
transformers.models.mllama.processing_mllama
import
(
convert_sparse_cross_attention_mask_to_dense
,
get_cross_attention_token_mask
,
)
if
TYPE_CHECKING
:
import
torch
from
av.stream
import
Stream
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.image_processing_utils
import
BaseImageProcessor
...
...
@@ -30,7 +37,7 @@ if TYPE_CHECKING:
path
:
Optional
[
str
]
bytes
:
Optional
[
bytes
]
ImageInput
=
Union
[
str
,
EncodedImage
,
ImageObject
]
ImageInput
=
Union
[
str
,
bytes
,
EncodedImage
,
ImageObject
]
VideoInput
=
str
...
...
@@ -75,8 +82,8 @@ class BasePlugin:
Pre-processes a single image.
"""
image_resolution
:
int
=
kwargs
.
get
(
"image_resolution"
)
if
max
(
image
.
width
,
image
.
height
)
>
image_resolution
:
resize_factor
=
image_resolution
/
max
(
image
.
width
,
image
.
height
)
if
(
image
.
width
*
image
.
height
)
>
image_resolution
:
resize_factor
=
math
.
sqrt
(
image_resolution
/
(
image
.
width
*
image
.
height
)
)
width
,
height
=
int
(
image
.
width
*
resize_factor
),
int
(
image
.
height
*
resize_factor
)
image
=
image
.
resize
((
width
,
height
),
resample
=
Image
.
NEAREST
)
...
...
@@ -104,6 +111,8 @@ class BasePlugin:
for
image
in
images
:
if
isinstance
(
image
,
str
):
image
=
Image
.
open
(
image
)
elif
isinstance
(
image
,
bytes
):
image
=
Image
.
open
(
BytesIO
(
image
))
elif
isinstance
(
image
,
dict
):
if
image
[
"bytes"
]
is
not
None
:
image
=
Image
.
open
(
BytesIO
(
image
[
"bytes"
]))
...
...
@@ -111,7 +120,7 @@ class BasePlugin:
image
=
Image
.
open
(
image
[
"path"
])
if
not
isinstance
(
image
,
ImageObject
):
raise
ValueError
(
"Expect input is a list of Images, but got {
}."
.
format
(
type
(
image
)
)
)
raise
ValueError
(
f
"Expect input is a list of Images, but got
{
type
(
image
)
}
."
)
results
.
append
(
self
.
_preprocess_image
(
image
,
**
kwargs
))
...
...
@@ -163,15 +172,15 @@ class BasePlugin:
if
len
(
images
)
!=
0
:
images
=
self
.
_regularize_images
(
images
,
image_resolution
=
getattr
(
processor
,
"image_resolution"
,
512
),
image_resolution
=
getattr
(
processor
,
"image_resolution"
,
512
*
512
),
)
input_dict
[
"images"
]
=
images
if
len
(
videos
)
!=
0
:
videos
=
self
.
_regularize_videos
(
videos
,
image_resolution
=
getattr
(
processor
,
"video_resolution"
,
128
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
1
.0
),
image_resolution
=
getattr
(
processor
,
"video_resolution"
,
128
*
128
),
video_fps
=
getattr
(
processor
,
"video_fps"
,
2
.0
),
video_maxlen
=
getattr
(
processor
,
"video_maxlen"
,
64
),
)
input_dict
[
"videos"
]
=
videos
...
...
@@ -221,11 +230,19 @@ class BasePlugin:
videos
:
Sequence
[
"VideoInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
seqlen
s
:
Sequence
[
int
],
batch_id
s
:
Sequence
[
List
[
int
]
]
,
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
r
"""
Builds batched multimodal inputs for VLMs.
Arguments:
images: a list of image inputs, shape (num_images,)
videos: a list of video inputs, shape (num_videos,)
imglens: number of images in each sample, shape (batch_size,)
vidlens: number of videos in each sample, shape (batch_size,)
batch_ids: input ids of samples, shape (batch_size, seq_len)
processor: a processor for pre-processing images and videos
"""
self
.
_validate_input
(
images
,
videos
)
return
{}
...
...
@@ -248,12 +265,12 @@ class LlavaPlugin(BasePlugin):
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
num_image_tokens
+=
1
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
,
1
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
*
image_seqlen
)
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
"The number of images does not match the number of {
} tokens"
.
format
(
IMAGE_PLACEHOLDER
)
)
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
...
...
@@ -264,7 +281,7 @@ class LlavaPlugin(BasePlugin):
videos
:
Sequence
[
"VideoInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
seqlen
s
:
Sequence
[
int
],
batch_id
s
:
Sequence
[
List
[
int
]
]
,
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
...
...
@@ -286,23 +303,27 @@ class LlavaNextPlugin(BasePlugin):
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
if
"image_sizes"
in
mm_inputs
:
image_sizes
=
iter
(
mm_inputs
[
"image_sizes"
])
if
"pixel_values"
in
mm_inputs
:
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
][
0
]))
for
message
in
messages
:
content
=
message
[
"content"
]
while
self
.
image_token
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
image_size
=
next
(
image_sizes
)
orig_height
,
orig_width
=
image_size
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
processor
.
vision_feature_select_strategy
==
"default"
:
if
getattr
(
processor
,
"
vision_feature_select_strategy
"
)
==
"default"
:
image_seqlen
-=
1
num_image_tokens
+=
1
content
=
content
.
replace
(
self
.
image_token
,
"{{image}}"
*
image_seqlen
,
1
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
"The number of images does not match the number of {} tokens"
.
format
(
IMAGE_PLACEHOLDER
))
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
override
...
...
@@ -312,12 +333,11 @@ class LlavaNextPlugin(BasePlugin):
videos
:
Sequence
[
"VideoInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
seqlen
s
:
Sequence
[
int
],
batch_id
s
:
Sequence
[
List
[
int
]
]
,
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
res
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
return
res
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
class
LlavaNextVideoPlugin
(
BasePlugin
):
...
...
@@ -330,8 +350,7 @@ class LlavaNextVideoPlugin(BasePlugin):
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
num_image_tokens
=
0
num_video_tokens
=
0
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
if
"pixel_values"
in
mm_inputs
:
...
...
@@ -339,15 +358,15 @@ class LlavaNextVideoPlugin(BasePlugin):
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
[
"pixel_values"
][
0
][
0
]))
for
message
in
messages
:
content
=
message
[
"content"
]
while
self
.
image_token
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
image_size
=
next
(
image_sizes
)
orig_height
,
orig_width
=
image_size
image_seqlen
=
processor
.
_get_number_of_features
(
orig_height
,
orig_width
,
height
,
width
)
if
processor
.
vision_feature_select_strategy
==
"default"
:
if
getattr
(
processor
,
"
vision_feature_select_strategy
"
)
==
"default"
:
image_seqlen
-=
1
num_image_tokens
+=
1
content
=
content
.
replace
(
self
.
image_token
,
"{{image}}"
*
image_seqlen
,
1
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
...
...
@@ -357,19 +376,19 @@ class LlavaNextVideoPlugin(BasePlugin):
num_frames
=
pixel_values_video
.
shape
[
0
]
# frame dim is always after batch dim
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
video_seqlen
=
image_seqlen
//
4
*
num_frames
# divide by 4 needed for avg pooling layer
for
message
in
messages
:
content
=
message
[
"content"
]
while
self
.
video_token
in
content
:
while
VIDEO_PLACEHOLDER
in
content
:
num_video_tokens
+=
1
content
=
content
.
replace
(
self
.
video_token
,
"{{video}}"
,
1
)
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
*
video_seqlen
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{video}}"
*
video_seqlen
,
1
)
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
"The number of images does not match the number of {
} tokens"
.
format
(
IMAGE_PLACEHOLDER
)
)
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
"The number of videos does not match the number of {
} tokens"
.
format
(
IMAGE
_PLACEHOLDER
)
)
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO
_PLACEHOLDER
}
tokens."
)
return
messages
...
...
@@ -380,7 +399,7 @@ class LlavaNextVideoPlugin(BasePlugin):
videos
:
Sequence
[
"VideoInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
seqlen
s
:
Sequence
[
int
],
batch_id
s
:
Sequence
[
List
[
int
]
]
,
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
...
...
@@ -408,7 +427,7 @@ class PaliGemmaPlugin(BasePlugin):
message
[
"content"
]
=
content
.
replace
(
"{{image}}"
,
""
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
"The number of images does not match the number of {
} tokens"
.
format
(
IMAGE_PLACEHOLDER
)
)
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
...
...
@@ -439,15 +458,78 @@ class PaliGemmaPlugin(BasePlugin):
videos
:
Sequence
[
"VideoInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
seqlen
s
:
Sequence
[
int
],
batch_id
s
:
Sequence
[
List
[
int
]
]
,
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
seqlens
=
[
len
(
input_ids
)
for
input_ids
in
batch_ids
]
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
mm_inputs
[
"token_type_ids"
]
=
_get_paligemma_token_type_ids
(
imglens
,
seqlens
,
processor
)
return
mm_inputs
class
PixtralPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
patch_size
=
getattr
(
processor
,
"patch_size"
)
image_token
=
getattr
(
processor
,
"image_token"
)
image_break_token
=
getattr
(
processor
,
"image_break_token"
)
image_end_token
=
getattr
(
processor
,
"image_end_token"
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
image_input_sizes
=
mm_inputs
.
get
(
"image_sizes"
,
None
)
for
message
in
messages
:
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
image_input_sizes
is
None
:
raise
ValueError
(
"Cannot get image input sizes."
)
image_size
=
image_input_sizes
[
0
][
num_image_tokens
]
height
,
width
=
image_size
num_height_tokens
=
height
//
patch_size
num_width_tokens
=
width
//
patch_size
replace_tokens
=
[[
image_token
]
*
num_width_tokens
+
[
image_break_token
]]
*
num_height_tokens
replace_tokens
=
[
item
for
sublist
in
replace_tokens
for
item
in
sublist
]
# flatten list
replace_tokens
[
-
1
]
=
image_end_token
replace_str
=
""
.
join
(
replace_tokens
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
replace_str
,
1
)
num_image_tokens
+=
1
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
override
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
if
mm_inputs
.
get
(
"pixel_values"
):
mm_inputs
[
"pixel_values"
]
=
mm_inputs
[
"pixel_values"
][
0
]
mm_inputs
.
pop
(
"image_sizes"
,
None
)
return
mm_inputs
class
Qwen2vlPlugin
(
BasePlugin
):
@
override
def
_preprocess_image
(
self
,
image
:
"ImageObject"
,
**
kwargs
)
->
"ImageObject"
:
...
...
@@ -493,7 +575,7 @@ class Qwen2vlPlugin(BasePlugin):
content
=
message
[
"content"
]
while
IMAGE_PLACEHOLDER
in
content
:
if
num_image_tokens
>=
len
(
image_grid_thw
):
raise
ValueError
(
"`len(images)` is less than the number of {
} tokens."
.
format
(
IMAGE_PLACEHOLDER
)
)
raise
ValueError
(
f
"`len(images)` is less than the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
...
...
@@ -506,7 +588,7 @@ class Qwen2vlPlugin(BasePlugin):
while
VIDEO_PLACEHOLDER
in
content
:
if
num_video_tokens
>=
len
(
video_grid_thw
):
raise
ValueError
(
"`len(videos)` is less than the number of {
} tokens."
.
format
(
VIDEO_PLACEHOLDER
)
)
raise
ValueError
(
f
"`len(videos)` is less than the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
...
...
@@ -520,10 +602,10 @@ class Qwen2vlPlugin(BasePlugin):
message
[
"content"
]
=
content
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
"The number of images does not match the number of {
} tokens"
.
format
(
IMAGE_PLACEHOLDER
)
)
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
"The number of videos does not match the number of {
} tokens"
.
format
(
VIDEO_PLACEHOLDER
)
)
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
...
...
@@ -534,7 +616,7 @@ class Qwen2vlPlugin(BasePlugin):
videos
:
Sequence
[
"VideoInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
seqlen
s
:
Sequence
[
int
],
batch_id
s
:
Sequence
[
List
[
int
]
]
,
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
...
...
@@ -551,42 +633,45 @@ class VideoLlavaPlugin(BasePlugin):
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
num_image_tokens
=
0
num_video_tokens
=
0
num_image_tokens
,
num_video_tokens
=
0
,
0
messages
=
deepcopy
(
messages
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
num_frames
=
0
exist
_images
=
"pixel_values_images"
in
mm_inputs
exist
_videos
=
"pixel_values_videos"
in
mm_inputs
if
exist_videos
or
exist_image
s
:
if
exist
_images
:
has
_images
=
"pixel_values_images"
in
mm_inputs
has
_videos
=
"pixel_values_videos"
in
mm_inputs
if
has_images
or
has_video
s
:
if
has
_images
:
height
,
width
=
get_image_size
(
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_images"
)[
0
]))
num_frames
=
1
if
exist_videos
:
if
has_videos
:
pixel_values_video
=
to_numpy_array
(
mm_inputs
.
get
(
"pixel_values_videos"
)[
0
])
height
,
width
=
get_image_size
(
pixel_values_video
[
0
])
num_frames
=
pixel_values_video
.
shape
[
0
]
# frame dim is always after batch dim
image_seqlen
=
(
height
//
processor
.
patch_size
)
*
(
width
//
processor
.
patch_size
)
+
1
video_seqlen
=
image_seqlen
*
num_frames
if
processor
.
vision_feature_select_strategy
==
"default"
:
if
getattr
(
processor
,
"
vision_feature_select_strategy
"
)
==
"default"
:
image_seqlen
-=
1
for
message
in
messages
:
content
=
message
[
"content"
]
while
self
.
image_token
in
content
:
while
IMAGE_PLACEHOLDER
in
content
:
num_image_tokens
+=
1
content
=
content
.
replace
(
self
.
image_token
,
"{{image}}"
,
1
)
while
self
.
video_token
in
content
:
content
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
"{{image}}"
*
image_seqlen
,
1
)
while
VIDEO_PLACEHOLDER
in
content
:
num_video_tokens
+=
1
content
=
content
.
replace
(
self
.
video_token
,
"{{video}}"
,
1
)
content
=
content
.
replace
(
VIDEO_PLACEHOLDER
,
"{{video}}"
*
video_seqlen
,
1
)
content
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
*
image_seqlen
)
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
*
video_seqlen
)
content
=
content
.
replace
(
"{{image}}"
,
self
.
image_token
)
message
[
"content"
]
=
content
.
replace
(
"{{video}}"
,
self
.
video_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
"The number of images does not match the number of {
} tokens"
.
format
(
self
.
image_token
)
)
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
if
len
(
videos
)
!=
num_video_tokens
:
raise
ValueError
(
"The number of videos does not match the number of {
} tokens"
.
format
(
self
.
video_token
)
)
raise
ValueError
(
f
"The number of videos does not match the number of
{
VIDEO_PLACEHOLDER
}
tokens."
)
return
messages
...
...
@@ -597,21 +682,96 @@ class VideoLlavaPlugin(BasePlugin):
videos
:
Sequence
[
"VideoInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
seqlen
s
:
Sequence
[
int
],
batch_id
s
:
Sequence
[
List
[
int
]
]
,
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
return
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
class
MllamaPlugin
(
BasePlugin
):
@
override
def
process_messages
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
List
[
Dict
[
str
,
str
]]:
self
.
_validate_input
(
images
,
videos
)
num_image_tokens
=
0
messages
=
deepcopy
(
messages
)
for
message
in
messages
:
content
=
message
[
"content"
]
num_image_tokens
+=
content
.
count
(
IMAGE_PLACEHOLDER
)
message
[
"content"
]
=
content
.
replace
(
IMAGE_PLACEHOLDER
,
self
.
image_token
)
if
len
(
images
)
!=
num_image_tokens
:
raise
ValueError
(
f
"The number of images does not match the number of
{
IMAGE_PLACEHOLDER
}
tokens."
)
return
messages
@
override
def
_get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
processor
:
"ProcessorMixin"
,
)
->
Dict
[
str
,
"torch.Tensor"
]:
r
"""
Processes visual inputs for mllama because its image processor only accepts List[List[ImageInput]].
Returns:
pixel_values: tensor with shape
(batch_size, max_num_images, max_image_tiles, channels, tile_height, tile_width)
For example, (2, 1, 4, 3, 560, 560).
aspect_ratio_ids: tensor with shape (batch_size, max_num_images). For example, (2, 1).
aspect_ratio_mask: tensor with shape (batch_size, max_num_images, max_image_tiles). For example, (2, 1, 4).
num_tiles: List[List[int]] with shape (batch_size, num_images_in_batch). For example, (2, 1).
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
images
=
self
.
_regularize_images
(
images
,
image_resolution
=
getattr
(
processor
,
"image_resolution"
,
512
*
512
))
return
image_processor
([[
image
]
for
image
in
images
],
return_tensors
=
"pt"
)
def
get_mm_inputs
(
self
,
images
:
Sequence
[
"ImageInput"
],
videos
:
Sequence
[
"VideoInput"
],
imglens
:
Sequence
[
int
],
vidlens
:
Sequence
[
int
],
batch_ids
:
Sequence
[
List
[
int
]],
processor
:
Optional
[
"ProcessorMixin"
],
)
->
Dict
[
str
,
Union
[
List
[
int
],
"torch.Tensor"
]]:
self
.
_validate_input
(
images
,
videos
)
if
len
(
images
)
!=
len
(
batch_ids
):
raise
ValueError
(
"Mllama only supports one image per sample."
)
mm_inputs
=
self
.
_get_mm_inputs
(
images
,
videos
,
processor
)
num_tiles
=
mm_inputs
.
pop
(
"num_tiles"
)
image_token_id
=
getattr
(
processor
,
"image_token_id"
)
max_image_tiles
=
getattr
(
processor
.
image_processor
,
"max_image_tiles"
)
cross_attention_token_mask
=
[
get_cross_attention_token_mask
(
input_ids
,
image_token_id
)
for
input_ids
in
batch_ids
]
mm_inputs
[
"cross_attention_mask"
]
=
convert_sparse_cross_attention_mask_to_dense
(
cross_attention_token_mask
,
num_tiles
=
num_tiles
,
max_num_tiles
=
max_image_tiles
,
length
=
max
(
len
(
input_ids
)
for
input_ids
in
batch_ids
),
)
return
mm_inputs
PLUGINS
=
{
"base"
:
BasePlugin
,
"llava"
:
LlavaPlugin
,
"llava_next"
:
LlavaNextPlugin
,
"llava_next_video"
:
LlavaNextVideoPlugin
,
"paligemma"
:
PaliGemmaPlugin
,
"pixtral"
:
PixtralPlugin
,
"qwen2_vl"
:
Qwen2vlPlugin
,
"video_llava"
:
VideoLlavaPlugin
,
"mllama"
:
MllamaPlugin
,
}
...
...
@@ -622,6 +782,6 @@ def get_mm_plugin(
)
->
"BasePlugin"
:
plugin_class
=
PLUGINS
.
get
(
name
,
None
)
if
plugin_class
is
None
:
raise
ValueError
(
"Multimodal plugin `{}` not found."
.
format
(
name
)
)
raise
ValueError
(
f
"Multimodal plugin `
{
name
}
` not found."
)
return
plugin_class
(
image_token
,
video_token
)
src/llamafactory/data/parser.py
View file @
2778a3d0
...
...
@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
from
transformers.utils
import
cached_file
from
..extras.constants
import
DATA_CONFIG
from
..extras.misc
import
use_modelscope
from
..extras.misc
import
use_modelscope
,
use_openmind
@
dataclass
...
...
@@ -30,7 +30,7 @@ class DatasetAttr:
"""
# basic configs
load_from
:
Literal
[
"hf_hub"
,
"ms_hub"
,
"script"
,
"file"
]
load_from
:
Literal
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
,
"script"
,
"file"
]
dataset_name
:
str
formatting
:
Literal
[
"alpaca"
,
"sharegpt"
]
=
"alpaca"
ranking
:
bool
=
False
...
...
@@ -87,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
config_path
=
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
)
try
:
with
open
(
config_path
,
"r"
)
as
f
:
with
open
(
config_path
)
as
f
:
dataset_info
=
json
.
load
(
f
)
except
Exception
as
err
:
if
len
(
dataset_names
)
!=
0
:
raise
ValueError
(
"Cannot open {
} due to {}."
.
format
(
config_path
,
str
(
err
)
)
)
raise
ValueError
(
f
"Cannot open
{
config_path
}
due to
{
str
(
err
)
}
."
)
dataset_info
=
None
dataset_list
:
List
[
"DatasetAttr"
]
=
[]
for
name
in
dataset_names
:
if
dataset_info
is
None
:
# dataset_dir is ONLINE
load_from
=
"ms_hub"
if
use_modelscope
()
else
"hf_hub"
if
use_modelscope
():
load_from
=
"ms_hub"
elif
use_openmind
():
load_from
=
"om_hub"
else
:
load_from
=
"hf_hub"
dataset_attr
=
DatasetAttr
(
load_from
,
dataset_name
=
name
)
dataset_list
.
append
(
dataset_attr
)
continue
if
name
not
in
dataset_info
:
raise
ValueError
(
"Undefined dataset {} in {
}."
.
format
(
name
,
DATA_CONFIG
)
)
raise
ValueError
(
f
"Undefined dataset
{
name
}
in
{
DATA_CONFIG
}
."
)
has_hf_url
=
"hf_hub_url"
in
dataset_info
[
name
]
has_ms_url
=
"ms_hub_url"
in
dataset_info
[
name
]
has_om_url
=
"om_hub_url"
in
dataset_info
[
name
]
if
has_hf_url
or
has_ms_url
:
if
(
use_modelscope
()
and
has_ms_url
)
or
(
not
has_hf_url
):
if
has_hf_url
or
has_ms_url
or
has_om_url
:
if
has_ms_url
and
(
use_modelscope
()
or
not
has_hf_url
):
dataset_attr
=
DatasetAttr
(
"ms_hub"
,
dataset_name
=
dataset_info
[
name
][
"ms_hub_url"
])
elif
has_om_url
and
(
use_openmind
()
or
not
has_hf_url
):
dataset_attr
=
DatasetAttr
(
"om_hub"
,
dataset_name
=
dataset_info
[
name
][
"om_hub_url"
])
else
:
dataset_attr
=
DatasetAttr
(
"hf_hub"
,
dataset_name
=
dataset_info
[
name
][
"hf_hub_url"
])
elif
"script_url"
in
dataset_info
[
name
]:
...
...
src/llamafactory/data/processors/feedback.py
View file @
2778a3d0
...
...
@@ -15,8 +15,8 @@
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
.processor_utils
import
infer_seqlen
...
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from
..template
import
Template
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_feedback_example
(
...
...
@@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
]))
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
=
_encode_feedback_example
(
...
...
@@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
desirable_num
=
sum
([
1
for
tag
in
model_inputs
[
"kto_tags"
]
if
tag
])
undesirable_num
=
len
(
model_inputs
[
"kto_tags"
])
-
desirable_num
if
desirable_num
==
0
or
undesirable_num
==
0
:
logger
.
warning
(
"Your dataset only has one preference type."
)
logger
.
warning
_rank0
(
"Your dataset only has one preference type."
)
return
model_inputs
src/llamafactory/data/processors/pairwise.py
View file @
2778a3d0
...
...
@@ -15,8 +15,8 @@
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
.processor_utils
import
infer_seqlen
...
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from
..template
import
Template
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_pairwise_example
(
...
...
@@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
]))
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
=
_encode_pairwise_example
(
...
...
@@ -110,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
print
(
"chosen_input_ids:
\n
{}"
.
format
(
example
[
"chosen_input_ids"
]))
print
(
"chosen_inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"chosen_input_ids"
],
skip_special_tokens
=
False
)))
print
(
"chosen_label_ids:
\n
{}"
.
format
(
example
[
"chosen_labels"
]))
print
(
"chosen_labels:
\n
{
}"
.
format
(
tokenizer
.
decode
(
valid_chosen_labels
,
skip_special_tokens
=
False
)
)
)
print
(
f
"chosen_labels:
\n
{
tokenizer
.
decode
(
valid_chosen_labels
,
skip_special_tokens
=
False
)
}
"
)
print
(
"rejected_input_ids:
\n
{}"
.
format
(
example
[
"rejected_input_ids"
]))
print
(
"rejected_inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"rejected_input_ids"
],
skip_special_tokens
=
False
)))
print
(
"rejected_label_ids:
\n
{}"
.
format
(
example
[
"rejected_labels"
]))
print
(
"rejected_labels:
\n
{
}"
.
format
(
tokenizer
.
decode
(
valid_rejected_labels
,
skip_special_tokens
=
False
)
)
)
print
(
f
"rejected_labels:
\n
{
tokenizer
.
decode
(
valid_rejected_labels
,
skip_special_tokens
=
False
)
}
"
)
src/llamafactory/data/processors/supervised.py
View file @
2778a3d0
...
...
@@ -15,8 +15,8 @@
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
.processor_utils
import
greedy_knapsack
,
infer_seqlen
...
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from
..template
import
Template
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_supervised_example
(
...
...
@@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
]))
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
_encode_supervised_example
(
...
...
@@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
length2indexes
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
]))
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
input_ids
,
labels
=
_encode_supervised_example
(
...
...
@@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
)
length
=
len
(
input_ids
)
if
length
>
data_args
.
cutoff_len
:
logger
.
warning
(
"Dropped lengthy example with length {
} > {}."
.
format
(
length
,
data_args
.
cutoff_len
)
)
logger
.
warning
_rank0
(
f
"Dropped lengthy example with length
{
length
}
>
{
data_args
.
cutoff_len
}
."
)
else
:
lengths
.
append
(
length
)
length2indexes
[
length
].
append
(
valid_num
)
...
...
@@ -212,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
"labels:
\n
{
}"
.
format
(
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)
)
)
print
(
f
"labels:
\n
{
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)
}
"
)
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment