Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
2778a3d0
Commit
2778a3d0
authored
Jan 16, 2025
by
luopl
Browse files
updata to v0.9.1_stable
parent
e92143e3
Changes
172
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
498 additions
and
259 deletions
+498
-259
setup.py
setup.py
+6
-5
src/api.py
src/api.py
+3
-3
src/llamafactory/__init__.py
src/llamafactory/__init__.py
+6
-5
src/llamafactory/api/app.py
src/llamafactory/api/app.py
+6
-6
src/llamafactory/api/chat.py
src/llamafactory/api/chat.py
+15
-15
src/llamafactory/chat/base_engine.py
src/llamafactory/chat/base_engine.py
+4
-4
src/llamafactory/chat/chat_model.py
src/llamafactory/chat/chat_model.py
+13
-13
src/llamafactory/chat/hf_engine.py
src/llamafactory/chat/hf_engine.py
+55
-31
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+40
-26
src/llamafactory/cli.py
src/llamafactory/cli.py
+15
-14
src/llamafactory/data/aligner.py
src/llamafactory/data/aligner.py
+22
-16
src/llamafactory/data/collator.py
src/llamafactory/data/collator.py
+9
-6
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+5
-5
src/llamafactory/data/formatter.py
src/llamafactory/data/formatter.py
+5
-5
src/llamafactory/data/loader.py
src/llamafactory/data/loader.py
+36
-23
src/llamafactory/data/mm_plugin.py
src/llamafactory/data/mm_plugin.py
+219
-59
src/llamafactory/data/parser.py
src/llamafactory/data/parser.py
+16
-8
src/llamafactory/data/processors/feedback.py
src/llamafactory/data/processors/feedback.py
+6
-4
src/llamafactory/data/processors/pairwise.py
src/llamafactory/data/processors/pairwise.py
+7
-5
src/llamafactory/data/processors/supervised.py
src/llamafactory/data/processors/supervised.py
+10
-6
No files found.
setup.py
View file @
2778a3d0
...
@@ -20,7 +20,7 @@ from setuptools import find_packages, setup
...
@@ -20,7 +20,7 @@ from setuptools import find_packages, setup
def
get_version
()
->
str
:
def
get_version
()
->
str
:
with
open
(
os
.
path
.
join
(
"src"
,
"llamafactory"
,
"extras"
,
"env.py"
),
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
os
.
path
.
join
(
"src"
,
"llamafactory"
,
"extras"
,
"env.py"
),
encoding
=
"utf-8"
)
as
f
:
file_content
=
f
.
read
()
file_content
=
f
.
read
()
pattern
=
r
"{}\W*=\W*\"([^\"]+)\""
.
format
(
"VERSION"
)
pattern
=
r
"{}\W*=\W*\"([^\"]+)\""
.
format
(
"VERSION"
)
(
version
,)
=
re
.
findall
(
pattern
,
file_content
)
(
version
,)
=
re
.
findall
(
pattern
,
file_content
)
...
@@ -28,7 +28,7 @@ def get_version() -> str:
...
@@ -28,7 +28,7 @@ def get_version() -> str:
def
get_requires
()
->
List
[
str
]:
def
get_requires
()
->
List
[
str
]:
with
open
(
"requirements.txt"
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
"requirements.txt"
,
encoding
=
"utf-8"
)
as
f
:
file_content
=
f
.
read
()
file_content
=
f
.
read
()
lines
=
[
line
.
strip
()
for
line
in
file_content
.
strip
().
split
(
"
\n
"
)
if
not
line
.
startswith
(
"#"
)]
lines
=
[
line
.
strip
()
for
line
in
file_content
.
strip
().
split
(
"
\n
"
)
if
not
line
.
startswith
(
"#"
)]
return
lines
return
lines
...
@@ -54,13 +54,14 @@ extra_require = {
...
@@ -54,13 +54,14 @@ extra_require = {
"gptq"
:
[
"optimum>=1.17.0"
,
"auto-gptq>=0.5.0"
],
"gptq"
:
[
"optimum>=1.17.0"
,
"auto-gptq>=0.5.0"
],
"awq"
:
[
"autoawq"
],
"awq"
:
[
"autoawq"
],
"aqlm"
:
[
"aqlm[gpu]>=1.1.0"
],
"aqlm"
:
[
"aqlm[gpu]>=1.1.0"
],
"vllm"
:
[
"vllm>=0.4.3,<
=
0.6.
2
"
],
"vllm"
:
[
"vllm>=0.4.3,<0.6.
4
"
],
"galore"
:
[
"galore-torch"
],
"galore"
:
[
"galore-torch"
],
"badam"
:
[
"badam>=1.2.1"
],
"badam"
:
[
"badam>=1.2.1"
],
"adam-mini"
:
[
"adam-mini"
],
"adam-mini"
:
[
"adam-mini"
],
"qwen"
:
[
"transformers_stream_generator"
],
"qwen"
:
[
"transformers_stream_generator"
],
"modelscope"
:
[
"modelscope"
],
"modelscope"
:
[
"modelscope"
],
"dev"
:
[
"ruff"
,
"pytest"
],
"openmind"
:
[
"openmind"
],
"dev"
:
[
"pre-commit"
,
"ruff"
,
"pytest"
],
}
}
...
@@ -71,7 +72,7 @@ def main():
...
@@ -71,7 +72,7 @@ def main():
author
=
"hiyouga"
,
author
=
"hiyouga"
,
author_email
=
"hiyouga"
"@"
"buaa.edu.cn"
,
author_email
=
"hiyouga"
"@"
"buaa.edu.cn"
,
description
=
"Easy-to-use LLM fine-tuning framework"
,
description
=
"Easy-to-use LLM fine-tuning framework"
,
long_description
=
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
).
read
(),
long_description
=
open
(
"README.md"
,
encoding
=
"utf-8"
).
read
(),
long_description_content_type
=
"text/markdown"
,
long_description_content_type
=
"text/markdown"
,
keywords
=
[
"LLaMA"
,
"BLOOM"
,
"Falcon"
,
"LLM"
,
"ChatGPT"
,
"transformer"
,
"pytorch"
,
"deep learning"
],
keywords
=
[
"LLaMA"
,
"BLOOM"
,
"Falcon"
,
"LLM"
,
"ChatGPT"
,
"transformer"
,
"pytorch"
,
"deep learning"
],
license
=
"Apache 2.0 License"
,
license
=
"Apache 2.0 License"
,
...
...
src/api.py
View file @
2778a3d0
...
@@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
...
@@ -23,9 +23,9 @@ from llamafactory.chat import ChatModel
def
main
():
def
main
():
chat_model
=
ChatModel
()
chat_model
=
ChatModel
()
app
=
create_app
(
chat_model
)
app
=
create_app
(
chat_model
)
api_host
=
os
.
environ
.
get
(
"API_HOST"
,
"0.0.0.0"
)
api_host
=
os
.
get
env
(
"API_HOST"
,
"0.0.0.0"
)
api_port
=
int
(
os
.
environ
.
get
(
"API_PORT"
,
"8000"
))
api_port
=
int
(
os
.
get
env
(
"API_PORT"
,
"8000"
))
print
(
"Visit http://localhost:{}/docs for API document."
.
format
(
api_port
)
)
print
(
f
"Visit http://localhost:
{
api_port
}
/docs for API document."
)
uvicorn
.
run
(
app
,
host
=
api_host
,
port
=
api_port
)
uvicorn
.
run
(
app
,
host
=
api_host
,
port
=
api_port
)
...
...
src/llamafactory/__init__.py
View file @
2778a3d0
...
@@ -20,17 +20,17 @@ Level:
...
@@ -20,17 +20,17 @@ Level:
Dependency graph:
Dependency graph:
main:
main:
transformers>=4.41.2,<=4.4
5.2
transformers>=4.41.2,<=4.4
6.1
datasets>=2.16.0,<=
2.2
1.0
datasets>=2.16.0,<=
3.
1.0
accelerate>=0.3
0.1,<=0.34.2
accelerate>=0.3
4.0,<=1.0.1
peft>=0.11.1,<=0.12.0
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
trl>=0.8.6,<=0.9.6
attention:
attention:
transformers>=4.42.4 (gemma+fa2)
transformers>=4.42.4 (gemma+fa2)
longlora:
longlora:
transformers>=4.41.2,<=4.4
5.2
transformers>=4.41.2,<=4.4
6.1
packing:
packing:
transformers>=4.41.2,<=4.4
5.2
transformers>=4.41.2,<=4.4
6.1
Disable version checking: DISABLE_VERSION_CHECK=1
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
Enable VRAM recording: RECORD_VRAM=1
...
@@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1
...
@@ -38,6 +38,7 @@ Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1
Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1
Use modelscope: USE_MODELSCOPE_HUB=1
Use openmind: USE_OPENMIND_HUB=1
"""
"""
from
.extras.env
import
VERSION
from
.extras.env
import
VERSION
...
...
src/llamafactory/api/app.py
View file @
2778a3d0
...
@@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
...
@@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
def
create_app
(
chat_model
:
"ChatModel"
)
->
"FastAPI"
:
def
create_app
(
chat_model
:
"ChatModel"
)
->
"FastAPI"
:
root_path
=
os
.
environ
.
get
(
"FASTAPI_ROOT_PATH"
,
""
)
root_path
=
os
.
get
env
(
"FASTAPI_ROOT_PATH"
,
""
)
app
=
FastAPI
(
lifespan
=
partial
(
lifespan
,
chat_model
=
chat_model
),
root_path
=
root_path
)
app
=
FastAPI
(
lifespan
=
partial
(
lifespan
,
chat_model
=
chat_model
),
root_path
=
root_path
)
app
.
add_middleware
(
app
.
add_middleware
(
CORSMiddleware
,
CORSMiddleware
,
...
@@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
...
@@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods
=
[
"*"
],
allow_methods
=
[
"*"
],
allow_headers
=
[
"*"
],
allow_headers
=
[
"*"
],
)
)
api_key
=
os
.
environ
.
get
(
"API_KEY"
,
None
)
api_key
=
os
.
get
env
(
"API_KEY"
)
security
=
HTTPBearer
(
auto_error
=
False
)
security
=
HTTPBearer
(
auto_error
=
False
)
async
def
verify_api_key
(
auth
:
Annotated
[
Optional
[
HTTPAuthorizationCredentials
],
Depends
(
security
)]):
async
def
verify_api_key
(
auth
:
Annotated
[
Optional
[
HTTPAuthorizationCredentials
],
Depends
(
security
)]):
...
@@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
...
@@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
dependencies
=
[
Depends
(
verify_api_key
)],
dependencies
=
[
Depends
(
verify_api_key
)],
)
)
async
def
list_models
():
async
def
list_models
():
model_card
=
ModelCard
(
id
=
os
.
environ
.
get
(
"API_MODEL_NAME"
,
"gpt-3.5-turbo"
))
model_card
=
ModelCard
(
id
=
os
.
get
env
(
"API_MODEL_NAME"
,
"gpt-3.5-turbo"
))
return
ModelList
(
data
=
[
model_card
])
return
ModelList
(
data
=
[
model_card
])
@
app
.
post
(
@
app
.
post
(
...
@@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
...
@@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
def
run_api
()
->
None
:
def
run_api
()
->
None
:
chat_model
=
ChatModel
()
chat_model
=
ChatModel
()
app
=
create_app
(
chat_model
)
app
=
create_app
(
chat_model
)
api_host
=
os
.
environ
.
get
(
"API_HOST"
,
"0.0.0.0"
)
api_host
=
os
.
get
env
(
"API_HOST"
,
"0.0.0.0"
)
api_port
=
int
(
os
.
environ
.
get
(
"API_PORT"
,
"8000"
))
api_port
=
int
(
os
.
get
env
(
"API_PORT"
,
"8000"
))
print
(
"Visit http://localhost:{}/docs for API document."
.
format
(
api_port
)
)
print
(
f
"Visit http://localhost:
{
api_port
}
/docs for API document."
)
uvicorn
.
run
(
app
,
host
=
api_host
,
port
=
api_port
)
uvicorn
.
run
(
app
,
host
=
api_host
,
port
=
api_port
)
src/llamafactory/api/chat.py
View file @
2778a3d0
...
@@ -21,7 +21,7 @@ import uuid
...
@@ -21,7 +21,7 @@ import uuid
from
typing
import
TYPE_CHECKING
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
from
..data
import
Role
as
DataRole
from
..data
import
Role
as
DataRole
from
..extras
.logging
import
get_
logg
er
from
..extras
import
logg
ing
from
..extras.packages
import
is_fastapi_available
,
is_pillow_available
,
is_requests_available
from
..extras.packages
import
is_fastapi_available
,
is_pillow_available
,
is_requests_available
from
.common
import
dictify
,
jsonify
from
.common
import
dictify
,
jsonify
from
.protocol
import
(
from
.protocol
import
(
...
@@ -57,7 +57,7 @@ if TYPE_CHECKING:
...
@@ -57,7 +57,7 @@ if TYPE_CHECKING:
from
.protocol
import
ChatCompletionRequest
,
ScoreEvaluationRequest
from
.protocol
import
ChatCompletionRequest
,
ScoreEvaluationRequest
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
ROLE_MAPPING
=
{
ROLE_MAPPING
=
{
Role
.
USER
:
DataRole
.
USER
.
value
,
Role
.
USER
:
DataRole
.
USER
.
value
,
Role
.
ASSISTANT
:
DataRole
.
ASSISTANT
.
value
,
Role
.
ASSISTANT
:
DataRole
.
ASSISTANT
.
value
,
...
@@ -69,8 +69,8 @@ ROLE_MAPPING = {
...
@@ -69,8 +69,8 @@ ROLE_MAPPING = {
def
_process_request
(
def
_process_request
(
request
:
"ChatCompletionRequest"
,
request
:
"ChatCompletionRequest"
,
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
Optional
[
str
],
Optional
[
str
],
Optional
[
"ImageInput"
]]:
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
Optional
[
str
],
Optional
[
str
],
Optional
[
List
[
"ImageInput"
]]
]
:
logger
.
info
(
"==== request ====
\n
{
}"
.
format
(
json
.
dumps
(
dictify
(
request
),
indent
=
2
,
ensure_ascii
=
False
)
)
)
logger
.
info
_rank0
(
f
"==== request ====
\n
{
json
.
dumps
(
dictify
(
request
),
indent
=
2
,
ensure_ascii
=
False
)
}
"
)
if
len
(
request
.
messages
)
==
0
:
if
len
(
request
.
messages
)
==
0
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid length"
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid length"
)
...
@@ -84,7 +84,7 @@ def _process_request(
...
@@ -84,7 +84,7 @@ def _process_request(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Only supports u/a/u/a/u..."
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Only supports u/a/u/a/u..."
)
input_messages
=
[]
input_messages
=
[]
image
=
None
image
s
=
[]
for
i
,
message
in
enumerate
(
request
.
messages
):
for
i
,
message
in
enumerate
(
request
.
messages
):
if
i
%
2
==
0
and
message
.
role
not
in
[
Role
.
USER
,
Role
.
TOOL
]:
if
i
%
2
==
0
and
message
.
role
not
in
[
Role
.
USER
,
Role
.
TOOL
]:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid role"
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid role"
)
...
@@ -111,7 +111,7 @@ def _process_request(
...
@@ -111,7 +111,7 @@ def _process_request(
else
:
# web uri
else
:
# web uri
image_stream
=
requests
.
get
(
image_url
,
stream
=
True
).
raw
image_stream
=
requests
.
get
(
image_url
,
stream
=
True
).
raw
image
=
Image
.
open
(
image_stream
).
convert
(
"RGB"
)
image
s
.
append
(
Image
.
open
(
image_stream
).
convert
(
"RGB"
)
)
else
:
else
:
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
message
.
content
})
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
message
.
content
})
...
@@ -124,7 +124,7 @@ def _process_request(
...
@@ -124,7 +124,7 @@ def _process_request(
else
:
else
:
tools
=
None
tools
=
None
return
input_messages
,
system
,
tools
,
image
return
input_messages
,
system
,
tools
,
image
s
or
None
def
_create_stream_chat_completion_chunk
(
def
_create_stream_chat_completion_chunk
(
...
@@ -142,13 +142,13 @@ def _create_stream_chat_completion_chunk(
...
@@ -142,13 +142,13 @@ def _create_stream_chat_completion_chunk(
async
def
create_chat_completion_response
(
async
def
create_chat_completion_response
(
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
)
->
"ChatCompletionResponse"
:
)
->
"ChatCompletionResponse"
:
completion_id
=
"chatcmpl-{
}"
.
format
(
uuid
.
uuid4
().
hex
)
completion_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
input_messages
,
system
,
tools
,
image
=
_process_request
(
request
)
input_messages
,
system
,
tools
,
image
s
=
_process_request
(
request
)
responses
=
await
chat_model
.
achat
(
responses
=
await
chat_model
.
achat
(
input_messages
,
input_messages
,
system
,
system
,
tools
,
tools
,
image
,
image
s
,
do_sample
=
request
.
do_sample
,
do_sample
=
request
.
do_sample
,
temperature
=
request
.
temperature
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
top_p
=
request
.
top_p
,
...
@@ -169,7 +169,7 @@ async def create_chat_completion_response(
...
@@ -169,7 +169,7 @@ async def create_chat_completion_response(
tool_calls
=
[]
tool_calls
=
[]
for
tool
in
result
:
for
tool
in
result
:
function
=
Function
(
name
=
tool
[
0
],
arguments
=
tool
[
1
])
function
=
Function
(
name
=
tool
[
0
],
arguments
=
tool
[
1
])
tool_calls
.
append
(
FunctionCall
(
id
=
"call_{
}"
.
format
(
uuid
.
uuid4
().
hex
)
,
function
=
function
))
tool_calls
.
append
(
FunctionCall
(
id
=
f
"call_
{
uuid
.
uuid4
().
hex
}
"
,
function
=
function
))
response_message
=
ChatCompletionMessage
(
role
=
Role
.
ASSISTANT
,
tool_calls
=
tool_calls
)
response_message
=
ChatCompletionMessage
(
role
=
Role
.
ASSISTANT
,
tool_calls
=
tool_calls
)
finish_reason
=
Finish
.
TOOL
finish_reason
=
Finish
.
TOOL
...
@@ -193,8 +193,8 @@ async def create_chat_completion_response(
...
@@ -193,8 +193,8 @@ async def create_chat_completion_response(
async
def
create_stream_chat_completion_response
(
async
def
create_stream_chat_completion_response
(
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
completion_id
=
"chatcmpl-{
}"
.
format
(
uuid
.
uuid4
().
hex
)
completion_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
input_messages
,
system
,
tools
,
image
=
_process_request
(
request
)
input_messages
,
system
,
tools
,
image
s
=
_process_request
(
request
)
if
tools
:
if
tools
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Cannot stream function calls."
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Cannot stream function calls."
)
...
@@ -208,7 +208,7 @@ async def create_stream_chat_completion_response(
...
@@ -208,7 +208,7 @@ async def create_stream_chat_completion_response(
input_messages
,
input_messages
,
system
,
system
,
tools
,
tools
,
image
,
image
s
,
do_sample
=
request
.
do_sample
,
do_sample
=
request
.
do_sample
,
temperature
=
request
.
temperature
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
top_p
=
request
.
top_p
,
...
@@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
...
@@ -229,7 +229,7 @@ async def create_stream_chat_completion_response(
async
def
create_score_evaluation_response
(
async
def
create_score_evaluation_response
(
request
:
"ScoreEvaluationRequest"
,
chat_model
:
"ChatModel"
request
:
"ScoreEvaluationRequest"
,
chat_model
:
"ChatModel"
)
->
"ScoreEvaluationResponse"
:
)
->
"ScoreEvaluationResponse"
:
score_id
=
"scoreval-{
}"
.
format
(
uuid
.
uuid4
().
hex
)
score_id
=
f
"scoreval-
{
uuid
.
uuid4
().
hex
}
"
if
len
(
request
.
messages
)
==
0
:
if
len
(
request
.
messages
)
==
0
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid request"
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid request"
)
...
...
src/llamafactory/chat/base_engine.py
View file @
2778a3d0
...
@@ -66,8 +66,8 @@ class BaseEngine(ABC):
...
@@ -66,8 +66,8 @@ class BaseEngine(ABC):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
)
->
List
[
"Response"
]:
r
"""
r
"""
...
@@ -81,8 +81,8 @@ class BaseEngine(ABC):
...
@@ -81,8 +81,8 @@ class BaseEngine(ABC):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
r
"""
r
"""
...
...
src/llamafactory/chat/chat_model.py
View file @
2778a3d0
...
@@ -53,7 +53,7 @@ class ChatModel:
...
@@ -53,7 +53,7 @@ class ChatModel:
elif
model_args
.
infer_backend
==
"vllm"
:
elif
model_args
.
infer_backend
==
"vllm"
:
self
.
engine
:
"BaseEngine"
=
VllmEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
self
.
engine
:
"BaseEngine"
=
VllmEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
else
:
else
:
raise
NotImplementedError
(
"Unknown backend: {
}"
.
format
(
model_args
.
infer_backend
)
)
raise
NotImplementedError
(
f
"Unknown backend:
{
model_args
.
infer_backend
}
"
)
self
.
_loop
=
asyncio
.
new_event_loop
()
self
.
_loop
=
asyncio
.
new_event_loop
()
self
.
_thread
=
Thread
(
target
=
_start_background_loop
,
args
=
(
self
.
_loop
,),
daemon
=
True
)
self
.
_thread
=
Thread
(
target
=
_start_background_loop
,
args
=
(
self
.
_loop
,),
daemon
=
True
)
...
@@ -64,15 +64,15 @@ class ChatModel:
...
@@ -64,15 +64,15 @@ class ChatModel:
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
)
->
List
[
"Response"
]:
r
"""
r
"""
Gets a list of responses of the chat model.
Gets a list of responses of the chat model.
"""
"""
task
=
asyncio
.
run_coroutine_threadsafe
(
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
achat
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
),
self
.
_loop
self
.
achat
(
messages
,
system
,
tools
,
image
s
,
video
s
,
**
input_kwargs
),
self
.
_loop
)
)
return
task
.
result
()
return
task
.
result
()
...
@@ -81,28 +81,28 @@ class ChatModel:
...
@@ -81,28 +81,28 @@ class ChatModel:
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
)
->
List
[
"Response"
]:
r
"""
r
"""
Asynchronously gets a list of responses of the chat model.
Asynchronously gets a list of responses of the chat model.
"""
"""
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
)
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
image
s
,
video
s
,
**
input_kwargs
)
def
stream_chat
(
def
stream_chat
(
self
,
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
Generator
[
str
,
None
,
None
]:
)
->
Generator
[
str
,
None
,
None
]:
r
"""
r
"""
Gets the response token-by-token of the chat model.
Gets the response token-by-token of the chat model.
"""
"""
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
)
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
image
s
,
video
s
,
**
input_kwargs
)
while
True
:
while
True
:
try
:
try
:
task
=
asyncio
.
run_coroutine_threadsafe
(
generator
.
__anext__
(),
self
.
_loop
)
task
=
asyncio
.
run_coroutine_threadsafe
(
generator
.
__anext__
(),
self
.
_loop
)
...
@@ -115,14 +115,14 @@ class ChatModel:
...
@@ -115,14 +115,14 @@ class ChatModel:
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
r
"""
r
"""
Asynchronously gets the response token-by-token of the chat model.
Asynchronously gets the response token-by-token of the chat model.
"""
"""
async
for
new_token
in
self
.
engine
.
stream_chat
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
):
async
for
new_token
in
self
.
engine
.
stream_chat
(
messages
,
system
,
tools
,
image
s
,
video
s
,
**
input_kwargs
):
yield
new_token
yield
new_token
def
get_scores
(
def
get_scores
(
...
...
src/llamafactory/chat/hf_engine.py
View file @
2778a3d0
...
@@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
...
@@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
from
typing_extensions
import
override
from
typing_extensions
import
override
from
..data
import
get_template_and_fix_tokenizer
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.constants
import
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.logging
import
get_logger
from
..extras.misc
import
get_logits_processor
from
..extras.misc
import
get_logits_processor
from
..model
import
load_model
,
load_tokenizer
from
..model
import
load_model
,
load_tokenizer
from
.base_engine
import
BaseEngine
,
Response
from
.base_engine
import
BaseEngine
,
Response
...
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
...
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
HuggingfaceEngine
(
BaseEngine
):
class
HuggingfaceEngine
(
BaseEngine
):
...
@@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
try
:
try
:
asyncio
.
get_event_loop
()
asyncio
.
get_event_loop
()
except
RuntimeError
:
except
RuntimeError
:
logger
.
warning
(
"There is no current event loop, creating a new one."
)
logger
.
warning
_once
(
"There is no current event loop, creating a new one."
)
loop
=
asyncio
.
new_event_loop
()
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
asyncio
.
set_event_loop
(
loop
)
self
.
semaphore
=
asyncio
.
Semaphore
(
int
(
os
.
environ
.
get
(
"MAX_CONCURRENT"
,
"1"
)))
self
.
semaphore
=
asyncio
.
Semaphore
(
int
(
os
.
get
env
(
"MAX_CONCURRENT"
,
"1"
)))
@
staticmethod
@
staticmethod
def
_process_args
(
def
_process_args
(
...
@@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -79,20 +79,20 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
Tuple
[
Dict
[
str
,
Any
],
int
]:
)
->
Tuple
[
Dict
[
str
,
Any
],
int
]:
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
]}
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
]}
if
image
is
not
None
:
if
image
s
is
not
None
:
mm_input_dict
.
update
({
"images"
:
[
image
]
,
"imglens"
:
[
1
]})
mm_input_dict
.
update
({
"images"
:
image
s
,
"imglens"
:
[
len
(
images
)
]})
if
IMAGE_PLACEHOLDER
not
in
message
s
[
0
]
[
"content"
]:
if
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
)
:
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
+
messages
[
0
][
"content"
]
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
*
len
(
images
)
+
messages
[
0
][
"content"
]
if
video
is
not
None
:
if
video
s
is
not
None
:
mm_input_dict
.
update
({
"videos"
:
[
video
]
,
"vidlens"
:
[
1
]})
mm_input_dict
.
update
({
"videos"
:
video
s
,
"vidlens"
:
[
len
(
videos
)
]})
if
VIDEO_PLACEHOLDER
not
in
message
s
[
0
]
[
"content"
]:
if
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
)
:
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
+
messages
[
0
][
"content"
]
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
*
len
(
videos
)
+
messages
[
0
][
"content"
]
messages
=
template
.
mm_plugin
.
process_messages
(
messages
=
template
.
mm_plugin
.
process_messages
(
messages
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
processor
messages
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
processor
...
@@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
stop
is
not
None
:
if
stop
is
not
None
:
logger
.
warning
(
"Stop parameter is not supported by the huggingface engine yet."
)
logger
.
warning
_rank0
(
"Stop parameter is not supported by the huggingface engine yet."
)
generating_args
=
generating_args
.
copy
()
generating_args
=
generating_args
.
copy
()
generating_args
.
update
(
generating_args
.
update
(
...
@@ -164,9 +164,13 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -164,9 +164,13 @@ class HuggingfaceEngine(BaseEngine):
logits_processor
=
get_logits_processor
(),
logits_processor
=
get_logits_processor
(),
)
)
mm_inputs
=
template
.
mm_plugin
.
get_mm_inputs
(
**
mm_input_dict
,
seqlen
s
=
[
prompt_
length
],
processor
=
processor
)
mm_inputs
=
template
.
mm_plugin
.
get_mm_inputs
(
**
mm_input_dict
,
batch_id
s
=
[
prompt_
ids
],
processor
=
processor
)
for
key
,
value
in
mm_inputs
.
items
():
for
key
,
value
in
mm_inputs
.
items
():
value
=
value
if
isinstance
(
value
,
torch
.
Tensor
)
else
torch
.
tensor
(
value
)
if
isinstance
(
value
,
list
)
and
all
(
isinstance
(
v
,
torch
.
Tensor
)
for
v
in
value
):
# for pixtral inputs
value
=
torch
.
stack
(
value
)
# assume they have same sizes
elif
not
isinstance
(
value
,
torch
.
Tensor
):
value
=
torch
.
tensor
(
value
)
gen_kwargs
[
key
]
=
value
.
to
(
model
.
device
)
gen_kwargs
[
key
]
=
value
.
to
(
model
.
device
)
return
gen_kwargs
,
prompt_length
return
gen_kwargs
,
prompt_length
...
@@ -182,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -182,12 +186,22 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
List
[
"Response"
]:
)
->
List
[
"Response"
]:
gen_kwargs
,
prompt_length
=
HuggingfaceEngine
.
_process_args
(
gen_kwargs
,
prompt_length
=
HuggingfaceEngine
.
_process_args
(
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
image
,
video
,
input_kwargs
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
images
,
videos
,
input_kwargs
,
)
)
generate_output
=
model
.
generate
(
**
gen_kwargs
)
generate_output
=
model
.
generate
(
**
gen_kwargs
)
response_ids
=
generate_output
[:,
prompt_length
:]
response_ids
=
generate_output
[:,
prompt_length
:]
...
@@ -218,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -218,12 +232,22 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
Callable
[[],
str
]:
)
->
Callable
[[],
str
]:
gen_kwargs
,
_
=
HuggingfaceEngine
.
_process_args
(
gen_kwargs
,
_
=
HuggingfaceEngine
.
_process_args
(
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
image
,
video
,
input_kwargs
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
images
,
videos
,
input_kwargs
,
)
)
streamer
=
TextIteratorStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
True
)
streamer
=
TextIteratorStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
True
)
gen_kwargs
[
"streamer"
]
=
streamer
gen_kwargs
[
"streamer"
]
=
streamer
...
@@ -266,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -266,8 +290,8 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
)
->
List
[
"Response"
]:
if
not
self
.
can_generate
:
if
not
self
.
can_generate
:
...
@@ -283,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -283,8 +307,8 @@ class HuggingfaceEngine(BaseEngine):
messages
,
messages
,
system
,
system
,
tools
,
tools
,
image
,
image
s
,
video
,
video
s
,
input_kwargs
,
input_kwargs
,
)
)
async
with
self
.
semaphore
:
async
with
self
.
semaphore
:
...
@@ -297,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -297,8 +321,8 @@ class HuggingfaceEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
if
not
self
.
can_generate
:
if
not
self
.
can_generate
:
...
@@ -314,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
...
@@ -314,8 +338,8 @@ class HuggingfaceEngine(BaseEngine):
messages
,
messages
,
system
,
system
,
tools
,
tools
,
image
,
image
s
,
video
,
video
s
,
input_kwargs
,
input_kwargs
,
)
)
async
with
self
.
semaphore
:
async
with
self
.
semaphore
:
...
...
src/llamafactory/chat/vllm_engine.py
View file @
2778a3d0
...
@@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
...
@@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
from
typing_extensions
import
override
from
typing_extensions
import
override
from
..data
import
get_template_and_fix_tokenizer
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
IMAGE_PLACEHOLDER
from
..extras.constants
import
IMAGE_PLACEHOLDER
from
..extras.logging
import
get_logger
from
..extras.misc
import
get_device_count
from
..extras.misc
import
get_device_count
from
..extras.packages
import
is_pillow_available
,
is_vllm_available
from
..extras.packages
import
is_pillow_available
,
is_vllm_available
from
..model
import
load_config
,
load_tokenizer
from
..model
import
load_config
,
load_tokenizer
...
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
...
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
class
VllmEngine
(
BaseEngine
):
class
VllmEngine
(
BaseEngine
):
...
@@ -83,11 +83,13 @@ class VllmEngine(BaseEngine):
...
@@ -83,11 +83,13 @@ class VllmEngine(BaseEngine):
"enable_lora"
:
model_args
.
adapter_name_or_path
is
not
None
,
"enable_lora"
:
model_args
.
adapter_name_or_path
is
not
None
,
"max_lora_rank"
:
model_args
.
vllm_max_lora_rank
,
"max_lora_rank"
:
model_args
.
vllm_max_lora_rank
,
}
}
if
isinstance
(
model_args
.
vllm_config
,
dict
):
engine_args
.
update
(
model_args
.
vllm_config
)
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
import
vllm.model_executor.models.llava
import
vllm.model_executor.models.llava
logger
.
info
(
"Detected Yi-VL model, applying projector patch."
)
logger
.
info
_rank0
(
"Detected Yi-VL model, applying projector patch."
)
vllm
.
model_executor
.
models
.
llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVLForVLLM
vllm
.
model_executor
.
models
.
llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVLForVLLM
self
.
model
=
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
**
engine_args
))
self
.
model
=
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
**
engine_args
))
...
@@ -101,21 +103,28 @@ class VllmEngine(BaseEngine):
...
@@ -101,21 +103,28 @@ class VllmEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncIterator
[
"RequestOutput"
]:
)
->
AsyncIterator
[
"RequestOutput"
]:
request_id
=
"chatcmpl-{
}"
.
format
(
uuid
.
uuid4
().
hex
)
request_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
if
image
is
not
None
:
if
image
s
is
not
None
:
if
IMAGE_PLACEHOLDER
not
in
message
s
[
0
]
[
"content"
]:
if
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
)
:
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
+
messages
[
0
][
"content"
]
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
*
len
(
images
)
+
messages
[
0
][
"content"
]
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
if
self
.
template
.
mm_plugin
.
__class__
.
__name__
==
"Qwen2vlPlugin"
:
# temporary solution
image_str
=
f
"<|vision_start|>
{
self
.
template
.
mm_plugin
.
image_token
}
<|vision_end|>"
else
:
image_str
=
self
.
template
.
mm_plugin
.
image_token
or
""
paired_messages
=
[
{
"role"
:
message
[
"role"
],
"content"
:
message
[
"content"
].
replace
(
IMAGE_PLACEHOLDER
,
image_str
)}
for
message
in
messages
]
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
self
.
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_length
=
len
(
prompt_ids
)
prompt_length
=
len
(
prompt_ids
)
use_beam_search
:
bool
=
self
.
generating_args
[
"num_beams"
]
>
1
temperature
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"temperature"
,
None
)
temperature
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"temperature"
,
None
)
top_p
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_p"
,
None
)
top_p
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_p"
,
None
)
top_k
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_k"
,
None
)
top_k
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_k"
,
None
)
...
@@ -126,6 +135,9 @@ class VllmEngine(BaseEngine):
...
@@ -126,6 +135,9 @@ class VllmEngine(BaseEngine):
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
length_penalty
is
not
None
:
logger
.
warning_rank0
(
"Length penalty is not supported by the vllm engine yet."
)
if
"max_new_tokens"
in
self
.
generating_args
:
if
"max_new_tokens"
in
self
.
generating_args
:
max_tokens
=
self
.
generating_args
[
"max_new_tokens"
]
max_tokens
=
self
.
generating_args
[
"max_new_tokens"
]
elif
"max_length"
in
self
.
generating_args
:
elif
"max_length"
in
self
.
generating_args
:
...
@@ -149,27 +161,29 @@ class VllmEngine(BaseEngine):
...
@@ -149,27 +161,29 @@ class VllmEngine(BaseEngine):
temperature
=
temperature
if
temperature
is
not
None
else
self
.
generating_args
[
"temperature"
],
temperature
=
temperature
if
temperature
is
not
None
else
self
.
generating_args
[
"temperature"
],
top_p
=
(
top_p
if
top_p
is
not
None
else
self
.
generating_args
[
"top_p"
])
or
1.0
,
# top_p must > 0
top_p
=
(
top_p
if
top_p
is
not
None
else
self
.
generating_args
[
"top_p"
])
or
1.0
,
# top_p must > 0
top_k
=
top_k
if
top_k
is
not
None
else
self
.
generating_args
[
"top_k"
],
top_k
=
top_k
if
top_k
is
not
None
else
self
.
generating_args
[
"top_k"
],
use_beam_search
=
use_beam_search
,
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
generating_args
[
"length_penalty"
],
stop
=
stop
,
stop
=
stop
,
stop_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
+
self
.
tokenizer
.
additional_special_tokens_ids
,
stop_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
+
self
.
tokenizer
.
additional_special_tokens_ids
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
skip_special_tokens
=
True
,
skip_special_tokens
=
True
,
)
)
if
image
is
not
None
:
# add image features
if
images
is
not
None
:
# add image features
if
not
isinstance
(
image
,
(
str
,
ImageObject
)):
image_data
=
[]
raise
ValueError
(
"Expected image input is a path or PIL.Image, but got {}."
.
format
(
type
(
image
)))
for
image
in
images
:
if
not
isinstance
(
image
,
(
str
,
ImageObject
)):
raise
ValueError
(
f
"Expected image input is a path or PIL.Image, but got
{
type
(
image
)
}
."
)
if
isinstance
(
image
,
str
):
image
=
Image
.
open
(
image
).
convert
(
"RGB"
)
if
isinstance
(
image
,
str
):
image_data
.
append
(
image
)
image
=
Image
.
open
(
image
).
convert
(
"RGB"
)
multi_modal_data
=
{
"image"
:
image
}
multi_modal_data
=
{
"image"
:
image
_data
}
else
:
else
:
multi_modal_data
=
None
multi_modal_data
=
None
result_generator
=
self
.
model
.
generate
(
result_generator
=
self
.
model
.
generate
(
inputs
=
{
"prompt_token_ids"
:
prompt_ids
,
"multi_modal_data"
:
multi_modal_data
},
{
"prompt_token_ids"
:
prompt_ids
,
"multi_modal_data"
:
multi_modal_data
},
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
request_id
=
request_id
,
lora_request
=
self
.
lora_request
,
lora_request
=
self
.
lora_request
,
...
@@ -182,12 +196,12 @@ class VllmEngine(BaseEngine):
...
@@ -182,12 +196,12 @@ class VllmEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
)
->
List
[
"Response"
]:
final_output
=
None
final_output
=
None
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
)
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
s
,
video
s
,
**
input_kwargs
)
async
for
request_output
in
generator
:
async
for
request_output
in
generator
:
final_output
=
request_output
final_output
=
request_output
...
@@ -210,12 +224,12 @@ class VllmEngine(BaseEngine):
...
@@ -210,12 +224,12 @@ class VllmEngine(BaseEngine):
messages
:
Sequence
[
Dict
[
str
,
str
]],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"ImageInput"
]
=
None
,
image
s
:
Optional
[
Sequence
[
"ImageInput"
]
]
=
None
,
video
:
Optional
[
"VideoInput"
]
=
None
,
video
s
:
Optional
[
Sequence
[
"VideoInput"
]
]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
generated_text
=
""
generated_text
=
""
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
,
video
,
**
input_kwargs
)
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
s
,
video
s
,
**
input_kwargs
)
async
for
result
in
generator
:
async
for
result
in
generator
:
delta_text
=
result
.
outputs
[
0
].
text
[
len
(
generated_text
)
:]
delta_text
=
result
.
outputs
[
0
].
text
[
len
(
generated_text
)
:]
generated_text
=
result
.
outputs
[
0
].
text
generated_text
=
result
.
outputs
[
0
].
text
...
...
src/llamafactory/cli.py
View file @
2778a3d0
...
@@ -22,8 +22,8 @@ from . import launcher
...
@@ -22,8 +22,8 @@ from . import launcher
from
.api.app
import
run_api
from
.api.app
import
run_api
from
.chat.chat_model
import
run_chat
from
.chat.chat_model
import
run_chat
from
.eval.evaluator
import
run_eval
from
.eval.evaluator
import
run_eval
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.env
import
VERSION
,
print_env
from
.extras.logging
import
get_logger
from
.extras.misc
import
get_device_count
from
.extras.misc
import
get_device_count
from
.train.tuner
import
export_model
,
run_exp
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
from
.webui.interface
import
run_web_demo
,
run_web_ui
...
@@ -47,7 +47,7 @@ USAGE = (
...
@@ -47,7 +47,7 @@ USAGE = (
WELCOME
=
(
WELCOME
=
(
"-"
*
58
"-"
*
58
+
"
\n
"
+
"
\n
"
+
"| Welcome to LLaMA Factory, version {
}"
.
format
(
VERSION
)
+
f
"| Welcome to LLaMA Factory, version
{
VERSION
}
"
+
" "
*
(
21
-
len
(
VERSION
))
+
" "
*
(
21
-
len
(
VERSION
))
+
"|
\n
|"
+
"|
\n
|"
+
" "
*
56
+
" "
*
56
...
@@ -56,7 +56,7 @@ WELCOME = (
...
@@ -56,7 +56,7 @@ WELCOME = (
+
"-"
*
58
+
"-"
*
58
)
)
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
@
unique
@
unique
...
@@ -86,25 +86,26 @@ def main():
...
@@ -86,25 +86,26 @@ def main():
elif
command
==
Command
.
EXPORT
:
elif
command
==
Command
.
EXPORT
:
export_model
()
export_model
()
elif
command
==
Command
.
TRAIN
:
elif
command
==
Command
.
TRAIN
:
force_torchrun
=
os
.
environ
.
get
(
"FORCE_TORCHRUN"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
force_torchrun
=
os
.
get
env
(
"FORCE_TORCHRUN"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
if
force_torchrun
or
get_device_count
()
>
1
:
if
force_torchrun
or
get_device_count
()
>
1
:
master_addr
=
os
.
environ
.
get
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_addr
=
os
.
get
env
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
environ
.
get
(
"MASTER_PORT"
,
str
(
random
.
randint
(
20001
,
29999
)))
master_port
=
os
.
get
env
(
"MASTER_PORT"
,
str
(
random
.
randint
(
20001
,
29999
)))
logger
.
info
(
"Initializing distributed tasks at: {
}:{}"
.
format
(
master_addr
,
master_port
)
)
logger
.
info
_rank0
(
f
"Initializing distributed tasks at:
{
master_addr
}
:
{
master_port
}
"
)
process
=
subprocess
.
run
(
process
=
subprocess
.
run
(
(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).
format
(
)
nnodes
=
os
.
environ
.
get
(
"NNODES"
,
"1"
),
.
format
(
node_rank
=
os
.
environ
.
get
(
"RANK"
,
"0"
),
nnodes
=
os
.
getenv
(
"NNODES"
,
"1"
),
nproc_per_node
=
os
.
environ
.
get
(
"NPROC_PER_NODE"
,
str
(
get_device_count
())),
node_rank
=
os
.
getenv
(
"NODE_RANK"
,
"0"
),
nproc_per_node
=
os
.
getenv
(
"NPROC_PER_NODE"
,
str
(
get_device_count
())),
master_addr
=
master_addr
,
master_addr
=
master_addr
,
master_port
=
master_port
,
master_port
=
master_port
,
file_name
=
launcher
.
__file__
,
file_name
=
launcher
.
__file__
,
args
=
" "
.
join
(
sys
.
argv
[
1
:]),
args
=
" "
.
join
(
sys
.
argv
[
1
:]),
)
,
)
shell
=
True
,
.
split
()
)
)
sys
.
exit
(
process
.
returncode
)
sys
.
exit
(
process
.
returncode
)
else
:
else
:
...
@@ -118,4 +119,4 @@ def main():
...
@@ -118,4 +119,4 @@ def main():
elif
command
==
Command
.
HELP
:
elif
command
==
Command
.
HELP
:
print
(
USAGE
)
print
(
USAGE
)
else
:
else
:
raise
NotImplementedError
(
"Unknown command: {
}."
.
format
(
command
)
)
raise
NotImplementedError
(
f
"Unknown command:
{
command
}
."
)
src/llamafactory/data/aligner.py
View file @
2778a3d0
...
@@ -16,7 +16,7 @@ import os
...
@@ -16,7 +16,7 @@ import os
from
functools
import
partial
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Union
from
..extras
.logging
import
get_
logg
er
from
..extras
import
logg
ing
from
.data_utils
import
Role
from
.data_utils
import
Role
...
@@ -29,45 +29,51 @@ if TYPE_CHECKING:
...
@@ -29,45 +29,51 @@ if TYPE_CHECKING:
from
.parser
import
DatasetAttr
from
.parser
import
DatasetAttr
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_convert_images
(
def
_convert_images
(
images
:
Sequence
[
"ImageInput"
],
images
:
Union
[
"ImageInput"
,
Sequence
[
"ImageInput"
]
]
,
dataset_attr
:
"DatasetAttr"
,
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
)
->
Optional
[
List
[
"ImageInput"
]]:
)
->
Optional
[
List
[
"ImageInput"
]]:
r
"""
r
"""
Optionally concatenates image path to dataset dir when loading from local disk.
Optionally concatenates image path to dataset dir when loading from local disk.
"""
"""
if
len
(
images
)
==
0
:
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
elif
len
(
images
)
==
0
:
return
None
return
None
else
:
images
=
images
[:]
images
=
images
[:]
if
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
if
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
for
i
in
range
(
len
(
images
)):
for
i
in
range
(
len
(
images
)):
if
isinstance
(
images
[
i
],
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
dataset
_dir
,
images
[
i
])):
if
isinstance
(
images
[
i
],
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
image
_dir
,
images
[
i
])):
images
[
i
]
=
os
.
path
.
join
(
data_args
.
dataset
_dir
,
images
[
i
])
images
[
i
]
=
os
.
path
.
join
(
data_args
.
image
_dir
,
images
[
i
])
return
images
return
images
def
_convert_videos
(
def
_convert_videos
(
videos
:
Sequence
[
"VideoInput"
],
videos
:
Union
[
"VideoInput"
,
Sequence
[
"VideoInput"
]
]
,
dataset_attr
:
"DatasetAttr"
,
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
data_args
:
"DataArguments"
,
)
->
Optional
[
List
[
"VideoInput"
]]:
)
->
Optional
[
List
[
"VideoInput"
]]:
r
"""
r
"""
Optionally concatenates video path to dataset dir when loading from local disk.
Optionally concatenates video path to dataset dir when loading from local disk.
"""
"""
if
len
(
videos
)
==
0
:
if
not
isinstance
(
videos
,
list
):
videos
=
[
videos
]
elif
len
(
videos
)
==
0
:
return
None
return
None
else
:
videos
=
videos
[:]
videos
=
videos
[:]
if
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
if
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
for
i
in
range
(
len
(
videos
)):
for
i
in
range
(
len
(
videos
)):
if
isinstance
(
videos
[
i
],
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
dataset
_dir
,
videos
[
i
])):
if
isinstance
(
videos
[
i
],
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
image
_dir
,
videos
[
i
])):
videos
[
i
]
=
os
.
path
.
join
(
data_args
.
dataset
_dir
,
videos
[
i
])
videos
[
i
]
=
os
.
path
.
join
(
data_args
.
image
_dir
,
videos
[
i
])
return
videos
return
videos
...
@@ -161,7 +167,7 @@ def convert_sharegpt(
...
@@ -161,7 +167,7 @@ def convert_sharegpt(
broken_data
=
False
broken_data
=
False
for
turn_idx
,
message
in
enumerate
(
messages
):
for
turn_idx
,
message
in
enumerate
(
messages
):
if
message
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
turn_idx
%
2
]:
if
message
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
turn_idx
%
2
]:
logger
.
warning
(
"Invalid role tag in {
}."
.
format
(
messages
)
)
logger
.
warning
_rank0
(
f
"Invalid role tag in
{
messages
}
."
)
broken_data
=
True
broken_data
=
True
aligned_messages
.
append
(
aligned_messages
.
append
(
...
@@ -171,7 +177,7 @@ def convert_sharegpt(
...
@@ -171,7 +177,7 @@ def convert_sharegpt(
if
(
not
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
!=
0
)
or
(
if
(
not
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
!=
0
)
or
(
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
==
0
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
==
0
):
):
logger
.
warning
(
"Invalid message count in {
}."
.
format
(
messages
)
)
logger
.
warning
_rank0
(
f
"Invalid message count in
{
messages
}
."
)
broken_data
=
True
broken_data
=
True
if
dataset_attr
.
kto_tag
and
isinstance
(
example
[
dataset_attr
.
kto_tag
],
bool
):
# kto example
if
dataset_attr
.
kto_tag
and
isinstance
(
example
[
dataset_attr
.
kto_tag
],
bool
):
# kto example
...
@@ -192,7 +198,7 @@ def convert_sharegpt(
...
@@ -192,7 +198,7 @@ def convert_sharegpt(
chosen
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
chosen
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
or
rejected
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
or
rejected
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
):
):
logger
.
warning
(
"Invalid role tag in {
}."
.
format
(
[
chosen
,
rejected
]
)
)
logger
.
warning
_rank0
(
f
"Invalid role tag in
{
[
chosen
,
rejected
]
}
."
)
broken_data
=
True
broken_data
=
True
prompt
=
aligned_messages
prompt
=
aligned_messages
...
@@ -205,7 +211,7 @@ def convert_sharegpt(
...
@@ -205,7 +211,7 @@ def convert_sharegpt(
response
=
aligned_messages
[
-
1
:]
response
=
aligned_messages
[
-
1
:]
if
broken_data
:
if
broken_data
:
logger
.
warning
(
"Skipping this abnormal example."
)
logger
.
warning
_rank0
(
"Skipping this abnormal example."
)
prompt
,
response
=
[],
[]
prompt
,
response
=
[],
[]
convert_images
=
partial
(
_convert_images
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
convert_images
=
partial
(
_convert_images
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
...
...
src/llamafactory/data/collator.py
View file @
2778a3d0
...
@@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
...
@@ -79,7 +79,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
processor
:
Optional
[
"ProcessorMixin"
]
=
None
processor
:
Optional
[
"ProcessorMixin"
]
=
None
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
"torch.Tensor"
]:
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
"torch.Tensor"
]:
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_
seqlen
s
=
[],
[],
[],
[],
[]
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_
input_id
s
=
[],
[],
[],
[],
[]
for
feature
in
features
:
for
feature
in
features
:
images
=
feature
.
pop
(
"images"
,
None
)
or
[]
images
=
feature
.
pop
(
"images"
,
None
)
or
[]
videos
=
feature
.
pop
(
"videos"
,
None
)
or
[]
videos
=
feature
.
pop
(
"videos"
,
None
)
or
[]
...
@@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
...
@@ -87,10 +87,10 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
batch_videos
.
extend
(
videos
)
batch_videos
.
extend
(
videos
)
batch_imglens
.
append
(
len
(
images
))
batch_imglens
.
append
(
len
(
images
))
batch_vidlens
.
append
(
len
(
videos
))
batch_vidlens
.
append
(
len
(
videos
))
batch_
seqlen
s
.
append
(
len
(
feature
[
"input_ids"
])
)
batch_
input_id
s
.
append
(
feature
[
"input_ids"
])
mm_inputs
=
self
.
template
.
mm_plugin
.
get_mm_inputs
(
mm_inputs
=
self
.
template
.
mm_plugin
.
get_mm_inputs
(
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_
seqlen
s
,
self
.
processor
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_
input_id
s
,
self
.
processor
)
)
if
"token_type_ids"
in
mm_inputs
:
if
"token_type_ids"
in
mm_inputs
:
token_type_ids
=
mm_inputs
.
pop
(
"token_type_ids"
)
token_type_ids
=
mm_inputs
.
pop
(
"token_type_ids"
)
...
@@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
...
@@ -99,6 +99,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features
:
Dict
[
str
,
"torch.Tensor"
]
=
super
().
__call__
(
features
)
features
:
Dict
[
str
,
"torch.Tensor"
]
=
super
().
__call__
(
features
)
features
.
update
(
mm_inputs
)
features
.
update
(
mm_inputs
)
if
isinstance
(
features
.
get
(
"pixel_values"
),
list
):
# for pixtral inputs
features
=
features
.
data
# use default_collate() instead of BatchEncoding.to()
return
features
return
features
...
@@ -137,9 +140,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
...
@@ -137,9 +140,9 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
for
key
in
(
"chosen"
,
"rejected"
):
for
key
in
(
"chosen"
,
"rejected"
):
for
feature
in
features
:
for
feature
in
features
:
target_feature
=
{
target_feature
=
{
"input_ids"
:
feature
[
"{}_input_ids"
.
format
(
key
)
],
"input_ids"
:
feature
[
f
"
{
key
}
_input_ids"
],
"attention_mask"
:
feature
[
"{}_attention_mask"
.
format
(
key
)
],
"attention_mask"
:
feature
[
f
"
{
key
}
_attention_mask"
],
"labels"
:
feature
[
"{}_labels"
.
format
(
key
)
],
"labels"
:
feature
[
f
"
{
key
}
_labels"
],
"images"
:
feature
[
"images"
],
"images"
:
feature
[
"images"
],
"videos"
:
feature
[
"videos"
],
"videos"
:
feature
[
"videos"
],
}
}
...
...
src/llamafactory/data/data_utils.py
View file @
2778a3d0
...
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
...
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
from
..extras
.logging
import
get_
logg
er
from
..extras
import
logg
ing
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
...
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from
..hparams
import
DataArguments
from
..hparams
import
DataArguments
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
SLOTS
=
Sequence
[
Union
[
str
,
Set
[
str
],
Dict
[
str
,
str
]]]
SLOTS
=
Sequence
[
Union
[
str
,
Set
[
str
],
Dict
[
str
,
str
]]]
...
@@ -56,12 +56,12 @@ def merge_dataset(
...
@@ -56,12 +56,12 @@ def merge_dataset(
return
all_datasets
[
0
]
return
all_datasets
[
0
]
elif
data_args
.
mix_strategy
==
"concat"
:
elif
data_args
.
mix_strategy
==
"concat"
:
if
data_args
.
streaming
:
if
data_args
.
streaming
:
logger
.
warning
(
"The samples between different datasets will not be mixed in streaming mode."
)
logger
.
warning
_once
(
"The samples between different datasets will not be mixed in streaming mode."
)
return
concatenate_datasets
(
all_datasets
)
return
concatenate_datasets
(
all_datasets
)
elif
data_args
.
mix_strategy
.
startswith
(
"interleave"
):
elif
data_args
.
mix_strategy
.
startswith
(
"interleave"
):
if
not
data_args
.
streaming
:
if
not
data_args
.
streaming
:
logger
.
warning
(
"We recommend using `mix_strategy=concat` in non-streaming mode."
)
logger
.
warning
_once
(
"We recommend using `mix_strategy=concat` in non-streaming mode."
)
return
interleave_datasets
(
return
interleave_datasets
(
datasets
=
all_datasets
,
datasets
=
all_datasets
,
...
@@ -70,7 +70,7 @@ def merge_dataset(
...
@@ -70,7 +70,7 @@ def merge_dataset(
stopping_strategy
=
"first_exhausted"
if
data_args
.
mix_strategy
.
endswith
(
"under"
)
else
"all_exhausted"
,
stopping_strategy
=
"first_exhausted"
if
data_args
.
mix_strategy
.
endswith
(
"under"
)
else
"all_exhausted"
,
)
)
else
:
else
:
raise
ValueError
(
"Unknown mixing strategy: {
}."
.
format
(
data_args
.
mix_strategy
)
)
raise
ValueError
(
f
"Unknown mixing strategy:
{
data_args
.
mix_strategy
}
."
)
def
split_dataset
(
def
split_dataset
(
...
...
src/llamafactory/data/formatter.py
View file @
2778a3d0
...
@@ -83,14 +83,14 @@ class StringFormatter(Formatter):
...
@@ -83,14 +83,14 @@ class StringFormatter(Formatter):
if
isinstance
(
slot
,
str
):
if
isinstance
(
slot
,
str
):
for
name
,
value
in
kwargs
.
items
():
for
name
,
value
in
kwargs
.
items
():
if
not
isinstance
(
value
,
str
):
if
not
isinstance
(
value
,
str
):
raise
RuntimeError
(
"Expected a string, got {
}"
.
format
(
value
)
)
raise
RuntimeError
(
f
"Expected a string, got
{
value
}
"
)
slot
=
slot
.
replace
(
"{{"
+
name
+
"}}"
,
value
,
1
)
slot
=
slot
.
replace
(
"{{"
+
name
+
"}}"
,
value
,
1
)
elements
.
append
(
slot
)
elements
.
append
(
slot
)
elif
isinstance
(
slot
,
(
dict
,
set
)):
elif
isinstance
(
slot
,
(
dict
,
set
)):
elements
.
append
(
slot
)
elements
.
append
(
slot
)
else
:
else
:
raise
RuntimeError
(
"Input must be string, set[str] or dict[str, str], got {
}"
.
format
(
type
(
slot
)
)
)
raise
RuntimeError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
slot
)
}
"
)
return
elements
return
elements
...
@@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
...
@@ -113,7 +113,7 @@ class FunctionFormatter(Formatter):
functions
.
append
((
tool_call
[
"name"
],
json
.
dumps
(
tool_call
[
"arguments"
],
ensure_ascii
=
False
)))
functions
.
append
((
tool_call
[
"name"
],
json
.
dumps
(
tool_call
[
"arguments"
],
ensure_ascii
=
False
)))
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
"Invalid JSON format in function message: {
}"
.
format
(
str
([
content
])
)
)
# flat string
raise
RuntimeError
(
f
"Invalid JSON format in function message:
{
str
([
content
])
}
"
)
# flat string
elements
=
[]
elements
=
[]
for
name
,
arguments
in
functions
:
for
name
,
arguments
in
functions
:
...
@@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
...
@@ -124,7 +124,7 @@ class FunctionFormatter(Formatter):
elif
isinstance
(
slot
,
(
dict
,
set
)):
elif
isinstance
(
slot
,
(
dict
,
set
)):
elements
.
append
(
slot
)
elements
.
append
(
slot
)
else
:
else
:
raise
RuntimeError
(
"Input must be string, set[str] or dict[str, str], got {
}"
.
format
(
type
(
slot
)
)
)
raise
RuntimeError
(
f
"Input must be string, set[str] or dict[str, str], got
{
type
(
slot
)
}
"
)
return
elements
return
elements
...
@@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
...
@@ -141,7 +141,7 @@ class ToolFormatter(Formatter):
tools
=
json
.
loads
(
content
)
tools
=
json
.
loads
(
content
)
return
[
self
.
tool_utils
.
tool_formatter
(
tools
)
if
len
(
tools
)
!=
0
else
""
]
return
[
self
.
tool_utils
.
tool_formatter
(
tools
)
if
len
(
tools
)
!=
0
else
""
]
except
json
.
JSONDecodeError
:
except
json
.
JSONDecodeError
:
raise
RuntimeError
(
"Invalid JSON format in tool description: {
}"
.
format
(
str
([
content
])
)
)
# flat string
raise
RuntimeError
(
f
"Invalid JSON format in tool description:
{
str
([
content
])
}
"
)
# flat string
@
override
@
override
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
"FunctionCall"
]]:
...
...
src/llamafactory/data/loader.py
View file @
2778a3d0
...
@@ -20,8 +20,8 @@ import numpy as np
...
@@ -20,8 +20,8 @@ import numpy as np
from
datasets
import
DatasetDict
,
load_dataset
,
load_from_disk
from
datasets
import
DatasetDict
,
load_dataset
,
load_from_disk
from
transformers.utils.versions
import
require_version
from
transformers.utils.versions
import
require_version
from
..extras
import
logging
from
..extras.constants
import
FILEEXT2TYPE
from
..extras.constants
import
FILEEXT2TYPE
from
..extras.logging
import
get_logger
from
..extras.misc
import
has_tokenized_data
from
..extras.misc
import
has_tokenized_data
from
.aligner
import
align_dataset
from
.aligner
import
align_dataset
from
.data_utils
import
merge_dataset
,
split_dataset
from
.data_utils
import
merge_dataset
,
split_dataset
...
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
...
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from
.template
import
Template
from
.template
import
Template
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_load_single_dataset
(
def
_load_single_dataset
(
...
@@ -51,9 +51,9 @@ def _load_single_dataset(
...
@@ -51,9 +51,9 @@ def _load_single_dataset(
r
"""
r
"""
Loads a single dataset and aligns it to the standard format.
Loads a single dataset and aligns it to the standard format.
"""
"""
logger
.
info
(
"Loading dataset {
}..."
.
format
(
dataset_attr
)
)
logger
.
info
_rank0
(
f
"Loading dataset
{
dataset_attr
}
..."
)
data_path
,
data_name
,
data_dir
,
data_files
=
None
,
None
,
None
,
None
data_path
,
data_name
,
data_dir
,
data_files
=
None
,
None
,
None
,
None
if
dataset_attr
.
load_from
in
[
"hf_hub"
,
"ms_hub"
]:
if
dataset_attr
.
load_from
in
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
]:
data_path
=
dataset_attr
.
dataset_name
data_path
=
dataset_attr
.
dataset_name
data_name
=
dataset_attr
.
subset
data_name
=
dataset_attr
.
subset
data_dir
=
dataset_attr
.
folder
data_dir
=
dataset_attr
.
folder
...
@@ -69,25 +69,24 @@ def _load_single_dataset(
...
@@ -69,25 +69,24 @@ def _load_single_dataset(
if
os
.
path
.
isdir
(
local_path
):
# is directory
if
os
.
path
.
isdir
(
local_path
):
# is directory
for
file_name
in
os
.
listdir
(
local_path
):
for
file_name
in
os
.
listdir
(
local_path
):
data_files
.
append
(
os
.
path
.
join
(
local_path
,
file_name
))
data_files
.
append
(
os
.
path
.
join
(
local_path
,
file_name
))
if
data_path
is
None
:
data_path
=
FILEEXT2TYPE
.
get
(
file_name
.
split
(
"."
)[
-
1
],
None
)
elif
data_path
!=
FILEEXT2TYPE
.
get
(
file_name
.
split
(
"."
)[
-
1
],
None
):
raise
ValueError
(
"File types should be identical."
)
elif
os
.
path
.
isfile
(
local_path
):
# is file
elif
os
.
path
.
isfile
(
local_path
):
# is file
data_files
.
append
(
local_path
)
data_files
.
append
(
local_path
)
data_path
=
FILEEXT2TYPE
.
get
(
local_path
.
split
(
"."
)[
-
1
],
None
)
else
:
else
:
raise
ValueError
(
"File {} not found."
.
format
(
local_path
)
)
raise
ValueError
(
f
"File
{
local_path
}
not found."
)
data_path
=
FILEEXT2TYPE
.
get
(
os
.
path
.
splitext
(
data_files
[
0
])[
-
1
][
1
:],
None
)
if
data_path
is
None
:
if
data_path
is
None
:
raise
ValueError
(
"Allowed file types: {}."
.
format
(
","
.
join
(
FILEEXT2TYPE
.
keys
())))
raise
ValueError
(
"Allowed file types: {}."
.
format
(
","
.
join
(
FILEEXT2TYPE
.
keys
())))
if
any
(
data_path
!=
FILEEXT2TYPE
.
get
(
os
.
path
.
splitext
(
data_file
)[
-
1
][
1
:],
None
)
for
data_file
in
data_files
):
raise
ValueError
(
"File types should be identical."
)
else
:
else
:
raise
NotImplementedError
(
"Unknown load type: {
}."
.
format
(
dataset_attr
.
load_from
)
)
raise
NotImplementedError
(
f
"Unknown load type:
{
dataset_attr
.
load_from
}
."
)
if
dataset_attr
.
load_from
==
"ms_hub"
:
if
dataset_attr
.
load_from
==
"ms_hub"
:
require_version
(
"modelscope>=1.11.0"
,
"To fix: pip install modelscope>=1.11.0"
)
require_version
(
"modelscope>=1.11.0"
,
"To fix: pip install modelscope>=1.11.0"
)
from
modelscope
import
MsDataset
from
modelscope
import
MsDataset
# type: ignore
from
modelscope.utils.config_ds
import
MS_DATASETS_CACHE
from
modelscope.utils.config_ds
import
MS_DATASETS_CACHE
# type: ignore
cache_dir
=
model_args
.
cache_dir
or
MS_DATASETS_CACHE
cache_dir
=
model_args
.
cache_dir
or
MS_DATASETS_CACHE
dataset
=
MsDataset
.
load
(
dataset
=
MsDataset
.
load
(
...
@@ -98,10 +97,27 @@ def _load_single_dataset(
...
@@ -98,10 +97,27 @@ def _load_single_dataset(
split
=
dataset_attr
.
split
,
split
=
dataset_attr
.
split
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
token
=
model_args
.
ms_hub_token
,
token
=
model_args
.
ms_hub_token
,
use_streaming
=
(
data_args
.
streaming
and
(
dataset_attr
.
load_from
!=
"file"
))
,
use_streaming
=
data_args
.
streaming
,
)
)
if
isinstance
(
dataset
,
MsDataset
):
if
isinstance
(
dataset
,
MsDataset
):
dataset
=
dataset
.
to_hf_dataset
()
dataset
=
dataset
.
to_hf_dataset
()
elif
dataset_attr
.
load_from
==
"om_hub"
:
require_version
(
"openmind>=0.8.0"
,
"To fix: pip install openmind>=0.8.0"
)
from
openmind
import
OmDataset
# type: ignore
from
openmind.utils.hub
import
OM_DATASETS_CACHE
# type: ignore
cache_dir
=
model_args
.
cache_dir
or
OM_DATASETS_CACHE
dataset
=
OmDataset
.
load_dataset
(
path
=
data_path
,
name
=
data_name
,
data_dir
=
data_dir
,
data_files
=
data_files
,
split
=
dataset_attr
.
split
,
cache_dir
=
cache_dir
,
token
=
model_args
.
om_hub_token
,
streaming
=
data_args
.
streaming
,
)
else
:
else
:
dataset
=
load_dataset
(
dataset
=
load_dataset
(
path
=
data_path
,
path
=
data_path
,
...
@@ -111,13 +127,10 @@ def _load_single_dataset(
...
@@ -111,13 +127,10 @@ def _load_single_dataset(
split
=
dataset_attr
.
split
,
split
=
dataset_attr
.
split
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
token
=
model_args
.
hf_hub_token
,
token
=
model_args
.
hf_hub_token
,
streaming
=
(
data_args
.
streaming
and
(
dataset_attr
.
load_from
!=
"file"
))
,
streaming
=
data_args
.
streaming
,
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
if
data_args
.
streaming
and
(
dataset_attr
.
load_from
==
"file"
):
# faster than specifying streaming=True
dataset
=
dataset
.
to_iterable_dataset
()
# TODO: add num shards parameter
if
dataset_attr
.
num_samples
is
not
None
and
not
data_args
.
streaming
:
if
dataset_attr
.
num_samples
is
not
None
and
not
data_args
.
streaming
:
target_num
=
dataset_attr
.
num_samples
target_num
=
dataset_attr
.
num_samples
indexes
=
np
.
random
.
permutation
(
len
(
dataset
))[:
target_num
]
# all samples should be included
indexes
=
np
.
random
.
permutation
(
len
(
dataset
))[:
target_num
]
# all samples should be included
...
@@ -128,7 +141,7 @@ def _load_single_dataset(
...
@@ -128,7 +141,7 @@ def _load_single_dataset(
assert
len
(
indexes
)
==
dataset_attr
.
num_samples
,
"Sample num mismatched."
assert
len
(
indexes
)
==
dataset_attr
.
num_samples
,
"Sample num mismatched."
dataset
=
dataset
.
select
(
indexes
)
dataset
=
dataset
.
select
(
indexes
)
logger
.
info
(
"Sampled {} examples from dataset {
}."
.
format
(
dataset_attr
.
num_samples
,
dataset_attr
)
)
logger
.
info
_rank0
(
f
"Sampled
{
dataset_attr
.
num_samples
}
examples from dataset
{
dataset_attr
}
."
)
if
data_args
.
max_samples
is
not
None
:
# truncate dataset
if
data_args
.
max_samples
is
not
None
:
# truncate dataset
max_samples
=
min
(
data_args
.
max_samples
,
len
(
dataset
))
max_samples
=
min
(
data_args
.
max_samples
,
len
(
dataset
))
...
@@ -224,9 +237,9 @@ def get_dataset(
...
@@ -224,9 +237,9 @@ def get_dataset(
# Load tokenized dataset
# Load tokenized dataset
if
data_args
.
tokenized_path
is
not
None
:
if
data_args
.
tokenized_path
is
not
None
:
if
has_tokenized_data
(
data_args
.
tokenized_path
):
if
has_tokenized_data
(
data_args
.
tokenized_path
):
logger
.
warning
(
"Loading dataset from disk will ignore other data arguments."
)
logger
.
warning
_rank0
(
"Loading dataset from disk will ignore other data arguments."
)
dataset_dict
:
"DatasetDict"
=
load_from_disk
(
data_args
.
tokenized_path
)
dataset_dict
:
"DatasetDict"
=
load_from_disk
(
data_args
.
tokenized_path
)
logger
.
info
(
"Loaded tokenized dataset from {
}."
.
format
(
data_args
.
tokenized_path
)
)
logger
.
info
_rank0
(
f
"Loaded tokenized dataset from
{
data_args
.
tokenized_path
}
."
)
dataset_module
:
Dict
[
str
,
"Dataset"
]
=
{}
dataset_module
:
Dict
[
str
,
"Dataset"
]
=
{}
if
"train"
in
dataset_dict
:
if
"train"
in
dataset_dict
:
...
@@ -277,8 +290,8 @@ def get_dataset(
...
@@ -277,8 +290,8 @@ def get_dataset(
if
data_args
.
tokenized_path
is
not
None
:
if
data_args
.
tokenized_path
is
not
None
:
if
training_args
.
should_save
:
if
training_args
.
should_save
:
dataset_dict
.
save_to_disk
(
data_args
.
tokenized_path
)
dataset_dict
.
save_to_disk
(
data_args
.
tokenized_path
)
logger
.
info
(
"Tokenized dataset saved at {
}."
.
format
(
data_args
.
tokenized_path
)
)
logger
.
info
_rank0
(
f
"Tokenized dataset saved at
{
data_args
.
tokenized_path
}
."
)
logger
.
info
(
"Please restart the training with `tokenized_path: {
}`."
.
format
(
data_args
.
tokenized_path
)
)
logger
.
info
_rank0
(
f
"Please restart the training with `tokenized_path:
{
data_args
.
tokenized_path
}
`."
)
sys
.
exit
(
0
)
sys
.
exit
(
0
)
...
...
src/llamafactory/data/mm_plugin.py
View file @
2778a3d0
This diff is collapsed.
Click to expand it.
src/llamafactory/data/parser.py
View file @
2778a3d0
...
@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
...
@@ -20,7 +20,7 @@ from typing import Any, Dict, List, Literal, Optional, Sequence
from
transformers.utils
import
cached_file
from
transformers.utils
import
cached_file
from
..extras.constants
import
DATA_CONFIG
from
..extras.constants
import
DATA_CONFIG
from
..extras.misc
import
use_modelscope
from
..extras.misc
import
use_modelscope
,
use_openmind
@
dataclass
@
dataclass
...
@@ -30,7 +30,7 @@ class DatasetAttr:
...
@@ -30,7 +30,7 @@ class DatasetAttr:
"""
"""
# basic configs
# basic configs
load_from
:
Literal
[
"hf_hub"
,
"ms_hub"
,
"script"
,
"file"
]
load_from
:
Literal
[
"hf_hub"
,
"ms_hub"
,
"om_hub"
,
"script"
,
"file"
]
dataset_name
:
str
dataset_name
:
str
formatting
:
Literal
[
"alpaca"
,
"sharegpt"
]
=
"alpaca"
formatting
:
Literal
[
"alpaca"
,
"sharegpt"
]
=
"alpaca"
ranking
:
bool
=
False
ranking
:
bool
=
False
...
@@ -87,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
...
@@ -87,31 +87,39 @@ def get_dataset_list(dataset_names: Optional[Sequence[str]], dataset_dir: str) -
config_path
=
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
)
config_path
=
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
)
try
:
try
:
with
open
(
config_path
,
"r"
)
as
f
:
with
open
(
config_path
)
as
f
:
dataset_info
=
json
.
load
(
f
)
dataset_info
=
json
.
load
(
f
)
except
Exception
as
err
:
except
Exception
as
err
:
if
len
(
dataset_names
)
!=
0
:
if
len
(
dataset_names
)
!=
0
:
raise
ValueError
(
"Cannot open {
} due to {}."
.
format
(
config_path
,
str
(
err
)
)
)
raise
ValueError
(
f
"Cannot open
{
config_path
}
due to
{
str
(
err
)
}
."
)
dataset_info
=
None
dataset_info
=
None
dataset_list
:
List
[
"DatasetAttr"
]
=
[]
dataset_list
:
List
[
"DatasetAttr"
]
=
[]
for
name
in
dataset_names
:
for
name
in
dataset_names
:
if
dataset_info
is
None
:
# dataset_dir is ONLINE
if
dataset_info
is
None
:
# dataset_dir is ONLINE
load_from
=
"ms_hub"
if
use_modelscope
()
else
"hf_hub"
if
use_modelscope
():
load_from
=
"ms_hub"
elif
use_openmind
():
load_from
=
"om_hub"
else
:
load_from
=
"hf_hub"
dataset_attr
=
DatasetAttr
(
load_from
,
dataset_name
=
name
)
dataset_attr
=
DatasetAttr
(
load_from
,
dataset_name
=
name
)
dataset_list
.
append
(
dataset_attr
)
dataset_list
.
append
(
dataset_attr
)
continue
continue
if
name
not
in
dataset_info
:
if
name
not
in
dataset_info
:
raise
ValueError
(
"Undefined dataset {} in {
}."
.
format
(
name
,
DATA_CONFIG
)
)
raise
ValueError
(
f
"Undefined dataset
{
name
}
in
{
DATA_CONFIG
}
."
)
has_hf_url
=
"hf_hub_url"
in
dataset_info
[
name
]
has_hf_url
=
"hf_hub_url"
in
dataset_info
[
name
]
has_ms_url
=
"ms_hub_url"
in
dataset_info
[
name
]
has_ms_url
=
"ms_hub_url"
in
dataset_info
[
name
]
has_om_url
=
"om_hub_url"
in
dataset_info
[
name
]
if
has_hf_url
or
has_ms_url
:
if
has_hf_url
or
has_ms_url
or
has_om_url
:
if
(
use_modelscope
()
and
has_ms_url
)
or
(
not
has_hf_url
):
if
has_ms_url
and
(
use_modelscope
()
or
not
has_hf_url
):
dataset_attr
=
DatasetAttr
(
"ms_hub"
,
dataset_name
=
dataset_info
[
name
][
"ms_hub_url"
])
dataset_attr
=
DatasetAttr
(
"ms_hub"
,
dataset_name
=
dataset_info
[
name
][
"ms_hub_url"
])
elif
has_om_url
and
(
use_openmind
()
or
not
has_hf_url
):
dataset_attr
=
DatasetAttr
(
"om_hub"
,
dataset_name
=
dataset_info
[
name
][
"om_hub_url"
])
else
:
else
:
dataset_attr
=
DatasetAttr
(
"hf_hub"
,
dataset_name
=
dataset_info
[
name
][
"hf_hub_url"
])
dataset_attr
=
DatasetAttr
(
"hf_hub"
,
dataset_name
=
dataset_info
[
name
][
"hf_hub_url"
])
elif
"script_url"
in
dataset_info
[
name
]:
elif
"script_url"
in
dataset_info
[
name
]:
...
...
src/llamafactory/data/processors/feedback.py
View file @
2778a3d0
...
@@ -15,8 +15,8 @@
...
@@ -15,8 +15,8 @@
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
.processor_utils
import
infer_seqlen
from
.processor_utils
import
infer_seqlen
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from
..template
import
Template
from
..template
import
Template
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_feedback_example
(
def
_encode_feedback_example
(
...
@@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
...
@@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
model_inputs
=
defaultdict
(
list
)
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
]))
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
continue
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
=
_encode_feedback_example
(
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
=
_encode_feedback_example
(
...
@@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
...
@@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
desirable_num
=
sum
([
1
for
tag
in
model_inputs
[
"kto_tags"
]
if
tag
])
desirable_num
=
sum
([
1
for
tag
in
model_inputs
[
"kto_tags"
]
if
tag
])
undesirable_num
=
len
(
model_inputs
[
"kto_tags"
])
-
desirable_num
undesirable_num
=
len
(
model_inputs
[
"kto_tags"
])
-
desirable_num
if
desirable_num
==
0
or
undesirable_num
==
0
:
if
desirable_num
==
0
or
undesirable_num
==
0
:
logger
.
warning
(
"Your dataset only has one preference type."
)
logger
.
warning
_rank0
(
"Your dataset only has one preference type."
)
return
model_inputs
return
model_inputs
src/llamafactory/data/processors/pairwise.py
View file @
2778a3d0
...
@@ -15,8 +15,8 @@
...
@@ -15,8 +15,8 @@
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
.processor_utils
import
infer_seqlen
from
.processor_utils
import
infer_seqlen
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from
..template
import
Template
from
..template
import
Template
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_pairwise_example
(
def
_encode_pairwise_example
(
...
@@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
...
@@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
model_inputs
=
defaultdict
(
list
)
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
<
2
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
]))
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
continue
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
=
_encode_pairwise_example
(
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
=
_encode_pairwise_example
(
...
@@ -110,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
...
@@ -110,8 +112,8 @@ def print_pairwise_dataset_example(example: Dict[str, List[int]], tokenizer: "Pr
print
(
"chosen_input_ids:
\n
{}"
.
format
(
example
[
"chosen_input_ids"
]))
print
(
"chosen_input_ids:
\n
{}"
.
format
(
example
[
"chosen_input_ids"
]))
print
(
"chosen_inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"chosen_input_ids"
],
skip_special_tokens
=
False
)))
print
(
"chosen_inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"chosen_input_ids"
],
skip_special_tokens
=
False
)))
print
(
"chosen_label_ids:
\n
{}"
.
format
(
example
[
"chosen_labels"
]))
print
(
"chosen_label_ids:
\n
{}"
.
format
(
example
[
"chosen_labels"
]))
print
(
"chosen_labels:
\n
{
}"
.
format
(
tokenizer
.
decode
(
valid_chosen_labels
,
skip_special_tokens
=
False
)
)
)
print
(
f
"chosen_labels:
\n
{
tokenizer
.
decode
(
valid_chosen_labels
,
skip_special_tokens
=
False
)
}
"
)
print
(
"rejected_input_ids:
\n
{}"
.
format
(
example
[
"rejected_input_ids"
]))
print
(
"rejected_input_ids:
\n
{}"
.
format
(
example
[
"rejected_input_ids"
]))
print
(
"rejected_inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"rejected_input_ids"
],
skip_special_tokens
=
False
)))
print
(
"rejected_inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"rejected_input_ids"
],
skip_special_tokens
=
False
)))
print
(
"rejected_label_ids:
\n
{}"
.
format
(
example
[
"rejected_labels"
]))
print
(
"rejected_label_ids:
\n
{}"
.
format
(
example
[
"rejected_labels"
]))
print
(
"rejected_labels:
\n
{
}"
.
format
(
tokenizer
.
decode
(
valid_rejected_labels
,
skip_special_tokens
=
False
)
)
)
print
(
f
"rejected_labels:
\n
{
tokenizer
.
decode
(
valid_rejected_labels
,
skip_special_tokens
=
False
)
}
"
)
src/llamafactory/data/processors/supervised.py
View file @
2778a3d0
...
@@ -15,8 +15,8 @@
...
@@ -15,8 +15,8 @@
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras
import
logging
from
...extras.constants
import
IGNORE_INDEX
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
.processor_utils
import
greedy_knapsack
,
infer_seqlen
from
.processor_utils
import
greedy_knapsack
,
infer_seqlen
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
...
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
from
..template
import
Template
from
..template
import
Template
logger
=
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
def
_encode_supervised_example
(
def
_encode_supervised_example
(
...
@@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
...
@@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
model_inputs
=
defaultdict
(
list
)
model_inputs
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
]))
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
continue
input_ids
,
labels
=
_encode_supervised_example
(
input_ids
,
labels
=
_encode_supervised_example
(
...
@@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
...
@@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
length2indexes
=
defaultdict
(
list
)
length2indexes
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
for
i
in
range
(
len
(
examples
[
"_prompt"
])):
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
if
len
(
examples
[
"_prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"_response"
][
i
])
!=
1
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
]))
logger
.
warning_rank0
(
"Dropped invalid example: {}"
.
format
(
examples
[
"_prompt"
][
i
]
+
examples
[
"_response"
][
i
])
)
continue
continue
input_ids
,
labels
=
_encode_supervised_example
(
input_ids
,
labels
=
_encode_supervised_example
(
...
@@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
...
@@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
)
)
length
=
len
(
input_ids
)
length
=
len
(
input_ids
)
if
length
>
data_args
.
cutoff_len
:
if
length
>
data_args
.
cutoff_len
:
logger
.
warning
(
"Dropped lengthy example with length {
} > {}."
.
format
(
length
,
data_args
.
cutoff_len
)
)
logger
.
warning
_rank0
(
f
"Dropped lengthy example with length
{
length
}
>
{
data_args
.
cutoff_len
}
."
)
else
:
else
:
lengths
.
append
(
length
)
lengths
.
append
(
length
)
length2indexes
[
length
].
append
(
valid_num
)
length2indexes
[
length
].
append
(
valid_num
)
...
@@ -212,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
...
@@ -212,4 +216,4 @@ def print_supervised_dataset_example(example: Dict[str, List[int]], tokenizer: "
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
"labels:
\n
{
}"
.
format
(
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)
)
)
print
(
f
"labels:
\n
{
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)
}
"
)
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment