Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
317a82e2
Commit
317a82e2
authored
Mar 07, 2025
by
chenych
Browse files
Add QWQ-32B
parent
37b0ad9f
Changes
255
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
465 additions
and
434 deletions
+465
-434
scripts/test_toolcall.py
scripts/test_toolcall.py
+0
-78
scripts/vllm_infer.py
scripts/vllm_infer.py
+10
-4
setup.py
setup.py
+7
-8
src/api.py
src/api.py
+1
-1
src/llamafactory/__init__.py
src/llamafactory/__init__.py
+6
-6
src/llamafactory/api/app.py
src/llamafactory/api/app.py
+1
-1
src/llamafactory/api/chat.py
src/llamafactory/api/chat.py
+10
-3
src/llamafactory/api/common.py
src/llamafactory/api/common.py
+1
-1
src/llamafactory/api/protocol.py
src/llamafactory/api/protocol.py
+1
-1
src/llamafactory/chat/__init__.py
src/llamafactory/chat/__init__.py
+1
-1
src/llamafactory/chat/base_engine.py
src/llamafactory/chat/base_engine.py
+4
-2
src/llamafactory/chat/chat_model.py
src/llamafactory/chat/chat_model.py
+11
-5
src/llamafactory/chat/hf_engine.py
src/llamafactory/chat/hf_engine.py
+48
-12
src/llamafactory/chat/vllm_engine.py
src/llamafactory/chat/vllm_engine.py
+28
-23
src/llamafactory/cli.py
src/llamafactory/cli.py
+3
-3
src/llamafactory/data/__init__.py
src/llamafactory/data/__init__.py
+1
-1
src/llamafactory/data/aligner.py
src/llamafactory/data/aligner.py
+0
-264
src/llamafactory/data/collator.py
src/llamafactory/data/collator.py
+60
-19
src/llamafactory/data/converter.py
src/llamafactory/data/converter.py
+271
-0
src/llamafactory/data/data_utils.py
src/llamafactory/data/data_utils.py
+1
-1
No files found.
scripts/test_toolcall.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
typing
import
Sequence
from
openai
import
OpenAI
from
transformers.utils.versions
import
require_version
require_version
(
"openai>=1.5.0"
,
"To fix: pip install openai>=1.5.0"
)
def
calculate_gpa
(
grades
:
Sequence
[
str
],
hours
:
Sequence
[
int
])
->
float
:
grade_to_score
=
{
"A"
:
4
,
"B"
:
3
,
"C"
:
2
}
total_score
,
total_hour
=
0
,
0
for
grade
,
hour
in
zip
(
grades
,
hours
):
total_score
+=
grade_to_score
[
grade
]
*
hour
total_hour
+=
hour
return
round
(
total_score
/
total_hour
,
2
)
def
main
():
client
=
OpenAI
(
api_key
=
"{}"
.
format
(
os
.
environ
.
get
(
"API_KEY"
,
"0"
)),
base_url
=
"http://localhost:{}/v1"
.
format
(
os
.
environ
.
get
(
"API_PORT"
,
8000
)),
)
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"calculate_gpa"
,
"description"
:
"Calculate the Grade Point Average (GPA) based on grades and credit hours"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"grades"
:
{
"type"
:
"array"
,
"items"
:
{
"type"
:
"string"
},
"description"
:
"The grades"
},
"hours"
:
{
"type"
:
"array"
,
"items"
:
{
"type"
:
"integer"
},
"description"
:
"The credit hours"
},
},
"required"
:
[
"grades"
,
"hours"
],
},
},
}
]
tool_map
=
{
"calculate_gpa"
:
calculate_gpa
}
messages
=
[]
messages
.
append
({
"role"
:
"user"
,
"content"
:
"My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."
})
result
=
client
.
chat
.
completions
.
create
(
messages
=
messages
,
model
=
"test"
,
tools
=
tools
)
if
result
.
choices
[
0
].
message
.
tool_calls
is
None
:
raise
ValueError
(
"Cannot retrieve function call from the response."
)
messages
.
append
(
result
.
choices
[
0
].
message
)
tool_call
=
result
.
choices
[
0
].
message
.
tool_calls
[
0
].
function
print
(
tool_call
)
# Function(arguments='{"grades": ["A", "A", "B", "C"], "hours": [3, 4, 3, 2]}', name='calculate_gpa')
name
,
arguments
=
tool_call
.
name
,
json
.
loads
(
tool_call
.
arguments
)
tool_result
=
tool_map
[
name
](
**
arguments
)
messages
.
append
({
"role"
:
"tool"
,
"content"
:
json
.
dumps
({
"gpa"
:
tool_result
},
ensure_ascii
=
False
)})
result
=
client
.
chat
.
completions
.
create
(
messages
=
messages
,
model
=
"test"
,
tools
=
tools
)
print
(
result
.
choices
[
0
].
message
.
content
)
# Based on the grades and credit hours you provided, your Grade Point Average (GPA) is 3.42.
if
__name__
==
"__main__"
:
main
()
scripts/vllm_infer.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
json
from
typing
import
Optional
import
fire
from
transformers
import
Seq2SeqTrainingArguments
...
...
@@ -45,14 +46,16 @@ def vllm_infer(
top_k
:
int
=
50
,
max_new_tokens
:
int
=
1024
,
repetition_penalty
:
float
=
1.0
,
seed
:
Optional
[
int
]
=
None
,
pipeline_parallel_size
:
int
=
1
,
image_resolution
:
int
=
512
*
512
,
image_max_pixels
:
int
=
768
*
768
,
image_min_pixels
:
int
=
32
*
32
,
):
r
"""
Performs batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
"""
check_version
(
"vllm>=0.4.3,<=0.
6.5
"
)
check_version
(
"vllm>=0.4.3,<=0.
7.3
"
)
if
pipeline_parallel_size
>
get_device_count
():
raise
ValueError
(
"Pipeline parallel size should be smaller than the number of gpus."
)
...
...
@@ -86,7 +89,9 @@ def vllm_infer(
for
sample
in
dataset_module
[
"train_dataset"
]:
if
sample
[
"images"
]:
multi_modal_data
=
{
"image"
:
template_obj
.
mm_plugin
.
_regularize_images
(
sample
[
"images"
],
image_resolution
=
image_resolution
)
"image"
:
template_obj
.
mm_plugin
.
_regularize_images
(
sample
[
"images"
],
image_max_pixels
=
image_max_pixels
,
image_min_pixels
=
image_min_pixels
)
}
else
:
multi_modal_data
=
None
...
...
@@ -105,6 +110,7 @@ def vllm_infer(
stop_token_ids
=
template_obj
.
get_stop_token_ids
(
tokenizer
),
max_tokens
=
generating_args
.
max_new_tokens
,
skip_special_tokens
=
False
,
seed
=
seed
,
)
if
model_args
.
adapter_name_or_path
is
not
None
:
lora_request
=
LoRARequest
(
"default"
,
1
,
model_args
.
adapter_name_or_path
[
0
])
...
...
setup.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -36,7 +36,7 @@ def get_requires() -> List[str]:
def
get_console_scripts
()
->
List
[
str
]:
console_scripts
=
[
"llamafactory-cli = llamafactory.cli:main"
]
if
os
.
environ
.
get
(
"ENABLE_SHORT_CONSOLE"
,
"1"
).
lower
()
in
[
"true"
,
"1"
]:
if
os
.
get
env
(
"ENABLE_SHORT_CONSOLE"
,
"1"
).
lower
()
in
[
"true"
,
"y"
,
"1"
]:
console_scripts
.
append
(
"lmf = llamafactory.cli:main"
)
return
console_scripts
...
...
@@ -44,9 +44,9 @@ def get_console_scripts() -> List[str]:
extra_require
=
{
"torch"
:
[
"torch>=1.13.1"
],
"torch-npu"
:
[
"torch==2.
1
.0"
,
"torch-npu==2.
1
.0.post
3
"
,
"decorator"
],
"torch-npu"
:
[
"torch==2.
4
.0"
,
"torch-npu==2.
4
.0.post
2
"
,
"decorator"
],
"metrics"
:
[
"nltk"
,
"jieba"
,
"rouge-chinese"
],
"deepspeed"
:
[
"deepspeed>=0.10.0,<=0.1
4.4
"
],
"deepspeed"
:
[
"deepspeed>=0.10.0,<=0.1
6.2
"
],
"liger-kernel"
:
[
"liger-kernel"
],
"bitsandbytes"
:
[
"bitsandbytes>=0.39.0"
],
"hqq"
:
[
"hqq"
],
...
...
@@ -54,7 +54,7 @@ extra_require = {
"gptq"
:
[
"optimum>=1.17.0"
,
"auto-gptq>=0.5.0"
],
"awq"
:
[
"autoawq"
],
"aqlm"
:
[
"aqlm[gpu]>=1.1.0"
],
"vllm"
:
[
"vllm>=0.4.3,<=0.
6.5
"
],
"vllm"
:
[
"vllm>=0.4.3,<=0.
7.3
"
],
"galore"
:
[
"galore-torch"
],
"apollo"
:
[
"apollo-torch"
],
"badam"
:
[
"badam>=1.2.1"
],
...
...
@@ -69,7 +69,6 @@ extra_require = {
"msgpack"
,
"referencing"
,
"jsonschema_specifications"
,
"librosa"
,
],
"modelscope"
:
[
"modelscope"
],
"openmind"
:
[
"openmind"
],
...
...
@@ -92,7 +91,7 @@ def main():
url
=
"https://github.com/hiyouga/LLaMA-Factory"
,
package_dir
=
{
""
:
"src"
},
packages
=
find_packages
(
"src"
),
python_requires
=
">=3.
8
.0"
,
python_requires
=
">=3.
9
.0"
,
install_requires
=
get_requires
(),
extras_require
=
extra_require
,
entry_points
=
{
"console_scripts"
:
get_console_scripts
()},
...
...
@@ -104,10 +103,10 @@ def main():
"License :: OSI Approved :: Apache Software License"
,
"Operating System :: OS Independent"
,
"Programming Language :: Python :: 3"
,
"Programming Language :: Python :: 3.8"
,
"Programming Language :: Python :: 3.9"
,
"Programming Language :: Python :: 3.10"
,
"Programming Language :: Python :: 3.11"
,
"Programming Language :: Python :: 3.12"
,
"Topic :: Scientific/Engineering :: Artificial Intelligence"
,
],
)
...
...
src/api.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/__init__.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -20,17 +20,17 @@ Level:
Dependency graph:
main:
transformers>=4.41.2,<=4.4
6.1
datasets>=2.16.0,<=3.
1
.0
accelerate>=0.34.0,<=1.
0
.1
transformers>=4.41.2,<=4.4
9.0,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.
2
.0
accelerate>=0.34.0,<=1.
2
.1
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<
=
4.4
6.1
transformers>=4.41.2,<4.4
8.0
packing:
transformers>=4.43.0
,<=4.46.1
transformers>=4.43.0
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
...
...
src/llamafactory/api/app.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/api/chat.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -22,6 +22,8 @@ from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
from
..data
import
Role
as
DataRole
from
..extras
import
logging
from
..extras.constants
import
IMAGE_PLACEHOLDER
from
..extras.misc
import
is_env_enabled
from
..extras.packages
import
is_fastapi_available
,
is_pillow_available
,
is_requests_available
from
.common
import
dictify
,
jsonify
from
.protocol
import
(
...
...
@@ -70,7 +72,8 @@ ROLE_MAPPING = {
def
_process_request
(
request
:
"ChatCompletionRequest"
,
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
Optional
[
str
],
Optional
[
str
],
Optional
[
List
[
"ImageInput"
]]]:
logger
.
info_rank0
(
f
"==== request ====
\n
{
json
.
dumps
(
dictify
(
request
),
indent
=
2
,
ensure_ascii
=
False
)
}
"
)
if
is_env_enabled
(
"API_VERBOSE"
,
"1"
):
logger
.
info_rank0
(
f
"==== request ====
\n
{
json
.
dumps
(
dictify
(
request
),
indent
=
2
,
ensure_ascii
=
False
)
}
"
)
if
len
(
request
.
messages
)
==
0
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid length"
)
...
...
@@ -99,10 +102,12 @@ def _process_request(
content
=
json
.
dumps
(
tool_calls
,
ensure_ascii
=
False
)
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
Role
.
FUNCTION
],
"content"
:
content
})
elif
isinstance
(
message
.
content
,
list
):
text_content
=
""
for
input_item
in
message
.
content
:
if
input_item
.
type
==
"text"
:
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"
content
"
:
input_item
.
text
})
text_
content
+=
input_item
.
text
else
:
text_content
+=
IMAGE_PLACEHOLDER
image_url
=
input_item
.
image_url
.
url
if
re
.
match
(
r
"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$"
,
image_url
):
# base64 image
image_stream
=
io
.
BytesIO
(
base64
.
b64decode
(
image_url
.
split
(
","
,
maxsplit
=
1
)[
1
]))
...
...
@@ -112,6 +117,8 @@ def _process_request(
image_stream
=
requests
.
get
(
image_url
,
stream
=
True
).
raw
images
.
append
(
Image
.
open
(
image_stream
).
convert
(
"RGB"
))
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
text_content
})
else
:
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
message
.
content
})
...
...
src/llamafactory/api/common.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/api/protocol.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/chat/__init__.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/chat/base_engine.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -22,7 +22,7 @@ if TYPE_CHECKING:
from
vllm
import
AsyncLLMEngine
from
..data
import
Template
from
..data.mm_plugin
import
ImageInput
,
VideoInput
from
..data.mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
...
...
@@ -68,6 +68,7 @@ class BaseEngine(ABC):
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
r
"""
...
...
@@ -83,6 +84,7 @@ class BaseEngine(ABC):
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
r
"""
...
...
src/llamafactory/chat/chat_model.py
View file @
317a82e2
...
...
@@ -27,7 +27,7 @@ from .vllm_engine import VllmEngine
if
TYPE_CHECKING
:
from
..data.mm_plugin
import
ImageInput
,
VideoInput
from
..data.mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
from
.base_engine
import
BaseEngine
,
Response
...
...
@@ -66,13 +66,14 @@ class ChatModel:
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
r
"""
Gets a list of responses of the chat model.
"""
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
achat
(
messages
,
system
,
tools
,
images
,
videos
,
**
input_kwargs
),
self
.
_loop
self
.
achat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
),
self
.
_loop
)
return
task
.
result
()
...
...
@@ -83,12 +84,13 @@ class ChatModel:
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
r
"""
Asynchronously gets a list of responses of the chat model.
"""
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
images
,
videos
,
**
input_kwargs
)
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
def
stream_chat
(
self
,
...
...
@@ -97,12 +99,13 @@ class ChatModel:
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
Generator
[
str
,
None
,
None
]:
r
"""
Gets the response token-by-token of the chat model.
"""
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
images
,
videos
,
**
input_kwargs
)
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
while
True
:
try
:
task
=
asyncio
.
run_coroutine_threadsafe
(
generator
.
__anext__
(),
self
.
_loop
)
...
...
@@ -117,12 +120,15 @@ class ChatModel:
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
r
"""
Asynchronously gets the response token-by-token of the chat model.
"""
async
for
new_token
in
self
.
engine
.
stream_chat
(
messages
,
system
,
tools
,
images
,
videos
,
**
input_kwargs
):
async
for
new_token
in
self
.
engine
.
stream_chat
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
):
yield
new_token
def
get_scores
(
...
...
src/llamafactory/chat/hf_engine.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -24,7 +24,7 @@ from typing_extensions import override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.misc
import
get_logits_processor
from
..model
import
load_model
,
load_tokenizer
from
.base_engine
import
BaseEngine
,
Response
...
...
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from
trl
import
PreTrainedModelWrapper
from
..data
import
Template
from
..data.mm_plugin
import
ImageInput
,
VideoInput
from
..data.mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
...
...
@@ -81,9 +81,10 @@ class HuggingfaceEngine(BaseEngine):
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
Tuple
[
Dict
[
str
,
Any
],
int
]:
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
]}
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"audios"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
],
"audlens"
:
[
0
]}
if
images
is
not
None
:
mm_input_dict
.
update
({
"images"
:
images
,
"imglens"
:
[
len
(
images
)]})
if
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
...
...
@@ -94,14 +95,25 @@ class HuggingfaceEngine(BaseEngine):
if
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
*
len
(
videos
)
+
messages
[
0
][
"content"
]
if
audios
is
not
None
:
mm_input_dict
.
update
({
"audios"
:
audios
,
"audlens"
:
[
len
(
audios
)]})
if
not
any
(
AUDIO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
AUDIO_PLACEHOLDER
*
len
(
audios
)
+
messages
[
0
][
"content"
]
messages
=
template
.
mm_plugin
.
process_messages
(
messages
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
processor
messages
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
mm_input_dict
[
"audios"
],
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
generating_args
[
"default_system"
]
prompt_ids
,
_
=
template
.
encode_oneturn
(
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
tokenizer
,
processor
prompt_ids
,
None
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
mm_input_dict
[
"audios"
],
tokenizer
,
processor
,
)
prompt_length
=
len
(
prompt_ids
)
inputs
=
torch
.
tensor
([
prompt_ids
],
device
=
model
.
device
)
...
...
@@ -114,6 +126,7 @@ class HuggingfaceEngine(BaseEngine):
num_return_sequences
:
int
=
input_kwargs
.
pop
(
"num_return_sequences"
,
1
)
repetition_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"repetition_penalty"
,
None
)
length_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"length_penalty"
,
None
)
skip_special_tokens
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"skip_special_tokens"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
...
...
@@ -133,6 +146,9 @@ class HuggingfaceEngine(BaseEngine):
if
repetition_penalty
is
not
None
else
generating_args
[
"repetition_penalty"
],
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
generating_args
[
"length_penalty"
],
skip_special_tokens
=
skip_special_tokens
if
skip_special_tokens
is
not
None
else
generating_args
[
"skip_special_tokens"
],
eos_token_id
=
template
.
get_stop_token_ids
(
tokenizer
),
pad_token_id
=
tokenizer
.
pad_token_id
,
)
...
...
@@ -166,9 +182,11 @@ class HuggingfaceEngine(BaseEngine):
mm_inputs
=
template
.
mm_plugin
.
get_mm_inputs
(
**
mm_input_dict
,
batch_ids
=
[
prompt_ids
],
processor
=
processor
)
for
key
,
value
in
mm_inputs
.
items
():
if
isinstance
(
value
,
list
)
and
all
(
isinstance
(
v
,
torch
.
Tensor
)
for
v
in
value
)
:
# for pixtral inputs
if
isinstance
(
value
,
list
)
and
isinstance
(
v
alue
[
0
]
,
torch
.
Tensor
):
# for pixtral inputs
value
=
torch
.
stack
(
value
)
# assume they have same sizes
elif
isinstance
(
value
,
list
)
and
all
(
isinstance
(
v
,
list
)
for
v
in
value
):
# for minicpmv inputs
elif
(
isinstance
(
value
,
list
)
and
isinstance
(
value
[
0
],
list
)
and
isinstance
(
value
[
0
][
0
],
torch
.
Tensor
)
):
# for minicpmv inputs
value
=
torch
.
stack
([
torch
.
stack
(
v
)
for
v
in
value
])
elif
not
isinstance
(
value
,
torch
.
Tensor
):
value
=
torch
.
tensor
(
value
)
...
...
@@ -176,12 +194,18 @@ class HuggingfaceEngine(BaseEngine):
if
torch
.
is_floating_point
(
value
):
# cast data dtype for paligemma
value
=
value
.
to
(
model
.
dtype
)
gen_kwargs
[
key
]
=
value
.
to
(
model
.
device
)
if
key
==
"second_per_grid_ts"
:
# qwen2.5vl special case
gen_kwargs
[
key
]
=
value
.
tolist
()
else
:
gen_kwargs
[
key
]
=
value
.
to
(
model
.
device
)
if
getattr
(
model
.
config
,
"model_type"
,
None
)
in
[
"minicpmv"
,
"minicpmo"
]:
gen_kwargs
[
"input_ids"
]
=
inputs
del
gen_kwargs
[
"image_sizes"
]
gen_kwargs
[
"tokenizer"
]
=
tokenizer
if
"audio_feature_lens"
in
mm_inputs
:
gen_kwargs
[
"audio_feature_lens"
]
=
mm_inputs
[
"audio_feature_lens"
]
gen_kwargs
.
pop
(
"image_sizes"
,
None
)
return
gen_kwargs
,
prompt_length
...
...
@@ -198,6 +222,7 @@ class HuggingfaceEngine(BaseEngine):
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
List
[
"Response"
]:
gen_kwargs
,
prompt_length
=
HuggingfaceEngine
.
_process_args
(
...
...
@@ -211,6 +236,7 @@ class HuggingfaceEngine(BaseEngine):
tools
,
images
,
videos
,
audios
,
input_kwargs
,
)
generate_output
=
model
.
generate
(
**
gen_kwargs
)
...
...
@@ -219,7 +245,9 @@ class HuggingfaceEngine(BaseEngine):
response_ids
=
generate_output
[:,
prompt_length
:]
response
=
tokenizer
.
batch_decode
(
response_ids
,
skip_special_tokens
=
generating_args
[
"skip_special_tokens"
],
clean_up_tokenization_spaces
=
True
response_ids
,
skip_special_tokens
=
getattr
(
gen_kwargs
[
"generation_config"
],
"skip_special_tokens"
,
True
),
clean_up_tokenization_spaces
=
True
,
)
results
=
[]
for
i
in
range
(
len
(
response
)):
...
...
@@ -249,6 +277,7 @@ class HuggingfaceEngine(BaseEngine):
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
Callable
[[],
str
]:
gen_kwargs
,
_
=
HuggingfaceEngine
.
_process_args
(
...
...
@@ -262,10 +291,13 @@ class HuggingfaceEngine(BaseEngine):
tools
,
images
,
videos
,
audios
,
input_kwargs
,
)
streamer
=
TextIteratorStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
generating_args
[
"skip_special_tokens"
]
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
getattr
(
gen_kwargs
[
"generation_config"
],
"skip_special_tokens"
,
True
),
)
gen_kwargs
[
"streamer"
]
=
streamer
thread
=
Thread
(
target
=
model
.
generate
,
kwargs
=
gen_kwargs
,
daemon
=
True
)
...
...
@@ -309,6 +341,7 @@ class HuggingfaceEngine(BaseEngine):
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
if
not
self
.
can_generate
:
...
...
@@ -326,6 +359,7 @@ class HuggingfaceEngine(BaseEngine):
tools
,
images
,
videos
,
audios
,
input_kwargs
,
)
async
with
self
.
semaphore
:
...
...
@@ -340,6 +374,7 @@ class HuggingfaceEngine(BaseEngine):
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
if
not
self
.
can_generate
:
...
...
@@ -357,6 +392,7 @@ class HuggingfaceEngine(BaseEngine):
tools
,
images
,
videos
,
audios
,
input_kwargs
,
)
async
with
self
.
semaphore
:
...
...
src/llamafactory/chat/vllm_engine.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -19,27 +19,22 @@ from typing_extensions import override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.misc
import
get_device_count
from
..extras.packages
import
is_pillow_available
,
is_vllm_available
from
..extras.packages
import
is_vllm_available
from
..model
import
load_config
,
load_tokenizer
from
..model.model_utils.quantization
import
QuantizationMethod
from
..model.model_utils.visual
import
LlavaMultiModalProjectorForYiVLForVLLM
from
.base_engine
import
BaseEngine
,
Response
if
is_pillow_available
():
from
PIL
import
Image
from
PIL.Image
import
Image
as
ImageObject
if
is_vllm_available
():
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
,
RequestOutput
,
SamplingParams
from
vllm.lora.request
import
LoRARequest
if
TYPE_CHECKING
:
from
..data.mm_plugin
import
ImageInput
,
VideoInput
from
..data.mm_plugin
import
AudioInput
,
ImageInput
,
VideoInput
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
...
...
@@ -54,6 +49,7 @@ class VllmEngine(BaseEngine):
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
self
.
model_args
=
model_args
config
=
load_config
(
model_args
)
# may download model from ms hub
if
getattr
(
config
,
"quantization_config"
,
None
):
# gptq models should use float16
quantization_config
:
Dict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
...
...
@@ -109,10 +105,11 @@ class VllmEngine(BaseEngine):
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncIterator
[
"RequestOutput"
]:
request_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
]}
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"audios"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
],
"audlens"
:
[
0
]}
if
images
is
not
None
:
mm_input_dict
.
update
({
"images"
:
images
,
"imglens"
:
[
len
(
images
)]})
if
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
...
...
@@ -123,8 +120,13 @@ class VllmEngine(BaseEngine):
if
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
*
len
(
videos
)
+
messages
[
0
][
"content"
]
if
audios
is
not
None
:
mm_input_dict
.
update
({
"audios"
:
audios
,
"audlens"
:
[
len
(
audios
)]})
if
not
any
(
AUDIO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
AUDIO_PLACEHOLDER
*
len
(
audios
)
+
messages
[
0
][
"content"
]
messages
=
self
.
template
.
mm_plugin
.
process_messages
(
messages
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
self
.
processor
messages
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
mm_input_dict
[
"audios"
],
self
.
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
...
...
@@ -137,6 +139,7 @@ class VllmEngine(BaseEngine):
num_return_sequences
:
int
=
input_kwargs
.
pop
(
"num_return_sequences"
,
1
)
repetition_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"repetition_penalty"
,
None
)
length_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"length_penalty"
,
None
)
skip_special_tokens
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"skip_special_tokens"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
...
...
@@ -170,19 +173,19 @@ class VllmEngine(BaseEngine):
stop
=
stop
,
stop_token_ids
=
self
.
template
.
get_stop_token_ids
(
self
.
tokenizer
),
max_tokens
=
max_tokens
,
skip_special_tokens
=
self
.
generating_args
[
"skip_special_tokens"
],
skip_special_tokens
=
skip_special_tokens
if
skip_special_tokens
is
not
None
else
self
.
generating_args
[
"skip_special_tokens"
],
)
if
images
is
not
None
:
# add image features
multi_modal_data
=
{
"image"
:
[]}
for
image
in
images
:
if
not
isinstance
(
image
,
(
str
,
ImageObject
)):
raise
ValueError
(
f
"Expected image input is a path or PIL.Image, but got
{
type
(
image
)
}
."
)
if
isinstance
(
image
,
str
):
image
=
Image
.
open
(
image
).
convert
(
"RGB"
)
multi_modal_data
[
"image"
].
append
(
image
)
multi_modal_data
=
{
"image"
:
self
.
template
.
mm_plugin
.
_regularize_images
(
images
,
image_max_pixels
=
self
.
model_args
.
image_max_pixels
,
image_min_pixels
=
self
.
model_args
.
image_min_pixels
,
)
}
else
:
multi_modal_data
=
None
...
...
@@ -202,10 +205,11 @@ class VllmEngine(BaseEngine):
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
final_output
=
None
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
**
input_kwargs
)
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
async
for
request_output
in
generator
:
final_output
=
request_output
...
...
@@ -230,10 +234,11 @@ class VllmEngine(BaseEngine):
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
generated_text
=
""
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
**
input_kwargs
)
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
images
,
videos
,
audios
,
**
input_kwargs
)
async
for
result
in
generator
:
delta_text
=
result
.
outputs
[
0
].
text
[
len
(
generated_text
)
:]
generated_text
=
result
.
outputs
[
0
].
text
...
...
src/llamafactory/cli.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -24,7 +24,7 @@ from .chat.chat_model import run_chat
from
.eval.evaluator
import
run_eval
from
.extras
import
logging
from
.extras.env
import
VERSION
,
print_env
from
.extras.misc
import
get_device_count
,
use_ray
from
.extras.misc
import
get_device_count
,
is_env_enabled
,
use_ray
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
...
...
@@ -86,7 +86,7 @@ def main():
elif
command
==
Command
.
EXPORT
:
export_model
()
elif
command
==
Command
.
TRAIN
:
force_torchrun
=
os
.
getenv
(
"FORCE_TORCHRUN"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
force_torchrun
=
is_env_enabled
(
"FORCE_TORCHRUN"
)
if
force_torchrun
or
(
get_device_count
()
>
1
and
not
use_ray
()):
master_addr
=
os
.
getenv
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
getenv
(
"MASTER_PORT"
,
str
(
random
.
randint
(
20001
,
29999
)))
...
...
src/llamafactory/data/__init__.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
src/llamafactory/data/aligner.py
deleted
100644 → 0
View file @
37b0ad9f
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Union
from
..extras
import
logging
from
.data_utils
import
Role
if
TYPE_CHECKING
:
from
datasets
import
Dataset
,
IterableDataset
from
transformers
import
Seq2SeqTrainingArguments
from
..hparams
import
DataArguments
from
.mm_plugin
import
ImageInput
,
VideoInput
from
.parser
import
DatasetAttr
logger
=
logging
.
get_logger
(
__name__
)
def
_convert_images
(
images
:
Union
[
"ImageInput"
,
Sequence
[
"ImageInput"
]],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
)
->
Optional
[
List
[
"ImageInput"
]]:
r
"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
if
not
isinstance
(
images
,
list
):
images
=
[
images
]
elif
len
(
images
)
==
0
:
return
None
else
:
images
=
images
[:]
if
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
for
i
in
range
(
len
(
images
)):
if
isinstance
(
images
[
i
],
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
image_dir
,
images
[
i
])):
images
[
i
]
=
os
.
path
.
join
(
data_args
.
image_dir
,
images
[
i
])
return
images
def
_convert_videos
(
videos
:
Union
[
"VideoInput"
,
Sequence
[
"VideoInput"
]],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
)
->
Optional
[
List
[
"VideoInput"
]]:
r
"""
Optionally concatenates video path to dataset dir when loading from local disk.
"""
if
not
isinstance
(
videos
,
list
):
videos
=
[
videos
]
elif
len
(
videos
)
==
0
:
return
None
else
:
videos
=
videos
[:]
if
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
for
i
in
range
(
len
(
videos
)):
if
isinstance
(
videos
[
i
],
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
image_dir
,
videos
[
i
])):
videos
[
i
]
=
os
.
path
.
join
(
data_args
.
image_dir
,
videos
[
i
])
return
videos
def
convert_alpaca
(
example
:
Dict
[
str
,
Any
],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
Any
]:
r
"""
Converts alpaca format dataset to the standard format.
"""
prompt
=
[]
if
dataset_attr
.
history
and
isinstance
(
example
[
dataset_attr
.
history
],
list
):
for
old_prompt
,
old_response
in
example
[
dataset_attr
.
history
]:
prompt
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
old_prompt
})
prompt
.
append
({
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
old_response
})
query
=
[]
if
dataset_attr
.
prompt
and
example
[
dataset_attr
.
prompt
]:
query
.
append
(
example
[
dataset_attr
.
prompt
])
if
dataset_attr
.
query
and
example
[
dataset_attr
.
query
]:
query
.
append
(
example
[
dataset_attr
.
query
])
prompt
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
"
\n
"
.
join
(
query
)})
# "prompt\nquery"
if
dataset_attr
.
kto_tag
and
isinstance
(
example
[
dataset_attr
.
kto_tag
],
bool
):
# kto example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
dataset_attr
.
response
]}]
if
example
[
dataset_attr
.
kto_tag
]:
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
else
:
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
elif
(
dataset_attr
.
ranking
and
isinstance
(
example
[
dataset_attr
.
chosen
],
str
)
and
isinstance
(
example
[
dataset_attr
.
rejected
],
str
)
):
# pairwise example
response
=
[
{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
dataset_attr
.
chosen
]},
{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
dataset_attr
.
rejected
]},
]
elif
dataset_attr
.
response
and
isinstance
(
example
[
dataset_attr
.
response
],
str
):
# normal example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
dataset_attr
.
response
]}]
else
:
# unsupervised
response
=
[]
convert_images
=
partial
(
_convert_images
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
convert_videos
=
partial
(
_convert_videos
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
output
=
{
"_prompt"
:
prompt
,
"_response"
:
response
,
"_system"
:
example
[
dataset_attr
.
system
]
if
dataset_attr
.
system
else
""
,
"_tools"
:
example
[
dataset_attr
.
tools
]
if
dataset_attr
.
tools
else
""
,
"_images"
:
convert_images
(
example
[
dataset_attr
.
images
])
if
dataset_attr
.
images
else
None
,
"_videos"
:
convert_videos
(
example
[
dataset_attr
.
videos
])
if
dataset_attr
.
videos
else
None
,
}
return
output
def
convert_sharegpt
(
example
:
Dict
[
str
,
Any
],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
Any
]:
r
"""
Converts sharegpt format dataset to the standard format.
"""
tag_mapping
=
{
dataset_attr
.
user_tag
:
Role
.
USER
.
value
,
dataset_attr
.
assistant_tag
:
Role
.
ASSISTANT
.
value
,
dataset_attr
.
observation_tag
:
Role
.
OBSERVATION
.
value
,
dataset_attr
.
function_tag
:
Role
.
FUNCTION
.
value
,
dataset_attr
.
system_tag
:
Role
.
SYSTEM
.
value
,
}
odd_tags
=
(
dataset_attr
.
user_tag
,
dataset_attr
.
observation_tag
)
even_tags
=
(
dataset_attr
.
assistant_tag
,
dataset_attr
.
function_tag
)
accept_tags
=
(
odd_tags
,
even_tags
)
messages
=
example
[
dataset_attr
.
messages
]
if
(
dataset_attr
.
system_tag
and
len
(
messages
)
!=
0
and
messages
[
0
][
dataset_attr
.
role_tag
]
==
dataset_attr
.
system_tag
):
system
=
messages
[
0
][
dataset_attr
.
content_tag
]
messages
=
messages
[
1
:]
else
:
system
=
example
[
dataset_attr
.
system
]
if
dataset_attr
.
system
else
""
aligned_messages
=
[]
broken_data
=
False
for
turn_idx
,
message
in
enumerate
(
messages
):
if
message
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
turn_idx
%
2
]:
logger
.
warning_rank0
(
f
"Invalid role tag in
{
messages
}
."
)
broken_data
=
True
aligned_messages
.
append
(
{
"role"
:
tag_mapping
[
message
[
dataset_attr
.
role_tag
]],
"content"
:
message
[
dataset_attr
.
content_tag
]}
)
if
(
not
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
!=
0
)
or
(
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
==
0
):
logger
.
warning_rank0
(
f
"Invalid message count in
{
messages
}
."
)
broken_data
=
True
if
dataset_attr
.
kto_tag
and
isinstance
(
example
[
dataset_attr
.
kto_tag
],
bool
):
# kto example
prompt
=
aligned_messages
[:
-
1
]
response
=
aligned_messages
[
-
1
:]
if
example
[
dataset_attr
.
kto_tag
]:
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
else
:
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
elif
(
dataset_attr
.
ranking
and
isinstance
(
example
[
dataset_attr
.
chosen
],
dict
)
and
isinstance
(
example
[
dataset_attr
.
rejected
],
dict
)
):
# pairwise example
chosen
=
example
[
dataset_attr
.
chosen
]
rejected
=
example
[
dataset_attr
.
rejected
]
if
(
chosen
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
or
rejected
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
):
logger
.
warning_rank0
(
f
"Invalid role tag in
{
[
chosen
,
rejected
]
}
."
)
broken_data
=
True
prompt
=
aligned_messages
response
=
[
{
"role"
:
tag_mapping
[
chosen
[
dataset_attr
.
role_tag
]],
"content"
:
chosen
[
dataset_attr
.
content_tag
]},
{
"role"
:
tag_mapping
[
rejected
[
dataset_attr
.
role_tag
]],
"content"
:
rejected
[
dataset_attr
.
content_tag
]},
]
else
:
# normal example
prompt
=
aligned_messages
[:
-
1
]
response
=
aligned_messages
[
-
1
:]
if
broken_data
:
logger
.
warning_rank0
(
"Skipping this abnormal example."
)
prompt
,
response
=
[],
[]
convert_images
=
partial
(
_convert_images
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
convert_videos
=
partial
(
_convert_videos
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
output
=
{
"_prompt"
:
prompt
,
"_response"
:
response
,
"_system"
:
system
,
"_tools"
:
example
[
dataset_attr
.
tools
]
if
dataset_attr
.
tools
else
""
,
"_images"
:
convert_images
(
example
[
dataset_attr
.
images
])
if
dataset_attr
.
images
else
None
,
"_videos"
:
convert_videos
(
example
[
dataset_attr
.
videos
])
if
dataset_attr
.
videos
else
None
,
}
return
output
def
align_dataset
(
dataset
:
Union
[
"Dataset"
,
"IterableDataset"
],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
"""
if
dataset_attr
.
formatting
==
"alpaca"
:
convert_func
=
partial
(
convert_alpaca
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
else
:
convert_func
=
partial
(
convert_sharegpt
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
kwargs
=
{}
if
not
data_args
.
streaming
:
kwargs
=
dict
(
num_proc
=
data_args
.
preprocessing_num_workers
,
load_from_cache_file
=
(
not
data_args
.
overwrite_cache
)
or
(
training_args
.
local_process_index
!=
0
),
desc
=
"Converting format of dataset"
,
)
return
dataset
.
map
(
convert_func
,
batched
=
False
,
remove_columns
=
column_names
,
**
kwargs
,
)
src/llamafactory/data/collator.py
View file @
317a82e2
...
...
@@ -18,11 +18,12 @@
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Literal
,
Optional
,
Sequence
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
transformers
import
DataCollatorForSeq2Seq
from
..extras.constants
import
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IGNORE_INDEX
,
IMAGE_PLACEHOLDER
from
..extras.packages
import
is_pillow_available
...
...
@@ -80,7 +81,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
r
"""
Data collator that supports VLMs.
Features should contain input_ids, attention_mask, labels, and optionally contain images and
vide
os.
Features should contain input_ids, attention_mask, labels, and optionally contain images
, videos
and
audi
os.
"""
template
:
Optional
[
"Template"
]
=
None
...
...
@@ -91,26 +92,54 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
raise
ValueError
(
"Template is required for MultiModalDataCollator."
)
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
"torch.Tensor"
]:
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_input_ids
=
[],
[],
[],
[],
[]
batch_images
,
batch_videos
,
batch_audios
=
[],
[],
[]
batch_imglens
,
batch_vidlens
,
batch_audlens
,
batch_input_ids
=
[],
[],
[],
[]
for
feature
in
features
:
images
=
feature
.
pop
(
"images"
,
None
)
or
[]
videos
=
feature
.
pop
(
"videos"
,
None
)
or
[]
audios
=
feature
.
pop
(
"audios"
,
None
)
or
[]
batch_images
.
extend
(
images
)
batch_videos
.
extend
(
videos
)
batch_audios
.
extend
(
audios
)
batch_imglens
.
append
(
len
(
images
))
batch_vidlens
.
append
(
len
(
videos
))
batch_audlens
.
append
(
len
(
audios
))
batch_input_ids
.
append
(
feature
[
"input_ids"
])
fake_input_ids
=
[]
if
(
self
.
processor
is
not
None
and
sum
(
batch_imglens
)
==
0
and
sum
(
batch_vidlens
)
==
0
self
.
template
.
mm_plugin
.
image_token
is
not
None
and
sum
(
batch_imglens
)
==
0
and
sum
(
batch_vidlens
)
==
0
):
# avoid process hanging in zero3/fsdp case
fake_messages
=
[{
"role"
:
"user"
,
"content"
:
IMAGE_PLACEHOLDER
}]
fake_images
=
[
Image
.
new
(
"RGB"
,
(
64
,
64
),
(
255
,
255
,
255
))]
fake_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
fake_messages
,
fake_images
,
[],
self
.
processor
)
fake_input_ids
=
self
.
tokenizer
.
encode
(
fake_messages
[
0
][
"content"
],
add_special_tokens
=
False
)
fake_input_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
fake_input_ids
,
None
,
fake_images
,
[],
self
.
tokenizer
,
self
.
processor
fake_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
fake_messages
,
fake_images
,
[],
[],
self
.
processor
)
_fake_input_ids
=
self
.
tokenizer
.
encode
(
fake_messages
[
0
][
"content"
],
add_special_tokens
=
False
)
_fake_input_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
_fake_input_ids
,
None
,
fake_images
,
[],
[],
self
.
tokenizer
,
self
.
processor
)
fake_input_ids
.
extend
(
_fake_input_ids
)
batch_images
=
fake_images
batch_imglens
[
0
]
=
1
if
(
self
.
template
.
mm_plugin
.
audio_token
is
not
None
and
sum
(
batch_audlens
)
==
0
):
# avoid process hanging in zero3/fsdp case
fake_messages
=
[{
"role"
:
"user"
,
"content"
:
AUDIO_PLACEHOLDER
}]
fake_audios
=
[
np
.
zeros
(
1600
)]
fake_messages
=
self
.
template
.
mm_plugin
.
process_messages
(
fake_messages
,
[],
[],
fake_audios
,
self
.
processor
)
_fake_input_ids
=
self
.
tokenizer
.
encode
(
fake_messages
[
0
][
"content"
],
add_special_tokens
=
False
)
_fake_input_ids
,
_
=
self
.
template
.
mm_plugin
.
process_token_ids
(
_fake_input_ids
,
None
,
[],
[],
fake_audios
,
self
.
tokenizer
,
self
.
processor
)
fake_input_ids
.
extend
(
_fake_input_ids
)
batch_audios
=
fake_audios
batch_audlens
[
0
]
=
1
if
len
(
fake_input_ids
)
!=
0
:
if
self
.
tokenizer
.
padding_side
==
"right"
:
features
[
0
][
"input_ids"
]
=
features
[
0
][
"input_ids"
]
+
fake_input_ids
features
[
0
][
"attention_mask"
]
=
features
[
0
][
"attention_mask"
]
+
[
0
]
*
len
(
fake_input_ids
)
...
...
@@ -120,12 +149,17 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features
[
0
][
"attention_mask"
]
=
[
0
]
*
len
(
fake_input_ids
)
+
features
[
0
][
"attention_mask"
]
features
[
0
][
"labels"
]
=
[
IGNORE_INDEX
]
*
len
(
fake_input_ids
)
+
features
[
0
][
"labels"
]
batch_images
=
fake_images
batch_imglens
[
0
]
=
1
batch_input_ids
[
0
]
=
features
[
0
][
"input_ids"
]
mm_inputs
=
self
.
template
.
mm_plugin
.
get_mm_inputs
(
batch_images
,
batch_videos
,
batch_imglens
,
batch_vidlens
,
batch_input_ids
,
self
.
processor
batch_images
,
batch_videos
,
batch_audios
,
batch_imglens
,
batch_vidlens
,
batch_audlens
,
batch_input_ids
,
self
.
processor
,
)
if
"token_type_ids"
in
mm_inputs
:
token_type_ids
=
mm_inputs
.
pop
(
"token_type_ids"
)
...
...
@@ -135,12 +169,16 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features
:
Dict
[
str
,
"torch.Tensor"
]
=
super
().
__call__
(
features
)
if
self
.
model
is
not
None
and
hasattr
(
self
.
model
,
"get_rope_index"
):
# for qwen2vl mrope
features
[
"position_ids"
],
features
[
"rope_deltas"
]
=
self
.
model
.
get_rope_index
(
input_ids
=
features
[
"input_ids"
],
image_grid_thw
=
mm_inputs
.
get
(
"image_grid_thw"
,
None
),
video_grid_thw
=
mm_inputs
.
get
(
"video_grid_thw"
,
None
),
attention_mask
=
features
[
"attention_mask"
],
)
rope_index_kwargs
=
{
"input_ids"
:
features
[
"input_ids"
],
"image_grid_thw"
:
mm_inputs
.
get
(
"image_grid_thw"
),
"video_grid_thw"
:
mm_inputs
.
get
(
"video_grid_thw"
),
"attention_mask"
:
features
[
"attention_mask"
],
}
if
"second_per_grid_ts"
in
mm_inputs
:
rope_index_kwargs
[
"second_per_grid_ts"
]
=
mm_inputs
.
get
(
"second_per_grid_ts"
)
features
[
"position_ids"
],
features
[
"rope_deltas"
]
=
self
.
model
.
get_rope_index
(
**
rope_index_kwargs
)
if
"cross_attention_mask"
in
mm_inputs
:
# for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask
=
mm_inputs
.
pop
(
"cross_attention_mask"
)
...
...
@@ -149,8 +187,6 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
mm_inputs
[
"cross_attention_mask"
]
=
F
.
pad
(
cross_attention_mask
,
(
0
,
0
,
0
,
0
,
0
,
seq_len
-
orig_len
))
features
.
update
(
mm_inputs
)
if
isinstance
(
features
.
get
(
"pixel_values"
),
list
):
# for pixtral inputs
features
=
features
.
data
# use default_collate() instead of BatchEncoding.to()
if
"image_bound"
in
features
:
# for minicpmv inputs
bsz
,
seq_length
=
features
[
"input_ids"
].
shape
...
...
@@ -204,6 +240,7 @@ class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"labels"
:
feature
[
f
"
{
key
}
_labels"
],
"images"
:
feature
[
"images"
],
"videos"
:
feature
[
"videos"
],
"audios"
:
feature
[
"audios"
],
}
concatenated_features
.
append
(
target_feature
)
...
...
@@ -227,6 +264,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"labels"
:
feature
[
"labels"
],
"images"
:
feature
[
"images"
],
"videos"
:
feature
[
"videos"
],
"audios"
:
feature
[
"audios"
],
}
kl_feature
=
{
"input_ids"
:
feature
[
"kl_input_ids"
],
...
...
@@ -234,6 +272,7 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
"labels"
:
feature
[
"kl_labels"
],
"images"
:
feature
[
"images"
],
"videos"
:
feature
[
"videos"
],
"audios"
:
feature
[
"audios"
],
}
target_features
.
append
(
target_feature
)
kl_features
.
append
(
kl_feature
)
...
...
@@ -244,6 +283,8 @@ class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):
batch
[
"kl_input_ids"
]
=
kl_batch
[
"input_ids"
]
batch
[
"kl_attention_mask"
]
=
kl_batch
[
"attention_mask"
]
batch
[
"kl_labels"
]
=
kl_batch
[
"labels"
]
if
"cross_attention_mask"
in
kl_batch
:
# for mllama inputs.
batch
[
"kl_cross_attention_mask"
]
=
kl_batch
[
"cross_attention_mask"
]
if
"token_type_ids"
in
kl_batch
:
batch
[
"kl_token_type_ids"
]
=
kl_batch
[
"token_type_ids"
]
...
...
src/llamafactory/data/converter.py
0 → 100644
View file @
317a82e2
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Type
,
Union
from
..extras
import
logging
from
.data_utils
import
Role
if
TYPE_CHECKING
:
from
datasets
import
Dataset
,
IterableDataset
from
transformers
import
Seq2SeqTrainingArguments
from
..hparams
import
DataArguments
from
.parser
import
DatasetAttr
logger
=
logging
.
get_logger
(
__name__
)
@
dataclass
class
DatasetConverter
:
dataset_attr
:
"DatasetAttr"
data_args
:
"DataArguments"
def
_find_medias
(
self
,
medias
:
Union
[
Any
,
Sequence
[
Any
]])
->
Optional
[
List
[
Any
]]:
r
"""
Optionally concatenates media path to media dir when loading from local disk.
"""
if
not
isinstance
(
medias
,
list
):
medias
=
[
medias
]
if
medias
is
not
None
else
[]
elif
len
(
medias
)
==
0
:
return
None
else
:
medias
=
medias
[:]
if
self
.
dataset_attr
.
load_from
in
[
"script"
,
"file"
]
and
isinstance
(
medias
[
0
],
str
):
for
i
in
range
(
len
(
medias
)):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
])):
medias
[
i
]
=
os
.
path
.
join
(
self
.
data_args
.
media_dir
,
medias
[
i
])
else
:
logger
.
warning_rank0_once
(
f
"Media
{
medias
[
i
]
}
does not exist in `media_dir`. Use original path."
)
return
medias
@
abstractmethod
def
__call__
(
self
,
example
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
r
"""
Converts a single example in the dataset to the standard format.
"""
...
@
dataclass
class
AlpacaDatasetConverter
(
DatasetConverter
):
def
__call__
(
self
,
example
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
prompt
=
[]
if
self
.
dataset_attr
.
history
and
isinstance
(
example
[
self
.
dataset_attr
.
history
],
list
):
for
old_prompt
,
old_response
in
example
[
self
.
dataset_attr
.
history
]:
prompt
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
old_prompt
})
prompt
.
append
({
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
old_response
})
query
=
[]
if
self
.
dataset_attr
.
prompt
and
example
[
self
.
dataset_attr
.
prompt
]:
query
.
append
(
example
[
self
.
dataset_attr
.
prompt
])
if
self
.
dataset_attr
.
query
and
example
[
self
.
dataset_attr
.
query
]:
query
.
append
(
example
[
self
.
dataset_attr
.
query
])
prompt
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
"
\n
"
.
join
(
query
)})
# "prompt\nquery"
if
self
.
dataset_attr
.
kto_tag
and
isinstance
(
example
[
self
.
dataset_attr
.
kto_tag
],
bool
):
# kto example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
self
.
dataset_attr
.
response
]}]
if
example
[
self
.
dataset_attr
.
kto_tag
]:
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
else
:
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
elif
(
self
.
dataset_attr
.
ranking
and
isinstance
(
example
[
self
.
dataset_attr
.
chosen
],
str
)
and
isinstance
(
example
[
self
.
dataset_attr
.
rejected
],
str
)
):
# pairwise example
response
=
[
{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
self
.
dataset_attr
.
chosen
]},
{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
self
.
dataset_attr
.
rejected
]},
]
elif
self
.
dataset_attr
.
response
and
isinstance
(
example
[
self
.
dataset_attr
.
response
],
str
):
# normal example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
example
[
self
.
dataset_attr
.
response
]}]
else
:
# unsupervised
response
=
[]
output
=
{
"_prompt"
:
prompt
,
"_response"
:
response
,
"_system"
:
example
[
self
.
dataset_attr
.
system
]
if
self
.
dataset_attr
.
system
else
""
,
"_tools"
:
example
[
self
.
dataset_attr
.
tools
]
if
self
.
dataset_attr
.
tools
else
""
,
"_images"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
images
])
if
self
.
dataset_attr
.
images
else
None
,
"_videos"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
videos
])
if
self
.
dataset_attr
.
videos
else
None
,
"_audios"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
audios
])
if
self
.
dataset_attr
.
audios
else
None
,
}
return
output
@
dataclass
class
SharegptDatasetConverter
(
DatasetConverter
):
def
__call__
(
self
,
example
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
tag_mapping
=
{
self
.
dataset_attr
.
user_tag
:
Role
.
USER
.
value
,
self
.
dataset_attr
.
assistant_tag
:
Role
.
ASSISTANT
.
value
,
self
.
dataset_attr
.
observation_tag
:
Role
.
OBSERVATION
.
value
,
self
.
dataset_attr
.
function_tag
:
Role
.
FUNCTION
.
value
,
self
.
dataset_attr
.
system_tag
:
Role
.
SYSTEM
.
value
,
}
odd_tags
=
(
self
.
dataset_attr
.
user_tag
,
self
.
dataset_attr
.
observation_tag
)
even_tags
=
(
self
.
dataset_attr
.
assistant_tag
,
self
.
dataset_attr
.
function_tag
)
accept_tags
=
(
odd_tags
,
even_tags
)
messages
=
example
[
self
.
dataset_attr
.
messages
]
if
(
self
.
dataset_attr
.
system_tag
and
len
(
messages
)
!=
0
and
messages
[
0
][
self
.
dataset_attr
.
role_tag
]
==
self
.
dataset_attr
.
system_tag
):
system
=
messages
[
0
][
self
.
dataset_attr
.
content_tag
]
messages
=
messages
[
1
:]
else
:
system
=
example
[
self
.
dataset_attr
.
system
]
if
self
.
dataset_attr
.
system
else
""
aligned_messages
=
[]
broken_data
=
False
for
turn_idx
,
message
in
enumerate
(
messages
):
if
message
[
self
.
dataset_attr
.
role_tag
]
not
in
accept_tags
[
turn_idx
%
2
]:
logger
.
warning_rank0
(
f
"Invalid role tag in
{
messages
}
."
)
broken_data
=
True
break
aligned_messages
.
append
(
{
"role"
:
tag_mapping
[
message
[
self
.
dataset_attr
.
role_tag
]],
"content"
:
message
[
self
.
dataset_attr
.
content_tag
],
}
)
if
(
not
self
.
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
!=
0
)
or
(
self
.
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
==
0
):
logger
.
warning_rank0
(
f
"Invalid message count in
{
messages
}
."
)
broken_data
=
True
if
broken_data
:
logger
.
warning_rank0
(
"Skipping this abnormal example."
)
prompt
,
response
=
[],
[]
elif
self
.
dataset_attr
.
kto_tag
and
isinstance
(
example
[
self
.
dataset_attr
.
kto_tag
],
bool
):
# kto example
prompt
=
aligned_messages
[:
-
1
]
response
=
aligned_messages
[
-
1
:]
if
example
[
self
.
dataset_attr
.
kto_tag
]:
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
else
:
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
elif
(
self
.
dataset_attr
.
ranking
and
isinstance
(
example
[
self
.
dataset_attr
.
chosen
],
dict
)
and
isinstance
(
example
[
self
.
dataset_attr
.
rejected
],
dict
)
):
# pairwise example
chosen
=
example
[
self
.
dataset_attr
.
chosen
]
rejected
=
example
[
self
.
dataset_attr
.
rejected
]
if
(
chosen
[
self
.
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
or
rejected
[
self
.
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
):
logger
.
warning_rank0
(
f
"Invalid role tag in
{
[
chosen
,
rejected
]
}
."
)
broken_data
=
True
prompt
=
aligned_messages
response
=
[
{
"role"
:
tag_mapping
[
chosen
[
self
.
dataset_attr
.
role_tag
]],
"content"
:
chosen
[
self
.
dataset_attr
.
content_tag
],
},
{
"role"
:
tag_mapping
[
rejected
[
self
.
dataset_attr
.
role_tag
]],
"content"
:
rejected
[
self
.
dataset_attr
.
content_tag
],
},
]
else
:
# normal example
prompt
=
aligned_messages
[:
-
1
]
response
=
aligned_messages
[
-
1
:]
output
=
{
"_prompt"
:
prompt
,
"_response"
:
response
,
"_system"
:
system
,
"_tools"
:
example
[
self
.
dataset_attr
.
tools
]
if
self
.
dataset_attr
.
tools
else
""
,
"_images"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
images
])
if
self
.
dataset_attr
.
images
else
None
,
"_videos"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
videos
])
if
self
.
dataset_attr
.
videos
else
None
,
"_audios"
:
self
.
_find_medias
(
example
[
self
.
dataset_attr
.
audios
])
if
self
.
dataset_attr
.
audios
else
None
,
}
return
output
DATASET_CONVERTERS
=
{
"alpaca"
:
AlpacaDatasetConverter
,
"sharegpt"
:
SharegptDatasetConverter
,
}
def
register_dataset_converter
(
name
:
str
,
dataset_converter
:
Type
[
"DatasetConverter"
])
->
None
:
r
"""
Register a new dataset converter.
"""
if
name
in
DATASET_CONVERTERS
:
raise
ValueError
(
f
"Dataset converter
{
name
}
already exists."
)
DATASET_CONVERTERS
[
name
]
=
dataset_converter
def
get_dataset_converter
(
name
:
str
,
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
)
->
"DatasetConverter"
:
r
"""
Gets a dataset converter.
"""
if
name
not
in
DATASET_CONVERTERS
:
raise
ValueError
(
f
"Dataset converter
{
name
}
not found."
)
return
DATASET_CONVERTERS
[
name
](
dataset_attr
,
data_args
)
def
align_dataset
(
dataset
:
Union
[
"Dataset"
,
"IterableDataset"
],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""
Aligned dataset:
_prompt: [{"role": "user", "content": "..."}] * (2T - 1)
_response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
_system: "..."
_tools: "...",
_images: [],
_videos: [],
_audios: [],
"""
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
kwargs
=
{}
if
not
data_args
.
streaming
:
kwargs
=
dict
(
num_proc
=
data_args
.
preprocessing_num_workers
,
load_from_cache_file
=
(
not
data_args
.
overwrite_cache
)
or
(
training_args
.
local_process_index
!=
0
),
desc
=
"Converting format of dataset"
,
)
dataset_converter
=
get_dataset_converter
(
dataset_attr
.
formatting
,
dataset_attr
,
data_args
)
return
dataset
.
map
(
dataset_converter
,
batched
=
False
,
remove_columns
=
column_names
,
**
kwargs
,
)
src/llamafactory/data/data_utils.py
View file @
317a82e2
# Copyright 202
4
the LlamaFactory team.
# Copyright 202
5
the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment