Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Qwen2_pytorch
Commits
032b90a1
Commit
032b90a1
authored
Sep 12, 2024
by
luopl
Browse files
init commit
parents
Pipeline
#1684
canceled with stages
Changes
233
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2671 additions
and
0 deletions
+2671
-0
LLaMA-Factory/src/llamafactory/api/common.py
LLaMA-Factory/src/llamafactory/api/common.py
+34
-0
LLaMA-Factory/src/llamafactory/api/protocol.py
LLaMA-Factory/src/llamafactory/api/protocol.py
+153
-0
LLaMA-Factory/src/llamafactory/chat/__init__.py
LLaMA-Factory/src/llamafactory/chat/__init__.py
+19
-0
LLaMA-Factory/src/llamafactory/chat/base_engine.py
LLaMA-Factory/src/llamafactory/chat/base_engine.py
+78
-0
LLaMA-Factory/src/llamafactory/chat/chat_model.py
LLaMA-Factory/src/llamafactory/chat/chat_model.py
+155
-0
LLaMA-Factory/src/llamafactory/chat/hf_engine.py
LLaMA-Factory/src/llamafactory/chat/hf_engine.py
+343
-0
LLaMA-Factory/src/llamafactory/chat/vllm_engine.py
LLaMA-Factory/src/llamafactory/chat/vllm_engine.py
+242
-0
LLaMA-Factory/src/llamafactory/cli.py
LLaMA-Factory/src/llamafactory/cli.py
+121
-0
LLaMA-Factory/src/llamafactory/data/__init__.py
LLaMA-Factory/src/llamafactory/data/__init__.py
+31
-0
LLaMA-Factory/src/llamafactory/data/aligner.py
LLaMA-Factory/src/llamafactory/data/aligner.py
+239
-0
LLaMA-Factory/src/llamafactory/data/collator.py
LLaMA-Factory/src/llamafactory/data/collator.py
+155
-0
LLaMA-Factory/src/llamafactory/data/data_utils.py
LLaMA-Factory/src/llamafactory/data/data_utils.py
+87
-0
LLaMA-Factory/src/llamafactory/data/formatter.py
LLaMA-Factory/src/llamafactory/data/formatter.py
+140
-0
LLaMA-Factory/src/llamafactory/data/loader.py
LLaMA-Factory/src/llamafactory/data/loader.py
+276
-0
LLaMA-Factory/src/llamafactory/data/parser.py
LLaMA-Factory/src/llamafactory/data/parser.py
+153
-0
LLaMA-Factory/src/llamafactory/data/preprocess.py
LLaMA-Factory/src/llamafactory/data/preprocess.py
+110
-0
LLaMA-Factory/src/llamafactory/data/processors/__init__.py
LLaMA-Factory/src/llamafactory/data/processors/__init__.py
+0
-0
LLaMA-Factory/src/llamafactory/data/processors/feedback.py
LLaMA-Factory/src/llamafactory/data/processors/feedback.py
+143
-0
LLaMA-Factory/src/llamafactory/data/processors/pairwise.py
LLaMA-Factory/src/llamafactory/data/processors/pairwise.py
+138
-0
LLaMA-Factory/src/llamafactory/data/processors/pretrain.py
LLaMA-Factory/src/llamafactory/data/processors/pretrain.py
+54
-0
No files found.
LLaMA-Factory/src/llamafactory/api/common.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
if
TYPE_CHECKING
:
from
pydantic
import
BaseModel
def
dictify
(
data
:
"BaseModel"
)
->
Dict
[
str
,
Any
]:
try
:
# pydantic v2
return
data
.
model_dump
(
exclude_unset
=
True
)
except
AttributeError
:
# pydantic v1
return
data
.
dict
(
exclude_unset
=
True
)
def
jsonify
(
data
:
"BaseModel"
)
->
str
:
try
:
# pydantic v2
return
json
.
dumps
(
data
.
model_dump
(
exclude_unset
=
True
),
ensure_ascii
=
False
)
except
AttributeError
:
# pydantic v1
return
data
.
json
(
exclude_unset
=
True
,
ensure_ascii
=
False
)
LLaMA-Factory/src/llamafactory/api/protocol.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
time
from
enum
import
Enum
,
unique
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
from
typing_extensions
import
Literal
@
unique
class
Role
(
str
,
Enum
):
USER
=
"user"
ASSISTANT
=
"assistant"
SYSTEM
=
"system"
FUNCTION
=
"function"
TOOL
=
"tool"
@
unique
class
Finish
(
str
,
Enum
):
STOP
=
"stop"
LENGTH
=
"length"
TOOL
=
"tool_calls"
class
ModelCard
(
BaseModel
):
id
:
str
object
:
Literal
[
"model"
]
=
"model"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
owned_by
:
Literal
[
"owner"
]
=
"owner"
class
ModelList
(
BaseModel
):
object
:
Literal
[
"list"
]
=
"list"
data
:
List
[
ModelCard
]
=
[]
class
Function
(
BaseModel
):
name
:
str
arguments
:
str
class
FunctionDefinition
(
BaseModel
):
name
:
str
description
:
str
parameters
:
Dict
[
str
,
Any
]
class
FunctionAvailable
(
BaseModel
):
type
:
Literal
[
"function"
,
"code_interpreter"
]
=
"function"
function
:
Optional
[
FunctionDefinition
]
=
None
class
FunctionCall
(
BaseModel
):
id
:
str
type
:
Literal
[
"function"
]
=
"function"
function
:
Function
class
ImageURL
(
BaseModel
):
url
:
str
class
MultimodalInputItem
(
BaseModel
):
type
:
Literal
[
"text"
,
"image_url"
]
text
:
Optional
[
str
]
=
None
image_url
:
Optional
[
ImageURL
]
=
None
class
ChatMessage
(
BaseModel
):
role
:
Role
content
:
Optional
[
Union
[
str
,
List
[
MultimodalInputItem
]]]
=
None
tool_calls
:
Optional
[
List
[
FunctionCall
]]
=
None
class
ChatCompletionMessage
(
BaseModel
):
role
:
Optional
[
Role
]
=
None
content
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
FunctionCall
]]
=
None
class
ChatCompletionRequest
(
BaseModel
):
model
:
str
messages
:
List
[
ChatMessage
]
tools
:
Optional
[
List
[
FunctionAvailable
]]
=
None
do_sample
:
Optional
[
bool
]
=
None
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
n
:
int
=
1
max_tokens
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
stream
:
bool
=
False
class
ChatCompletionResponseChoice
(
BaseModel
):
index
:
int
message
:
ChatCompletionMessage
finish_reason
:
Finish
class
ChatCompletionStreamResponseChoice
(
BaseModel
):
index
:
int
delta
:
ChatCompletionMessage
finish_reason
:
Optional
[
Finish
]
=
None
class
ChatCompletionResponseUsage
(
BaseModel
):
prompt_tokens
:
int
completion_tokens
:
int
total_tokens
:
int
class
ChatCompletionResponse
(
BaseModel
):
id
:
str
object
:
Literal
[
"chat.completion"
]
=
"chat.completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
ChatCompletionResponseChoice
]
usage
:
ChatCompletionResponseUsage
class
ChatCompletionStreamResponse
(
BaseModel
):
id
:
str
object
:
Literal
[
"chat.completion.chunk"
]
=
"chat.completion.chunk"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
ChatCompletionStreamResponseChoice
]
class
ScoreEvaluationRequest
(
BaseModel
):
model
:
str
messages
:
List
[
str
]
max_length
:
Optional
[
int
]
=
None
class
ScoreEvaluationResponse
(
BaseModel
):
id
:
str
object
:
Literal
[
"score.evaluation"
]
=
"score.evaluation"
model
:
str
scores
:
List
[
float
]
LLaMA-Factory/src/llamafactory/chat/__init__.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.base_engine
import
BaseEngine
from
.chat_model
import
ChatModel
__all__
=
[
"BaseEngine"
,
"ChatModel"
]
LLaMA-Factory/src/llamafactory/chat/base_engine.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
Union
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
from
vllm
import
AsyncLLMEngine
from
..data
import
Template
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
@
dataclass
class
Response
:
response_text
:
str
response_length
:
int
prompt_length
:
int
finish_reason
:
Literal
[
"stop"
,
"length"
]
class
BaseEngine
(
ABC
):
model
:
Union
[
"PreTrainedModel"
,
"AsyncLLMEngine"
]
tokenizer
:
"PreTrainedTokenizer"
can_generate
:
bool
template
:
"Template"
generating_args
:
Dict
[
str
,
Any
]
@
abstractmethod
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
...
@
abstractmethod
async
def
chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
...
@
abstractmethod
async
def
stream_chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
...
@
abstractmethod
async
def
get_scores
(
self
,
batch_input
:
List
[
str
],
**
input_kwargs
,
)
->
List
[
float
]:
...
LLaMA-Factory/src/llamafactory/chat/chat_model.py
0 → 100644
View file @
032b90a1
# Copyright 2024 THUDM and the LlamaFactory team.
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
import
os
from
threading
import
Thread
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Dict
,
Generator
,
List
,
Optional
,
Sequence
from
..extras.misc
import
torch_gc
from
..hparams
import
get_infer_args
from
.hf_engine
import
HuggingfaceEngine
from
.vllm_engine
import
VllmEngine
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
.base_engine
import
BaseEngine
,
Response
def
_start_background_loop
(
loop
:
"asyncio.AbstractEventLoop"
)
->
None
:
asyncio
.
set_event_loop
(
loop
)
loop
.
run_forever
()
class
ChatModel
:
def
__init__
(
self
,
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
None
:
model_args
,
data_args
,
finetuning_args
,
generating_args
=
get_infer_args
(
args
)
if
model_args
.
infer_backend
==
"huggingface"
:
self
.
engine
:
"BaseEngine"
=
HuggingfaceEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
elif
model_args
.
infer_backend
==
"vllm"
:
self
.
engine
:
"BaseEngine"
=
VllmEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
else
:
raise
NotImplementedError
(
"Unknown backend: {}"
.
format
(
model_args
.
infer_backend
))
self
.
_loop
=
asyncio
.
new_event_loop
()
self
.
_thread
=
Thread
(
target
=
_start_background_loop
,
args
=
(
self
.
_loop
,),
daemon
=
True
)
self
.
_thread
.
start
()
def
chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
achat
(
messages
,
system
,
tools
,
image
,
**
input_kwargs
),
self
.
_loop
)
return
task
.
result
()
async
def
achat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
image
,
**
input_kwargs
)
def
stream_chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
**
input_kwargs
,
)
->
Generator
[
str
,
None
,
None
]:
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
image
,
**
input_kwargs
)
while
True
:
try
:
task
=
asyncio
.
run_coroutine_threadsafe
(
generator
.
__anext__
(),
self
.
_loop
)
yield
task
.
result
()
except
StopAsyncIteration
:
break
async
def
astream_chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
async
for
new_token
in
self
.
engine
.
stream_chat
(
messages
,
system
,
tools
,
image
,
**
input_kwargs
):
yield
new_token
def
get_scores
(
self
,
batch_input
:
List
[
str
],
**
input_kwargs
,
)
->
List
[
float
]:
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
aget_scores
(
batch_input
,
**
input_kwargs
),
self
.
_loop
)
return
task
.
result
()
async
def
aget_scores
(
self
,
batch_input
:
List
[
str
],
**
input_kwargs
,
)
->
List
[
float
]:
return
await
self
.
engine
.
get_scores
(
batch_input
,
**
input_kwargs
)
def
run_chat
()
->
None
:
if
os
.
name
!=
"nt"
:
try
:
import
readline
# noqa: F401
except
ImportError
:
print
(
"Install `readline` for a better experience."
)
chat_model
=
ChatModel
()
messages
=
[]
print
(
"Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application."
)
while
True
:
try
:
query
=
input
(
"
\n
User: "
)
except
UnicodeDecodeError
:
print
(
"Detected decoding error at the inputs, please set the terminal encoding to utf-8."
)
continue
except
Exception
:
raise
if
query
.
strip
()
==
"exit"
:
break
if
query
.
strip
()
==
"clear"
:
messages
=
[]
torch_gc
()
print
(
"History has been removed."
)
continue
messages
.
append
({
"role"
:
"user"
,
"content"
:
query
})
print
(
"Assistant: "
,
end
=
""
,
flush
=
True
)
response
=
""
for
new_text
in
chat_model
.
stream_chat
(
messages
):
print
(
new_text
,
end
=
""
,
flush
=
True
)
response
+=
new_text
print
()
messages
.
append
({
"role"
:
"assistant"
,
"content"
:
response
})
LLaMA-Factory/src/llamafactory/chat/hf_engine.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
import
concurrent.futures
import
os
from
threading
import
Thread
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
torch
from
transformers
import
GenerationConfig
,
TextIteratorStreamer
from
..data
import
get_template_and_fix_tokenizer
from
..extras.logging
import
get_logger
from
..extras.misc
import
get_logits_processor
from
..model
import
load_model
,
load_tokenizer
from
.base_engine
import
BaseEngine
,
Response
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
transformers.image_processing_utils
import
BaseImageProcessor
from
trl
import
PreTrainedModelWrapper
from
..data
import
Template
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
class
HuggingfaceEngine
(
BaseEngine
):
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
self
.
can_generate
=
finetuning_args
.
stage
==
"sft"
tokenizer_module
=
load_tokenizer
(
model_args
)
self
.
tokenizer
=
tokenizer_module
[
"tokenizer"
]
self
.
processor
=
tokenizer_module
[
"processor"
]
self
.
tokenizer
.
padding_side
=
"left"
if
self
.
can_generate
else
"right"
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
.
template
,
data_args
.
tool_format
)
self
.
model
=
load_model
(
self
.
tokenizer
,
model_args
,
finetuning_args
,
is_trainable
=
False
,
add_valuehead
=
(
not
self
.
can_generate
)
)
# must after fixing tokenizer to resize vocab
self
.
generating_args
=
generating_args
.
to_dict
()
try
:
asyncio
.
get_event_loop
()
except
RuntimeError
:
logger
.
warning
(
"There is no current event loop, creating a new one."
)
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
self
.
semaphore
=
asyncio
.
Semaphore
(
int
(
os
.
environ
.
get
(
"MAX_CONCURRENT"
,
"1"
)))
@
staticmethod
def
_process_args
(
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
template
:
"Template"
,
generating_args
:
Dict
[
str
,
Any
],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
Tuple
[
Dict
[
str
,
Any
],
int
]:
if
(
processor
is
not
None
and
image
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
)
and
template
.
image_token
not
in
messages
[
0
][
"content"
]
):
# llava-like models
messages
[
0
][
"content"
]
=
template
.
image_token
+
messages
[
0
][
"content"
]
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
generating_args
[
"default_system"
]
pixel_values
=
None
prompt_ids
,
_
=
template
.
encode_oneturn
(
tokenizer
=
tokenizer
,
messages
=
paired_messages
,
system
=
system
,
tools
=
tools
)
if
processor
is
not
None
and
image
is
not
None
:
# add image features
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
batch_feature
=
image_processor
(
image
,
return_tensors
=
"pt"
)
pixel_values
=
batch_feature
.
to
(
model
.
device
)[
"pixel_values"
]
# shape (B, C, H, W)
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
template
.
image_token
)
prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
prompt_ids
prompt_length
=
len
(
prompt_ids
)
inputs
=
torch
.
tensor
([
prompt_ids
],
device
=
model
.
device
)
attention_mask
=
torch
.
ones_like
(
inputs
,
dtype
=
torch
.
bool
)
do_sample
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"do_sample"
,
None
)
temperature
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"temperature"
,
None
)
top_p
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_p"
,
None
)
top_k
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_k"
,
None
)
num_return_sequences
:
int
=
input_kwargs
.
pop
(
"num_return_sequences"
,
1
)
repetition_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"repetition_penalty"
,
None
)
length_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"length_penalty"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
stop
is
not
None
:
logger
.
warning
(
"Stop parameter is not supported by the huggingface engine yet."
)
generating_args
=
generating_args
.
copy
()
generating_args
.
update
(
dict
(
do_sample
=
do_sample
if
do_sample
is
not
None
else
generating_args
[
"do_sample"
],
temperature
=
temperature
if
temperature
is
not
None
else
generating_args
[
"temperature"
],
top_p
=
top_p
if
top_p
is
not
None
else
generating_args
[
"top_p"
],
top_k
=
top_k
if
top_k
is
not
None
else
generating_args
[
"top_k"
],
num_return_sequences
=
num_return_sequences
,
repetition_penalty
=
repetition_penalty
if
repetition_penalty
is
not
None
else
generating_args
[
"repetition_penalty"
],
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
generating_args
[
"length_penalty"
],
eos_token_id
=
[
tokenizer
.
eos_token_id
]
+
tokenizer
.
additional_special_tokens_ids
,
pad_token_id
=
tokenizer
.
pad_token_id
,
)
)
if
isinstance
(
num_return_sequences
,
int
)
and
num_return_sequences
>
1
:
# do_sample needs temperature > 0
generating_args
[
"do_sample"
]
=
True
generating_args
[
"temperature"
]
=
generating_args
[
"temperature"
]
or
1.0
if
not
generating_args
[
"temperature"
]:
generating_args
[
"do_sample"
]
=
False
if
not
generating_args
[
"do_sample"
]:
generating_args
.
pop
(
"temperature"
,
None
)
generating_args
.
pop
(
"top_p"
,
None
)
if
max_length
:
generating_args
.
pop
(
"max_new_tokens"
,
None
)
generating_args
[
"max_length"
]
=
max_length
if
max_new_tokens
:
generating_args
.
pop
(
"max_length"
,
None
)
generating_args
[
"max_new_tokens"
]
=
max_new_tokens
gen_kwargs
=
dict
(
inputs
=
inputs
,
attention_mask
=
attention_mask
,
generation_config
=
GenerationConfig
(
**
generating_args
),
logits_processor
=
get_logits_processor
(),
)
if
pixel_values
is
not
None
:
gen_kwargs
[
"pixel_values"
]
=
pixel_values
return
gen_kwargs
,
prompt_length
@
staticmethod
@
torch
.
inference_mode
()
def
_chat
(
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
template
:
"Template"
,
generating_args
:
Dict
[
str
,
Any
],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
List
[
"Response"
]:
gen_kwargs
,
prompt_length
=
HuggingfaceEngine
.
_process_args
(
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
image
,
input_kwargs
)
generate_output
=
model
.
generate
(
**
gen_kwargs
)
response_ids
=
generate_output
[:,
prompt_length
:]
response
=
tokenizer
.
batch_decode
(
response_ids
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
results
=
[]
for
i
in
range
(
len
(
response
)):
eos_index
=
(
response_ids
[
i
]
==
tokenizer
.
eos_token_id
).
nonzero
()
response_length
=
(
eos_index
[
0
].
item
()
+
1
)
if
len
(
eos_index
)
else
len
(
response_ids
[
i
])
results
.
append
(
Response
(
response_text
=
response
[
i
],
response_length
=
response_length
,
prompt_length
=
prompt_length
,
finish_reason
=
"stop"
if
len
(
eos_index
)
else
"length"
,
)
)
return
results
@
staticmethod
@
torch
.
inference_mode
()
def
_stream_chat
(
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
template
:
"Template"
,
generating_args
:
Dict
[
str
,
Any
],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
Callable
[[],
str
]:
gen_kwargs
,
_
=
HuggingfaceEngine
.
_process_args
(
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
image
,
input_kwargs
)
streamer
=
TextIteratorStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
True
)
gen_kwargs
[
"streamer"
]
=
streamer
thread
=
Thread
(
target
=
model
.
generate
,
kwargs
=
gen_kwargs
,
daemon
=
True
)
thread
.
start
()
def
stream
():
try
:
return
streamer
.
__next__
()
except
StopIteration
:
raise
StopAsyncIteration
()
return
stream
@
staticmethod
@
torch
.
inference_mode
()
def
_get_scores
(
model
:
"PreTrainedModelWrapper"
,
tokenizer
:
"PreTrainedTokenizer"
,
batch_input
:
List
[
str
],
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
List
[
float
]:
max_length
=
input_kwargs
.
pop
(
"max_length"
,
None
)
device
=
getattr
(
model
.
pretrained_model
,
"device"
,
"cuda"
)
inputs
=
tokenizer
(
batch_input
,
padding
=
True
,
truncation
=
True
,
max_length
=
max_length
or
getattr
(
model
.
config
,
"max_position_embeddings"
,
1024
),
return_tensors
=
"pt"
,
add_special_tokens
=
True
,
).
to
(
device
)
input_ids
:
torch
.
Tensor
=
inputs
[
"input_ids"
]
_
,
_
,
values
=
model
(
**
inputs
,
output_hidden_states
=
True
,
return_dict
=
True
)
if
getattr
(
model
.
config
,
"model_type"
,
None
)
==
"chatglm"
:
values
=
torch
.
transpose
(
values
,
0
,
1
)
scores
=
[]
for
i
in
range
(
input_ids
.
size
(
0
)):
end_indexes
=
(
input_ids
[
i
]
!=
tokenizer
.
pad_token_id
).
nonzero
()
end_index
=
end_indexes
[
-
1
].
item
()
if
len
(
end_indexes
)
else
0
scores
.
append
(
values
[
i
,
end_index
].
nan_to_num
().
item
())
return
scores
async
def
chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
if
not
self
.
can_generate
:
raise
ValueError
(
"The current model does not support `chat`."
)
loop
=
asyncio
.
get_running_loop
()
input_args
=
(
self
.
model
,
self
.
tokenizer
,
self
.
processor
,
self
.
template
,
self
.
generating_args
,
messages
,
system
,
tools
,
image
,
input_kwargs
,
)
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
return
await
loop
.
run_in_executor
(
pool
,
self
.
_chat
,
*
input_args
)
async
def
stream_chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
if
not
self
.
can_generate
:
raise
ValueError
(
"The current model does not support `stream_chat`."
)
loop
=
asyncio
.
get_running_loop
()
input_args
=
(
self
.
model
,
self
.
tokenizer
,
self
.
processor
,
self
.
template
,
self
.
generating_args
,
messages
,
system
,
tools
,
image
,
input_kwargs
,
)
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
stream
=
self
.
_stream_chat
(
*
input_args
)
while
True
:
try
:
yield
await
loop
.
run_in_executor
(
pool
,
stream
)
except
StopAsyncIteration
:
break
async
def
get_scores
(
self
,
batch_input
:
List
[
str
],
**
input_kwargs
,
)
->
List
[
float
]:
if
self
.
can_generate
:
raise
ValueError
(
"Cannot get scores using an auto-regressive model."
)
loop
=
asyncio
.
get_running_loop
()
input_args
=
(
self
.
model
,
self
.
tokenizer
,
batch_input
,
input_kwargs
)
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
return
await
loop
.
run_in_executor
(
pool
,
self
.
_get_scores
,
*
input_args
)
LLaMA-Factory/src/llamafactory/chat/vllm_engine.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
uuid
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
AsyncIterator
,
Dict
,
List
,
Optional
,
Sequence
,
Union
from
..data
import
get_template_and_fix_tokenizer
from
..extras.logging
import
get_logger
from
..extras.misc
import
get_device_count
from
..extras.packages
import
is_vllm_available
,
is_vllm_version_greater_than_0_5
,
is_vllm_version_greater_than_0_5_1
from
..model
import
load_config
,
load_tokenizer
from
..model.model_utils.quantization
import
QuantizationMethod
from
..model.model_utils.visual
import
LlavaMultiModalProjectorForYiVLForVLLM
from
.base_engine
import
BaseEngine
,
Response
if
is_vllm_available
():
from
vllm
import
AsyncEngineArgs
,
AsyncLLMEngine
,
RequestOutput
,
SamplingParams
from
vllm.lora.request
import
LoRARequest
if
is_vllm_version_greater_than_0_5_1
():
pass
elif
is_vllm_version_greater_than_0_5
():
from
vllm.multimodal.image
import
ImagePixelData
else
:
from
vllm.sequence
import
MultiModalData
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
transformers.image_processing_utils
import
BaseImageProcessor
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
get_logger
(
__name__
)
class
VllmEngine
(
BaseEngine
):
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
config
=
load_config
(
model_args
)
# may download model from ms hub
if
getattr
(
config
,
"quantization_config"
,
None
):
# gptq models should use float16
quantization_config
:
Dict
[
str
,
Any
]
=
getattr
(
config
,
"quantization_config"
,
None
)
quant_method
=
quantization_config
.
get
(
"quant_method"
,
""
)
if
quant_method
==
QuantizationMethod
.
GPTQ
and
model_args
.
infer_dtype
==
"auto"
:
model_args
.
infer_dtype
=
"float16"
self
.
can_generate
=
finetuning_args
.
stage
==
"sft"
tokenizer_module
=
load_tokenizer
(
model_args
)
self
.
tokenizer
=
tokenizer_module
[
"tokenizer"
]
self
.
processor
=
tokenizer_module
[
"processor"
]
self
.
tokenizer
.
padding_side
=
"left"
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
.
template
,
data_args
.
tool_format
)
self
.
generating_args
=
generating_args
.
to_dict
()
engine_args
=
{
"model"
:
model_args
.
model_name_or_path
,
"trust_remote_code"
:
True
,
"download_dir"
:
model_args
.
cache_dir
,
"dtype"
:
model_args
.
infer_dtype
,
"max_model_len"
:
model_args
.
vllm_maxlen
,
"tensor_parallel_size"
:
get_device_count
()
or
1
,
"gpu_memory_utilization"
:
model_args
.
vllm_gpu_util
,
"disable_log_stats"
:
True
,
"disable_log_requests"
:
True
,
"enforce_eager"
:
model_args
.
vllm_enforce_eager
,
"enable_lora"
:
model_args
.
adapter_name_or_path
is
not
None
,
"max_lora_rank"
:
model_args
.
vllm_max_lora_rank
,
}
if
model_args
.
visual_inputs
:
image_size
=
config
.
vision_config
.
image_size
patch_size
=
config
.
vision_config
.
patch_size
self
.
image_feature_size
=
(
image_size
//
patch_size
)
**
2
engine_args
[
"image_input_type"
]
=
"pixel_values"
engine_args
[
"image_token_id"
]
=
self
.
tokenizer
.
convert_tokens_to_ids
(
self
.
template
.
image_token
)
engine_args
[
"image_input_shape"
]
=
"1,3,{},{}"
.
format
(
image_size
,
image_size
)
engine_args
[
"image_feature_size"
]
=
self
.
image_feature_size
if
getattr
(
config
,
"is_yi_vl_derived_model"
,
None
):
import
vllm.model_executor.models.llava
logger
.
info
(
"Detected Yi-VL model, applying projector patch."
)
vllm
.
model_executor
.
models
.
llava
.
LlavaMultiModalProjector
=
LlavaMultiModalProjectorForYiVLForVLLM
self
.
model
=
AsyncLLMEngine
.
from_engine_args
(
AsyncEngineArgs
(
**
engine_args
))
if
model_args
.
adapter_name_or_path
is
not
None
:
self
.
lora_request
=
LoRARequest
(
"default"
,
1
,
model_args
.
adapter_name_or_path
[
0
])
else
:
self
.
lora_request
=
None
async
def
_generate
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
**
input_kwargs
,
)
->
AsyncIterator
[
"RequestOutput"
]:
request_id
=
"chatcmpl-{}"
.
format
(
uuid
.
uuid4
().
hex
)
if
(
self
.
processor
is
not
None
and
image
is
not
None
and
not
hasattr
(
self
.
processor
,
"image_seq_length"
)
and
self
.
template
.
image_token
not
in
messages
[
0
][
"content"
]
):
# llava-like models (TODO: paligemma models)
messages
[
0
][
"content"
]
=
self
.
template
.
image_token
*
self
.
image_feature_size
+
messages
[
0
][
"content"
]
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
self
.
generating_args
[
"default_system"
]
prompt_ids
,
_
=
self
.
template
.
encode_oneturn
(
tokenizer
=
self
.
tokenizer
,
messages
=
paired_messages
,
system
=
system
,
tools
=
tools
)
if
self
.
processor
is
not
None
and
image
is
not
None
:
# add image features
image_processor
:
"BaseImageProcessor"
=
getattr
(
self
.
processor
,
"image_processor"
)
pixel_values
=
image_processor
(
image
,
return_tensors
=
"pt"
)[
"pixel_values"
]
if
is_vllm_version_greater_than_0_5_1
():
multi_modal_data
=
{
"image"
:
pixel_values
}
elif
is_vllm_version_greater_than_0_5
():
multi_modal_data
=
ImagePixelData
(
image
=
pixel_values
)
else
:
# TODO: remove vllm 0.4.3 support
multi_modal_data
=
MultiModalData
(
type
=
MultiModalData
.
Type
.
IMAGE
,
data
=
pixel_values
)
else
:
multi_modal_data
=
None
prompt_length
=
len
(
prompt_ids
)
use_beam_search
:
bool
=
self
.
generating_args
[
"num_beams"
]
>
1
temperature
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"temperature"
,
None
)
top_p
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_p"
,
None
)
top_k
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_k"
,
None
)
num_return_sequences
:
int
=
input_kwargs
.
pop
(
"num_return_sequences"
,
1
)
repetition_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"repetition_penalty"
,
None
)
length_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"length_penalty"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
"max_new_tokens"
in
self
.
generating_args
:
max_tokens
=
self
.
generating_args
[
"max_new_tokens"
]
elif
"max_length"
in
self
.
generating_args
:
if
self
.
generating_args
[
"max_length"
]
>
prompt_length
:
max_tokens
=
self
.
generating_args
[
"max_length"
]
-
prompt_length
else
:
max_tokens
=
1
if
max_length
:
max_tokens
=
max_length
-
prompt_length
if
max_length
>
prompt_length
else
1
if
max_new_tokens
:
max_tokens
=
max_new_tokens
sampling_params
=
SamplingParams
(
n
=
num_return_sequences
,
repetition_penalty
=
(
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
generating_args
[
"repetition_penalty"
]
)
or
1.0
,
# repetition_penalty must > 0
temperature
=
temperature
if
temperature
is
not
None
else
self
.
generating_args
[
"temperature"
],
top_p
=
(
top_p
if
top_p
is
not
None
else
self
.
generating_args
[
"top_p"
])
or
1.0
,
# top_p must > 0
top_k
=
top_k
if
top_k
is
not
None
else
self
.
generating_args
[
"top_k"
],
use_beam_search
=
use_beam_search
,
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
generating_args
[
"length_penalty"
],
stop
=
stop
,
stop_token_ids
=
[
self
.
tokenizer
.
eos_token_id
]
+
self
.
tokenizer
.
additional_special_tokens_ids
,
max_tokens
=
max_tokens
,
skip_special_tokens
=
True
,
)
result_generator
=
self
.
model
.
generate
(
inputs
=
{
"prompt_token_ids"
:
prompt_ids
,
"multi_modal_data"
:
multi_modal_data
},
sampling_params
=
sampling_params
,
request_id
=
request_id
,
lora_request
=
self
.
lora_request
,
)
return
result_generator
async
def
chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
final_output
=
None
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
,
**
input_kwargs
)
async
for
request_output
in
generator
:
final_output
=
request_output
results
=
[]
for
output
in
final_output
.
outputs
:
results
.
append
(
Response
(
response_text
=
output
.
text
,
response_length
=
len
(
output
.
token_ids
),
prompt_length
=
len
(
final_output
.
prompt_token_ids
),
finish_reason
=
output
.
finish_reason
,
)
)
return
results
async
def
stream_chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
image
:
Optional
[
"NDArray"
]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
generated_text
=
""
generator
=
await
self
.
_generate
(
messages
,
system
,
tools
,
image
,
**
input_kwargs
)
async
for
result
in
generator
:
delta_text
=
result
.
outputs
[
0
].
text
[
len
(
generated_text
)
:]
generated_text
=
result
.
outputs
[
0
].
text
yield
delta_text
async
def
get_scores
(
self
,
batch_input
:
List
[
str
],
**
input_kwargs
,
)
->
List
[
float
]:
raise
NotImplementedError
(
"vLLM engine does not support get_scores."
)
LLaMA-Factory/src/llamafactory/cli.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
import
subprocess
import
sys
from
enum
import
Enum
,
unique
from
.
import
launcher
from
.api.app
import
run_api
from
.chat.chat_model
import
run_chat
from
.eval.evaluator
import
run_eval
from
.extras.env
import
VERSION
,
print_env
from
.extras.logging
import
get_logger
from
.extras.misc
import
get_device_count
from
.train.tuner
import
export_model
,
run_exp
from
.webui.interface
import
run_web_demo
,
run_web_ui
USAGE
=
(
"-"
*
70
+
"
\n
"
+
"| Usage: |
\n
"
+
"| llamafactory-cli api -h: launch an OpenAI-style API server |
\n
"
+
"| llamafactory-cli chat -h: launch a chat interface in CLI |
\n
"
+
"| llamafactory-cli eval -h: evaluate models |
\n
"
+
"| llamafactory-cli export -h: merge LoRA adapters and export model |
\n
"
+
"| llamafactory-cli train -h: train models |
\n
"
+
"| llamafactory-cli webchat -h: launch a chat interface in Web UI |
\n
"
+
"| llamafactory-cli webui: launch LlamaBoard |
\n
"
+
"| llamafactory-cli version: show version info |
\n
"
+
"-"
*
70
)
WELCOME
=
(
"-"
*
58
+
"
\n
"
+
"| Welcome to LLaMA Factory, version {}"
.
format
(
VERSION
)
+
" "
*
(
21
-
len
(
VERSION
))
+
"|
\n
|"
+
" "
*
56
+
"|
\n
"
+
"| Project page: https://github.com/hiyouga/LLaMA-Factory |
\n
"
+
"-"
*
58
)
logger
=
get_logger
(
__name__
)
@
unique
class
Command
(
str
,
Enum
):
API
=
"api"
CHAT
=
"chat"
ENV
=
"env"
EVAL
=
"eval"
EXPORT
=
"export"
TRAIN
=
"train"
WEBDEMO
=
"webchat"
WEBUI
=
"webui"
VER
=
"version"
HELP
=
"help"
def
main
():
command
=
sys
.
argv
.
pop
(
1
)
if
len
(
sys
.
argv
)
!=
1
else
Command
.
HELP
if
command
==
Command
.
API
:
run_api
()
elif
command
==
Command
.
CHAT
:
run_chat
()
elif
command
==
Command
.
ENV
:
print_env
()
elif
command
==
Command
.
EVAL
:
run_eval
()
elif
command
==
Command
.
EXPORT
:
export_model
()
elif
command
==
Command
.
TRAIN
:
force_torchrun
=
os
.
environ
.
get
(
"FORCE_TORCHRUN"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
if
force_torchrun
or
get_device_count
()
>
1
:
master_addr
=
os
.
environ
.
get
(
"MASTER_ADDR"
,
"127.0.0.1"
)
master_port
=
os
.
environ
.
get
(
"MASTER_PORT"
,
str
(
random
.
randint
(
20001
,
29999
)))
logger
.
info
(
"Initializing distributed tasks at: {}:{}"
.
format
(
master_addr
,
master_port
))
process
=
subprocess
.
run
(
(
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
).
format
(
nnodes
=
os
.
environ
.
get
(
"NNODES"
,
"1"
),
node_rank
=
os
.
environ
.
get
(
"RANK"
,
"0"
),
nproc_per_node
=
os
.
environ
.
get
(
"NPROC_PER_NODE"
,
str
(
get_device_count
())),
master_addr
=
master_addr
,
master_port
=
master_port
,
file_name
=
launcher
.
__file__
,
args
=
" "
.
join
(
sys
.
argv
[
1
:]),
),
shell
=
True
,
)
sys
.
exit
(
process
.
returncode
)
else
:
run_exp
()
elif
command
==
Command
.
WEBDEMO
:
run_web_demo
()
elif
command
==
Command
.
WEBUI
:
run_web_ui
()
elif
command
==
Command
.
VER
:
print
(
WELCOME
)
elif
command
==
Command
.
HELP
:
print
(
USAGE
)
else
:
raise
NotImplementedError
(
"Unknown command: {}"
.
format
(
command
))
LLaMA-Factory/src/llamafactory/data/__init__.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.collator
import
KTODataCollatorWithPadding
,
PairwiseDataCollatorWithPadding
,
SFTDataCollatorWith4DAttentionMask
from
.data_utils
import
Role
,
split_dataset
from
.loader
import
get_dataset
from
.template
import
TEMPLATES
,
Template
,
get_template_and_fix_tokenizer
__all__
=
[
"KTODataCollatorWithPadding"
,
"PairwiseDataCollatorWithPadding"
,
"SFTDataCollatorWith4DAttentionMask"
,
"Role"
,
"split_dataset"
,
"get_dataset"
,
"TEMPLATES"
,
"Template"
,
"get_template_and_fix_tokenizer"
,
]
LLaMA-Factory/src/llamafactory/data/aligner.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Union
from
datasets
import
Features
from
..extras.logging
import
get_logger
from
.data_utils
import
Role
if
TYPE_CHECKING
:
from
datasets
import
Dataset
,
IterableDataset
from
transformers
import
Seq2SeqTrainingArguments
from
..hparams
import
DataArguments
from
.parser
import
DatasetAttr
logger
=
get_logger
(
__name__
)
def
_convert_images
(
images
:
List
[
Any
],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
)
->
List
[
Any
]:
r
"""
Optionally concatenates image path to dataset dir when loading from local disk.
"""
outputs
=
[]
if
dataset_attr
.
load_from
in
[
"script"
,
"file"
]:
for
image
in
images
:
if
isinstance
(
image
,
str
)
and
os
.
path
.
isfile
(
os
.
path
.
join
(
data_args
.
dataset_dir
,
image
)):
outputs
.
append
(
os
.
path
.
join
(
data_args
.
dataset_dir
,
image
))
else
:
outputs
.
append
(
image
)
return
outputs
def
convert_alpaca
(
examples
:
Dict
[
str
,
List
[
Any
]],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
)
->
Dict
[
str
,
List
[
Any
]]:
r
"""
Converts alpaca format dataset to the standard format.
"""
outputs
=
{
"prompt"
:
[],
"response"
:
[],
"system"
:
[],
"tools"
:
[],
"images"
:
[]}
convert_images
=
partial
(
_convert_images
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
for
i
in
range
(
len
(
examples
[
dataset_attr
.
prompt
])):
prompt
=
[]
if
dataset_attr
.
history
and
isinstance
(
examples
[
dataset_attr
.
history
][
i
],
list
):
for
old_prompt
,
old_response
in
examples
[
dataset_attr
.
history
][
i
]:
prompt
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
old_prompt
})
prompt
.
append
({
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
old_response
})
content
=
[]
if
dataset_attr
.
prompt
and
examples
[
dataset_attr
.
prompt
][
i
]:
content
.
append
(
examples
[
dataset_attr
.
prompt
][
i
])
if
dataset_attr
.
query
and
examples
[
dataset_attr
.
query
][
i
]:
content
.
append
(
examples
[
dataset_attr
.
query
][
i
])
prompt
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
"
\n
"
.
join
(
content
)})
# "prompt\nquery"
if
dataset_attr
.
kto_tag
and
isinstance
(
examples
[
dataset_attr
.
kto_tag
][
i
],
bool
):
# kto example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
examples
[
dataset_attr
.
response
][
i
]}]
if
examples
[
dataset_attr
.
kto_tag
][
i
]:
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
else
:
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
elif
(
dataset_attr
.
ranking
and
isinstance
(
examples
[
dataset_attr
.
chosen
][
i
],
str
)
and
isinstance
(
examples
[
dataset_attr
.
rejected
][
i
],
str
)
):
# pairwise example
response
=
[
{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
examples
[
dataset_attr
.
chosen
][
i
]},
{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
examples
[
dataset_attr
.
rejected
][
i
]},
]
elif
dataset_attr
.
response
and
isinstance
(
examples
[
dataset_attr
.
response
][
i
],
str
):
# normal example
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
examples
[
dataset_attr
.
response
][
i
]}]
else
:
# unsupervised
response
=
[]
outputs
[
"prompt"
].
append
(
prompt
)
outputs
[
"response"
].
append
(
response
)
outputs
[
"system"
].
append
(
examples
[
dataset_attr
.
system
][
i
]
if
dataset_attr
.
system
else
""
)
outputs
[
"tools"
].
append
(
examples
[
dataset_attr
.
tools
][
i
]
if
dataset_attr
.
tools
else
""
)
outputs
[
"images"
].
append
(
convert_images
(
examples
[
dataset_attr
.
images
][
i
])
if
dataset_attr
.
images
else
[])
return
outputs
def
convert_sharegpt
(
examples
:
Dict
[
str
,
List
[
Any
]],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
)
->
Dict
[
str
,
List
[
Any
]]:
r
"""
Converts sharegpt format dataset to the standard format.
"""
outputs
=
{
"prompt"
:
[],
"response"
:
[],
"system"
:
[],
"tools"
:
[],
"images"
:
[]}
convert_images
=
partial
(
_convert_images
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
tag_mapping
=
{
dataset_attr
.
user_tag
:
Role
.
USER
.
value
,
dataset_attr
.
assistant_tag
:
Role
.
ASSISTANT
.
value
,
dataset_attr
.
observation_tag
:
Role
.
OBSERVATION
.
value
,
dataset_attr
.
function_tag
:
Role
.
FUNCTION
.
value
,
dataset_attr
.
system_tag
:
Role
.
SYSTEM
.
value
,
}
odd_tags
=
(
dataset_attr
.
user_tag
,
dataset_attr
.
observation_tag
)
even_tags
=
(
dataset_attr
.
assistant_tag
,
dataset_attr
.
function_tag
)
accept_tags
=
(
odd_tags
,
even_tags
)
for
i
,
messages
in
enumerate
(
examples
[
dataset_attr
.
messages
]):
if
len
(
messages
)
==
0
:
continue
if
dataset_attr
.
system_tag
and
messages
[
0
][
dataset_attr
.
role_tag
]
==
dataset_attr
.
system_tag
:
system
=
messages
[
0
][
dataset_attr
.
content_tag
]
messages
=
messages
[
1
:]
else
:
system
=
examples
[
dataset_attr
.
system
][
i
]
if
dataset_attr
.
system
else
""
aligned_messages
=
[]
broken_data
=
False
for
turn_idx
,
message
in
enumerate
(
messages
):
if
message
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
turn_idx
%
2
]:
logger
.
warning
(
"Invalid role tag in {}."
.
format
(
messages
))
broken_data
=
True
aligned_messages
.
append
(
{
"role"
:
tag_mapping
[
message
[
dataset_attr
.
role_tag
]],
"content"
:
message
[
dataset_attr
.
content_tag
]}
)
if
(
not
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
!=
0
)
or
(
dataset_attr
.
ranking
and
len
(
aligned_messages
)
%
2
==
0
):
logger
.
warning
(
"Invalid message count in {}."
.
format
(
messages
))
broken_data
=
True
if
dataset_attr
.
kto_tag
and
isinstance
(
examples
[
dataset_attr
.
kto_tag
][
i
],
bool
):
# kto example
prompt
=
aligned_messages
[:
-
1
]
response
=
aligned_messages
[
-
1
:]
if
examples
[
dataset_attr
.
kto_tag
][
i
]:
response
=
response
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
else
:
response
=
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
+
response
elif
(
dataset_attr
.
ranking
and
isinstance
(
examples
[
dataset_attr
.
chosen
][
i
],
dict
)
and
isinstance
(
examples
[
dataset_attr
.
rejected
][
i
],
dict
)
):
# pairwise example
chosen
=
examples
[
dataset_attr
.
chosen
][
i
]
rejected
=
examples
[
dataset_attr
.
rejected
][
i
]
if
(
chosen
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
or
rejected
[
dataset_attr
.
role_tag
]
not
in
accept_tags
[
-
1
]
):
logger
.
warning
(
"Invalid role tag in {}."
.
format
([
chosen
,
rejected
]))
broken_data
=
True
prompt
=
aligned_messages
response
=
[
{
"role"
:
tag_mapping
[
chosen
[
dataset_attr
.
role_tag
]],
"content"
:
chosen
[
dataset_attr
.
content_tag
]},
{
"role"
:
tag_mapping
[
rejected
[
dataset_attr
.
role_tag
]],
"content"
:
rejected
[
dataset_attr
.
content_tag
]},
]
else
:
# normal example
prompt
=
aligned_messages
[:
-
1
]
response
=
aligned_messages
[
-
1
:]
if
broken_data
:
logger
.
warning
(
"Skipping this abnormal example."
)
continue
outputs
[
"prompt"
].
append
(
prompt
)
outputs
[
"response"
].
append
(
response
)
outputs
[
"system"
].
append
(
system
)
outputs
[
"tools"
].
append
(
examples
[
dataset_attr
.
tools
][
i
]
if
dataset_attr
.
tools
else
""
)
outputs
[
"images"
].
append
(
convert_images
(
examples
[
dataset_attr
.
images
][
i
])
if
dataset_attr
.
images
else
[])
return
outputs
def
align_dataset
(
dataset
:
Union
[
"Dataset"
,
"IterableDataset"
],
dataset_attr
:
"DatasetAttr"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
r
"""
Aligned dataset:
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
system: "..."
tools: "...",
images: [],
"""
if
dataset_attr
.
formatting
==
"alpaca"
:
convert_func
=
partial
(
convert_alpaca
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
else
:
convert_func
=
partial
(
convert_sharegpt
,
dataset_attr
=
dataset_attr
,
data_args
=
data_args
)
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
features
=
Features
.
from_dict
(
{
"prompt"
:
[
{
"role"
:
{
"dtype"
:
"string"
,
"_type"
:
"Value"
},
"content"
:
{
"dtype"
:
"string"
,
"_type"
:
"Value"
}}
],
"response"
:
[
{
"role"
:
{
"dtype"
:
"string"
,
"_type"
:
"Value"
},
"content"
:
{
"dtype"
:
"string"
,
"_type"
:
"Value"
}}
],
"system"
:
{
"dtype"
:
"string"
,
"_type"
:
"Value"
},
"tools"
:
{
"dtype"
:
"string"
,
"_type"
:
"Value"
},
"images"
:
[{
"_type"
:
"Image"
}],
}
)
kwargs
=
{}
if
not
data_args
.
streaming
:
kwargs
=
dict
(
num_proc
=
data_args
.
preprocessing_num_workers
,
load_from_cache_file
=
(
not
data_args
.
overwrite_cache
)
or
(
training_args
.
local_process_index
!=
0
),
desc
=
"Converting format of dataset"
,
)
return
dataset
.
map
(
convert_func
,
batched
=
True
,
remove_columns
=
column_names
,
features
=
features
,
**
kwargs
,
)
LLaMA-Factory/src/llamafactory/data/collator.py
0 → 100644
View file @
032b90a1
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team.
#
# This code is inspired by the OpenAccess AI Collective's axolotl library.
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Literal
,
Sequence
import
torch
from
transformers
import
DataCollatorForSeq2Seq
def
prepare_4d_attention_mask
(
attention_mask_with_indices
:
"torch.Tensor"
,
dtype
:
"torch.dtype"
)
->
"torch.Tensor"
:
r
"""
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.
e.g.
```python
# input
[[1, 1, 2, 2, 2, 0]]
# output
[
[
[
[o, x, x, x, x, x],
[o, o, x, x, x, x],
[x, x, o, x, x, x],
[x, x, o, o, x, x],
[x, x, o, o, o, x],
[x, x, x, x, x, x],
]
]
]
```
where `o` equals to `0.0`, `x` equals to `min_dtype`.
"""
bsz
,
seq_len
=
attention_mask_with_indices
.
size
()
min_dtype
=
torch
.
finfo
(
dtype
).
min
expanded_mask
=
attention_mask_with_indices
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
seq_len
,
seq_len
)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
padding_mask
=
torch
.
where
(
expanded_mask
!=
0
,
1
,
0
)
# Create a block-diagonal mask.
attention_mask_4d
=
torch
.
eq
(
expanded_mask
,
expanded_mask
.
transpose
(
-
1
,
-
2
)).
int
()
*
padding_mask
# Use the lower triangular mask to zero out the upper triangular part
attention_mask_4d
*=
torch
.
tril
(
torch
.
ones
((
seq_len
,
seq_len
),
dtype
=
torch
.
long
))
# Invert the attention mask.
attention_mask_4d
=
torch
.
where
(
attention_mask_4d
!=
0
,
torch
.
tensor
(
0
,
dtype
=
dtype
),
min_dtype
)
return
attention_mask_4d
@
dataclass
class
SFTDataCollatorWith4DAttentionMask
(
DataCollatorForSeq2Seq
):
r
"""
Data collator for 4d attention mask.
"""
block_diag_attn
:
bool
=
False
attn_implementation
:
Literal
[
"eager"
,
"sdpa"
,
"flash_attention_2"
]
=
"eager"
compute_dtype
:
"torch.dtype"
=
torch
.
float32
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
"torch.Tensor"
]:
features
=
super
().
__call__
(
features
)
if
self
.
block_diag_attn
and
self
.
attn_implementation
!=
"flash_attention_2"
:
features
[
"attention_mask"
]
=
prepare_4d_attention_mask
(
features
[
"attention_mask"
],
self
.
compute_dtype
)
return
features
@
dataclass
class
PairwiseDataCollatorWithPadding
(
DataCollatorForSeq2Seq
):
r
"""
Data collator for pairwise data.
"""
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
"torch.Tensor"
]:
r
"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features
=
[]
for
key
in
(
"chosen"
,
"rejected"
):
for
feature
in
features
:
target_feature
=
{
"input_ids"
:
feature
[
"{}_input_ids"
.
format
(
key
)],
"attention_mask"
:
feature
[
"{}_attention_mask"
.
format
(
key
)],
"labels"
:
feature
[
"{}_labels"
.
format
(
key
)],
}
if
"pixel_values"
in
feature
:
target_feature
[
"pixel_values"
]
=
feature
[
"pixel_values"
]
if
"{}_token_type_ids"
.
format
(
key
)
in
feature
:
target_feature
[
"token_type_ids"
]
=
feature
[
"{}_token_type_ids"
.
format
(
key
)]
concatenated_features
.
append
(
target_feature
)
return
super
().
__call__
(
concatenated_features
)
@
dataclass
class
KTODataCollatorWithPadding
(
DataCollatorForSeq2Seq
):
r
"""
Data collator for KTO data.
"""
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
"torch.Tensor"
]:
target_features
=
[]
kl_features
=
[]
kto_tags
=
[]
for
feature
in
features
:
target_feature
=
{
"input_ids"
:
feature
[
"input_ids"
],
"attention_mask"
:
feature
[
"attention_mask"
],
"labels"
:
feature
[
"labels"
],
}
kl_feature
=
{
"input_ids"
:
feature
[
"kl_input_ids"
],
"attention_mask"
:
feature
[
"kl_attention_mask"
],
"labels"
:
feature
[
"kl_labels"
],
}
if
"pixel_values"
in
feature
:
target_feature
[
"pixel_values"
]
=
feature
[
"pixel_values"
]
if
"token_type_ids"
in
feature
:
target_feature
[
"token_type_ids"
]
=
feature
[
"token_type_ids"
]
kl_feature
[
"token_type_ids"
]
=
feature
[
"kl_token_type_ids"
]
target_features
.
append
(
target_feature
)
kl_features
.
append
(
kl_feature
)
kto_tags
.
append
(
feature
[
"kto_tags"
])
batch
=
super
().
__call__
(
target_features
)
kl_batch
=
super
().
__call__
(
kl_features
)
batch
[
"kl_input_ids"
]
=
kl_batch
[
"input_ids"
]
batch
[
"kl_attention_mask"
]
=
kl_batch
[
"attention_mask"
]
batch
[
"kl_labels"
]
=
kl_batch
[
"labels"
]
if
"token_type_ids"
in
batch
:
batch
[
"kl_token_type_ids"
]
=
kl_batch
[
"token_type_ids"
]
batch
[
"kto_tags"
]
=
torch
.
tensor
(
kto_tags
)
return
batch
LLaMA-Factory/src/llamafactory/data/data_utils.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
enum
import
Enum
,
unique
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Set
,
TypedDict
,
Union
from
datasets
import
DatasetDict
,
concatenate_datasets
,
interleave_datasets
from
..extras.logging
import
get_logger
if
TYPE_CHECKING
:
from
datasets
import
Dataset
,
IterableDataset
from
..hparams
import
DataArguments
logger
=
get_logger
(
__name__
)
SLOTS
=
Sequence
[
Union
[
str
,
Set
[
str
],
Dict
[
str
,
str
]]]
@
unique
class
Role
(
str
,
Enum
):
USER
=
"user"
ASSISTANT
=
"assistant"
SYSTEM
=
"system"
FUNCTION
=
"function"
OBSERVATION
=
"observation"
class
DatasetModule
(
TypedDict
):
train_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]
eval_dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]
def
merge_dataset
(
all_datasets
:
List
[
Union
[
"Dataset"
,
"IterableDataset"
]],
data_args
:
"DataArguments"
,
seed
:
int
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
if
len
(
all_datasets
)
==
1
:
return
all_datasets
[
0
]
elif
data_args
.
mix_strategy
==
"concat"
:
if
data_args
.
streaming
:
logger
.
warning
(
"The samples between different datasets will not be mixed in streaming mode."
)
return
concatenate_datasets
(
all_datasets
)
elif
data_args
.
mix_strategy
.
startswith
(
"interleave"
):
if
not
data_args
.
streaming
:
logger
.
warning
(
"We recommend using `mix_strategy=concat` in non-streaming mode."
)
return
interleave_datasets
(
datasets
=
all_datasets
,
probabilities
=
data_args
.
interleave_probs
,
seed
=
seed
,
stopping_strategy
=
"first_exhausted"
if
data_args
.
mix_strategy
.
endswith
(
"under"
)
else
"all_exhausted"
,
)
else
:
raise
ValueError
(
"Unknown mixing strategy."
)
def
split_dataset
(
dataset
:
Union
[
"Dataset"
,
"IterableDataset"
],
data_args
:
"DataArguments"
,
seed
:
int
)
->
"DatasetDict"
:
r
"""
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
"""
if
data_args
.
streaming
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
seed
)
val_set
=
dataset
.
take
(
int
(
data_args
.
val_size
))
train_set
=
dataset
.
skip
(
int
(
data_args
.
val_size
))
return
DatasetDict
({
"train"
:
train_set
,
"validation"
:
val_set
})
else
:
val_size
=
int
(
data_args
.
val_size
)
if
data_args
.
val_size
>
1
else
data_args
.
val_size
dataset
=
dataset
.
train_test_split
(
test_size
=
val_size
,
seed
=
seed
)
return
DatasetDict
({
"train"
:
dataset
[
"train"
],
"validation"
:
dataset
[
"test"
]})
LLaMA-Factory/src/llamafactory/data/formatter.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
re
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Literal
,
Optional
,
Tuple
,
Union
from
.data_utils
import
SLOTS
from
.tool_utils
import
DefaultToolUtils
,
GLM4ToolUtils
@
dataclass
class
Formatter
(
ABC
):
slots
:
SLOTS
=
field
(
default_factory
=
list
)
tool_format
:
Optional
[
Literal
[
"default"
,
"glm4"
]]
=
None
@
abstractmethod
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
...
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
Tuple
[
str
,
str
]]]:
raise
NotImplementedError
@
dataclass
class
EmptyFormatter
(
Formatter
):
def
__post_init__
(
self
):
has_placeholder
=
False
for
slot
in
filter
(
lambda
s
:
isinstance
(
s
,
str
),
self
.
slots
):
if
re
.
search
(
r
"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}"
,
slot
):
has_placeholder
=
True
if
has_placeholder
:
raise
ValueError
(
"Empty formatter should not contain any placeholder."
)
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
return
self
.
slots
@
dataclass
class
StringFormatter
(
Formatter
):
def
__post_init__
(
self
):
has_placeholder
=
False
for
slot
in
filter
(
lambda
s
:
isinstance
(
s
,
str
),
self
.
slots
):
if
re
.
search
(
r
"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}"
,
slot
):
has_placeholder
=
True
if
not
has_placeholder
:
raise
ValueError
(
"A placeholder is required in the string formatter."
)
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
elements
=
[]
for
slot
in
self
.
slots
:
if
isinstance
(
slot
,
str
):
for
name
,
value
in
kwargs
.
items
():
if
not
isinstance
(
value
,
str
):
raise
RuntimeError
(
"Expected a string, got {}"
.
format
(
value
))
slot
=
slot
.
replace
(
"{{"
+
name
+
"}}"
,
value
,
1
)
elements
.
append
(
slot
)
elif
isinstance
(
slot
,
(
dict
,
set
)):
elements
.
append
(
slot
)
else
:
raise
RuntimeError
(
"Input must be string, set[str] or dict[str, str], got {}"
.
format
(
type
(
slot
)))
return
elements
@
dataclass
class
FunctionFormatter
(
Formatter
):
def
__post_init__
(
self
):
if
self
.
tool_format
==
"default"
:
self
.
slots
=
DefaultToolUtils
.
get_function_slots
()
+
self
.
slots
elif
self
.
tool_format
==
"glm4"
:
self
.
slots
=
GLM4ToolUtils
.
get_function_slots
()
+
self
.
slots
else
:
raise
NotImplementedError
(
"Tool format {} was not found."
.
format
(
self
.
tool_format
))
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
content
=
kwargs
.
pop
(
"content"
)
functions
:
List
[
Tuple
[
str
,
str
]]
=
[]
try
:
tool_calls
=
json
.
loads
(
content
)
if
not
isinstance
(
tool_calls
,
list
):
# parallel function call
tool_calls
=
[
tool_calls
]
for
tool_call
in
tool_calls
:
functions
.
append
((
tool_call
[
"name"
],
json
.
dumps
(
tool_call
[
"arguments"
],
ensure_ascii
=
False
)))
except
json
.
JSONDecodeError
:
functions
=
[]
elements
=
[]
for
name
,
arguments
in
functions
:
for
slot
in
self
.
slots
:
if
isinstance
(
slot
,
str
):
slot
=
slot
.
replace
(
"{{name}}"
,
name
).
replace
(
"{{arguments}}"
,
arguments
)
elements
.
append
(
slot
)
elif
isinstance
(
slot
,
(
dict
,
set
)):
elements
.
append
(
slot
)
else
:
raise
RuntimeError
(
"Input must be string, set[str] or dict[str, str], got {}"
.
format
(
type
(
slot
)))
return
elements
@
dataclass
class
ToolFormatter
(
Formatter
):
def
__post_init__
(
self
):
if
self
.
tool_format
==
"default"
:
self
.
_tool_formatter
=
DefaultToolUtils
.
tool_formatter
self
.
_tool_extractor
=
DefaultToolUtils
.
tool_extractor
elif
self
.
tool_format
==
"glm4"
:
self
.
_tool_formatter
=
GLM4ToolUtils
.
tool_formatter
self
.
_tool_extractor
=
GLM4ToolUtils
.
tool_extractor
else
:
raise
NotImplementedError
(
"Tool format {} was not found."
.
format
(
self
.
tool_format
))
def
apply
(
self
,
**
kwargs
)
->
SLOTS
:
content
=
kwargs
.
pop
(
"content"
)
try
:
tools
=
json
.
loads
(
content
)
return
[
self
.
_tool_formatter
(
tools
)
if
len
(
tools
)
!=
0
else
""
]
except
json
.
JSONDecodeError
:
return
[
""
]
def
extract
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
Tuple
[
str
,
str
]]]:
return
self
.
_tool_extractor
(
content
)
LLaMA-Factory/src/llamafactory/data/loader.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
sys
from
typing
import
TYPE_CHECKING
,
Dict
,
Literal
,
Optional
,
Sequence
,
Union
import
numpy
as
np
from
datasets
import
DatasetDict
,
load_dataset
,
load_from_disk
from
transformers.utils.versions
import
require_version
from
..extras.constants
import
FILEEXT2TYPE
from
..extras.logging
import
get_logger
from
..extras.misc
import
has_tokenized_data
from
.aligner
import
align_dataset
from
.data_utils
import
merge_dataset
,
split_dataset
from
.parser
import
get_dataset_list
from
.preprocess
import
get_preprocess_and_print_func
from
.template
import
get_template_and_fix_tokenizer
if
TYPE_CHECKING
:
from
datasets
import
Dataset
,
IterableDataset
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
,
Seq2SeqTrainingArguments
from
..hparams
import
DataArguments
,
ModelArguments
from
.data_utils
import
DatasetModule
from
.parser
import
DatasetAttr
from
.template
import
Template
logger
=
get_logger
(
__name__
)
def
_load_single_dataset
(
dataset_attr
:
"DatasetAttr"
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
)
->
Union
[
"Dataset"
,
"IterableDataset"
]:
logger
.
info
(
"Loading dataset {}..."
.
format
(
dataset_attr
))
data_path
,
data_name
,
data_dir
,
data_files
=
None
,
None
,
None
,
None
if
dataset_attr
.
load_from
in
[
"hf_hub"
,
"ms_hub"
]:
data_path
=
dataset_attr
.
dataset_name
data_name
=
dataset_attr
.
subset
data_dir
=
dataset_attr
.
folder
elif
dataset_attr
.
load_from
==
"script"
:
data_path
=
os
.
path
.
join
(
data_args
.
dataset_dir
,
dataset_attr
.
dataset_name
)
data_name
=
dataset_attr
.
subset
data_dir
=
dataset_attr
.
folder
elif
dataset_attr
.
load_from
==
"file"
:
data_files
=
[]
local_path
=
os
.
path
.
join
(
data_args
.
dataset_dir
,
dataset_attr
.
dataset_name
)
if
os
.
path
.
isdir
(
local_path
):
# is directory
for
file_name
in
os
.
listdir
(
local_path
):
data_files
.
append
(
os
.
path
.
join
(
local_path
,
file_name
))
if
data_path
is
None
:
data_path
=
FILEEXT2TYPE
.
get
(
file_name
.
split
(
"."
)[
-
1
],
None
)
elif
data_path
!=
FILEEXT2TYPE
.
get
(
file_name
.
split
(
"."
)[
-
1
],
None
):
raise
ValueError
(
"File types should be identical."
)
elif
os
.
path
.
isfile
(
local_path
):
# is file
data_files
.
append
(
local_path
)
data_path
=
FILEEXT2TYPE
.
get
(
local_path
.
split
(
"."
)[
-
1
],
None
)
else
:
raise
ValueError
(
"File {} not found."
.
format
(
local_path
))
if
data_path
is
None
:
raise
ValueError
(
"Allowed file types: {}."
.
format
(
","
.
join
(
FILEEXT2TYPE
.
keys
())))
else
:
raise
NotImplementedError
(
"Unknown load type: {}."
.
format
(
dataset_attr
.
load_from
))
if
dataset_attr
.
load_from
==
"ms_hub"
:
require_version
(
"modelscope>=1.11.0"
,
"To fix: pip install modelscope>=1.11.0"
)
from
modelscope
import
MsDataset
from
modelscope.utils.config_ds
import
MS_DATASETS_CACHE
cache_dir
=
model_args
.
cache_dir
or
MS_DATASETS_CACHE
dataset
=
MsDataset
.
load
(
dataset_name
=
data_path
,
subset_name
=
data_name
,
data_dir
=
data_dir
,
data_files
=
data_files
,
split
=
dataset_attr
.
split
,
cache_dir
=
cache_dir
,
token
=
model_args
.
ms_hub_token
,
use_streaming
=
(
data_args
.
streaming
and
(
dataset_attr
.
load_from
!=
"file"
)),
)
if
isinstance
(
dataset
,
MsDataset
):
dataset
=
dataset
.
to_hf_dataset
()
else
:
dataset
=
load_dataset
(
path
=
data_path
,
name
=
data_name
,
data_dir
=
data_dir
,
data_files
=
data_files
,
split
=
dataset_attr
.
split
,
cache_dir
=
model_args
.
cache_dir
,
token
=
model_args
.
hf_hub_token
,
streaming
=
(
data_args
.
streaming
and
(
dataset_attr
.
load_from
!=
"file"
)),
trust_remote_code
=
True
,
)
if
data_args
.
streaming
and
(
dataset_attr
.
load_from
==
"file"
):
# faster than specifying streaming=True
dataset
=
dataset
.
to_iterable_dataset
()
# TODO: add num shards parameter
if
dataset_attr
.
num_samples
is
not
None
and
not
data_args
.
streaming
:
target_num
=
dataset_attr
.
num_samples
indexes
=
np
.
random
.
permutation
(
len
(
dataset
))[:
target_num
]
target_num
-=
len
(
indexes
)
if
target_num
>
0
:
expand_indexes
=
np
.
random
.
choice
(
len
(
dataset
),
target_num
)
indexes
=
np
.
concatenate
((
indexes
,
expand_indexes
),
axis
=
0
)
assert
len
(
indexes
)
==
dataset_attr
.
num_samples
,
"Sample num mismatched."
dataset
=
dataset
.
select
(
indexes
)
logger
.
info
(
"Sampled {} examples from dataset {}."
.
format
(
dataset_attr
.
num_samples
,
dataset_attr
))
if
data_args
.
max_samples
is
not
None
:
# truncate dataset
max_samples
=
min
(
data_args
.
max_samples
,
len
(
dataset
))
dataset
=
dataset
.
select
(
range
(
max_samples
))
return
align_dataset
(
dataset
,
dataset_attr
,
data_args
,
training_args
)
def
_get_merged_dataset
(
dataset_names
:
Optional
[
Sequence
[
str
]],
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]:
if
dataset_names
is
None
:
return
None
datasets
=
[]
for
dataset_attr
in
get_dataset_list
(
dataset_names
,
data_args
.
dataset_dir
):
if
(
stage
==
"rm"
and
dataset_attr
.
ranking
is
False
)
or
(
stage
!=
"rm"
and
dataset_attr
.
ranking
is
True
):
raise
ValueError
(
"The dataset is not applicable in the current training stage."
)
datasets
.
append
(
_load_single_dataset
(
dataset_attr
,
model_args
,
data_args
,
training_args
))
return
merge_dataset
(
datasets
,
data_args
,
seed
=
training_args
.
seed
)
def
_get_preprocessed_dataset
(
dataset
:
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]],
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
is_eval
:
bool
=
False
,
)
->
Optional
[
Union
[
"Dataset"
,
"IterableDataset"
]]:
if
dataset
is
None
:
return
None
preprocess_func
,
print_function
=
get_preprocess_and_print_func
(
data_args
,
stage
,
template
,
tokenizer
,
processor
,
do_generate
=
(
training_args
.
predict_with_generate
and
is_eval
)
)
column_names
=
list
(
next
(
iter
(
dataset
)).
keys
())
kwargs
=
{}
if
not
data_args
.
streaming
:
kwargs
=
dict
(
num_proc
=
data_args
.
preprocessing_num_workers
,
load_from_cache_file
=
(
not
data_args
.
overwrite_cache
)
or
(
training_args
.
local_process_index
!=
0
),
desc
=
"Running tokenizer on dataset"
,
)
dataset
=
dataset
.
map
(
preprocess_func
,
batched
=
True
,
remove_columns
=
column_names
,
**
kwargs
)
if
training_args
.
should_log
:
try
:
print
(
"eval example:"
if
is_eval
else
"training example:"
)
print_function
(
next
(
iter
(
dataset
)))
except
StopIteration
:
if
stage
==
"pt"
:
raise
RuntimeError
(
"Cannot find sufficient samples, consider increasing dataset size."
)
else
:
raise
RuntimeError
(
"Cannot find valid samples, check `data/README.md` for the data format."
)
return
dataset
def
get_dataset
(
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
training_args
:
"Seq2SeqTrainingArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
]
=
None
,
)
->
"DatasetModule"
:
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
.
template
,
data_args
.
tool_format
)
if
data_args
.
train_on_prompt
and
template
.
efficient_eos
:
raise
ValueError
(
"Current template does not support `train_on_prompt`."
)
# Load tokenized dataset
if
data_args
.
tokenized_path
is
not
None
:
if
has_tokenized_data
(
data_args
.
tokenized_path
):
logger
.
warning
(
"Loading dataset from disk will ignore other data arguments."
)
dataset_dict
:
"DatasetDict"
=
load_from_disk
(
data_args
.
tokenized_path
)
logger
.
info
(
"Loaded tokenized dataset from {}."
.
format
(
data_args
.
tokenized_path
))
dataset_module
:
Dict
[
str
,
"Dataset"
]
=
{}
if
"train"
in
dataset_dict
:
dataset_module
[
"train_dataset"
]
=
dataset_dict
[
"train"
]
if
"validation"
in
dataset_dict
:
dataset_module
[
"eval_dataset"
]
=
dataset_dict
[
"validation"
]
if
data_args
.
streaming
:
dataset_module
=
{
k
:
v
.
to_iterable_dataset
()
for
k
,
v
in
dataset_module
.
items
()}
return
dataset_module
if
data_args
.
streaming
:
raise
ValueError
(
"Turn off `streaming` when saving dataset to disk."
)
# Load and preprocess dataset
with
training_args
.
main_process_first
(
desc
=
"load dataset"
):
dataset
=
_get_merged_dataset
(
data_args
.
dataset
,
model_args
,
data_args
,
training_args
,
stage
)
eval_dataset
=
_get_merged_dataset
(
data_args
.
eval_dataset
,
model_args
,
data_args
,
training_args
,
stage
)
with
training_args
.
main_process_first
(
desc
=
"pre-process dataset"
):
dataset
=
_get_preprocessed_dataset
(
dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
False
)
eval_dataset
=
_get_preprocessed_dataset
(
eval_dataset
,
data_args
,
training_args
,
stage
,
template
,
tokenizer
,
processor
,
is_eval
=
True
)
if
data_args
.
val_size
>
1e-6
:
dataset_dict
=
split_dataset
(
dataset
,
data_args
,
seed
=
training_args
.
seed
)
else
:
dataset_dict
=
{}
if
dataset
is
not
None
:
if
data_args
.
streaming
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
training_args
.
seed
)
dataset_dict
[
"train"
]
=
dataset
if
eval_dataset
is
not
None
:
if
data_args
.
streaming
:
eval_dataset
=
eval_dataset
.
shuffle
(
buffer_size
=
data_args
.
buffer_size
,
seed
=
training_args
.
seed
)
dataset_dict
[
"validation"
]
=
eval_dataset
dataset_dict
=
DatasetDict
(
dataset_dict
)
if
data_args
.
tokenized_path
is
not
None
:
if
training_args
.
should_save
:
dataset_dict
.
save_to_disk
(
data_args
.
tokenized_path
)
logger
.
info
(
"Tokenized dataset saved at {}."
.
format
(
data_args
.
tokenized_path
))
logger
.
info
(
"Please restart the training with `tokenized_path: {}`."
.
format
(
data_args
.
tokenized_path
))
sys
.
exit
(
0
)
dataset_module
=
{}
if
"train"
in
dataset_dict
:
dataset_module
[
"train_dataset"
]
=
dataset_dict
[
"train"
]
if
"validation"
in
dataset_dict
:
dataset_module
[
"eval_dataset"
]
=
dataset_dict
[
"validation"
]
return
dataset_module
LLaMA-Factory/src/llamafactory/data/parser.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
from
transformers.utils
import
cached_file
from
..extras.constants
import
DATA_CONFIG
from
..extras.misc
import
use_modelscope
@
dataclass
class
DatasetAttr
:
r
"""
Dataset attributes.
"""
# basic configs
load_from
:
Literal
[
"hf_hub"
,
"ms_hub"
,
"script"
,
"file"
]
dataset_name
:
str
formatting
:
Literal
[
"alpaca"
,
"sharegpt"
]
=
"alpaca"
ranking
:
bool
=
False
# extra configs
subset
:
Optional
[
str
]
=
None
split
:
str
=
"train"
folder
:
Optional
[
str
]
=
None
num_samples
:
Optional
[
int
]
=
None
# common columns
system
:
Optional
[
str
]
=
None
tools
:
Optional
[
str
]
=
None
images
:
Optional
[
str
]
=
None
# rlhf columns
chosen
:
Optional
[
str
]
=
None
rejected
:
Optional
[
str
]
=
None
kto_tag
:
Optional
[
str
]
=
None
# alpaca columns
prompt
:
Optional
[
str
]
=
"instruction"
query
:
Optional
[
str
]
=
"input"
response
:
Optional
[
str
]
=
"output"
history
:
Optional
[
str
]
=
None
# sharegpt columns
messages
:
Optional
[
str
]
=
"conversations"
# sharegpt tags
role_tag
:
Optional
[
str
]
=
"from"
content_tag
:
Optional
[
str
]
=
"value"
user_tag
:
Optional
[
str
]
=
"human"
assistant_tag
:
Optional
[
str
]
=
"gpt"
observation_tag
:
Optional
[
str
]
=
"observation"
function_tag
:
Optional
[
str
]
=
"function_call"
system_tag
:
Optional
[
str
]
=
"system"
def
__repr__
(
self
)
->
str
:
return
self
.
dataset_name
def
set_attr
(
self
,
key
:
str
,
obj
:
Dict
[
str
,
Any
],
default
:
Optional
[
Any
]
=
None
)
->
None
:
setattr
(
self
,
key
,
obj
.
get
(
key
,
default
))
def
get_dataset_list
(
dataset_names
:
Optional
[
Sequence
[
str
]],
dataset_dir
:
str
)
->
List
[
"DatasetAttr"
]:
r
"""
Gets the attributes of the datasets.
"""
if
dataset_names
is
None
:
dataset_names
=
[]
if
dataset_dir
==
"ONLINE"
:
dataset_info
=
None
else
:
if
dataset_dir
.
startswith
(
"REMOTE:"
):
config_path
=
cached_file
(
path_or_repo_id
=
dataset_dir
[
7
:],
filename
=
DATA_CONFIG
,
repo_type
=
"dataset"
)
else
:
config_path
=
os
.
path
.
join
(
dataset_dir
,
DATA_CONFIG
)
try
:
with
open
(
config_path
,
"r"
)
as
f
:
dataset_info
=
json
.
load
(
f
)
except
Exception
as
err
:
if
len
(
dataset_names
)
!=
0
:
raise
ValueError
(
"Cannot open {} due to {}."
.
format
(
config_path
,
str
(
err
)))
dataset_info
=
None
dataset_list
:
List
[
"DatasetAttr"
]
=
[]
for
name
in
dataset_names
:
if
dataset_info
is
None
:
# dataset_dir is ONLINE
load_from
=
"ms_hub"
if
use_modelscope
()
else
"hf_hub"
dataset_attr
=
DatasetAttr
(
load_from
,
dataset_name
=
name
)
dataset_list
.
append
(
dataset_attr
)
continue
if
name
not
in
dataset_info
:
raise
ValueError
(
"Undefined dataset {} in {}."
.
format
(
name
,
DATA_CONFIG
))
has_hf_url
=
"hf_hub_url"
in
dataset_info
[
name
]
has_ms_url
=
"ms_hub_url"
in
dataset_info
[
name
]
if
has_hf_url
or
has_ms_url
:
if
(
use_modelscope
()
and
has_ms_url
)
or
(
not
has_hf_url
):
dataset_attr
=
DatasetAttr
(
"ms_hub"
,
dataset_name
=
dataset_info
[
name
][
"ms_hub_url"
])
else
:
dataset_attr
=
DatasetAttr
(
"hf_hub"
,
dataset_name
=
dataset_info
[
name
][
"hf_hub_url"
])
elif
"script_url"
in
dataset_info
[
name
]:
dataset_attr
=
DatasetAttr
(
"script"
,
dataset_name
=
dataset_info
[
name
][
"script_url"
])
else
:
dataset_attr
=
DatasetAttr
(
"file"
,
dataset_name
=
dataset_info
[
name
][
"file_name"
])
dataset_attr
.
set_attr
(
"formatting"
,
dataset_info
[
name
],
default
=
"alpaca"
)
dataset_attr
.
set_attr
(
"ranking"
,
dataset_info
[
name
],
default
=
False
)
dataset_attr
.
set_attr
(
"subset"
,
dataset_info
[
name
])
dataset_attr
.
set_attr
(
"split"
,
dataset_info
[
name
],
default
=
"train"
)
dataset_attr
.
set_attr
(
"folder"
,
dataset_info
[
name
])
dataset_attr
.
set_attr
(
"num_samples"
,
dataset_info
[
name
])
if
"columns"
in
dataset_info
[
name
]:
column_names
=
[
"system"
,
"tools"
,
"images"
,
"chosen"
,
"rejected"
,
"kto_tag"
]
if
dataset_attr
.
formatting
==
"alpaca"
:
column_names
.
extend
([
"prompt"
,
"query"
,
"response"
,
"history"
])
else
:
column_names
.
extend
([
"messages"
])
for
column_name
in
column_names
:
dataset_attr
.
set_attr
(
column_name
,
dataset_info
[
name
][
"columns"
])
if
dataset_attr
.
formatting
==
"sharegpt"
and
"tags"
in
dataset_info
[
name
]:
tag_names
=
(
"role_tag"
,
"content_tag"
,
"user_tag"
,
"assistant_tag"
,
"observation_tag"
,
"function_tag"
,
"system_tag"
,
)
for
tag
in
tag_names
:
dataset_attr
.
set_attr
(
tag
,
dataset_info
[
name
][
"tags"
])
dataset_list
.
append
(
dataset_attr
)
return
dataset_list
LLaMA-Factory/src/llamafactory/data/preprocess.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Callable
,
Literal
,
Optional
,
Tuple
from
.processors.feedback
import
preprocess_feedback_dataset
from
.processors.pairwise
import
preprocess_pairwise_dataset
,
print_pairwise_dataset_example
from
.processors.pretrain
import
preprocess_pretrain_dataset
from
.processors.supervised
import
(
preprocess_packed_supervised_dataset
,
preprocess_supervised_dataset
,
print_supervised_dataset_example
,
)
from
.processors.unsupervised
import
preprocess_unsupervised_dataset
,
print_unsupervised_dataset_example
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
..hparams
import
DataArguments
from
.template
import
Template
def
get_preprocess_and_print_func
(
data_args
:
"DataArguments"
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"kto"
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
do_generate
:
bool
=
False
,
)
->
Tuple
[
Callable
,
Callable
]:
if
stage
==
"pt"
:
preprocess_func
=
partial
(
preprocess_pretrain_dataset
,
tokenizer
=
tokenizer
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_unsupervised_dataset_example
,
tokenizer
=
tokenizer
)
elif
stage
==
"sft"
and
not
do_generate
:
if
data_args
.
packing
:
if
data_args
.
neat_packing
:
from
datasets.arrow_writer
import
OptimizedTypedSequence
,
TypedSequence
def
__init__
(
self
,
data
,
**
kwargs
):
return
TypedSequence
.
__init__
(
self
,
data
,
type
=
kwargs
.
pop
(
"type"
,
None
),
try_type
=
kwargs
.
pop
(
"try_type"
,
None
),
optimized_int_type
=
kwargs
.
pop
(
"optimized_int_type"
,
None
),
)
OptimizedTypedSequence
.
__init__
=
__init__
preprocess_func
=
partial
(
preprocess_packed_supervised_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
data_args
=
data_args
,
)
else
:
preprocess_func
=
partial
(
preprocess_supervised_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_supervised_dataset_example
,
tokenizer
=
tokenizer
)
elif
stage
==
"rm"
:
preprocess_func
=
partial
(
preprocess_pairwise_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_pairwise_dataset_example
,
tokenizer
=
tokenizer
)
elif
stage
==
"kto"
:
preprocess_func
=
partial
(
preprocess_feedback_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_supervised_dataset_example
,
tokenizer
=
tokenizer
)
else
:
preprocess_func
=
partial
(
preprocess_unsupervised_dataset
,
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
data_args
=
data_args
,
)
print_function
=
partial
(
print_unsupervised_dataset_example
,
tokenizer
=
tokenizer
)
return
preprocess_func
,
print_function
LLaMA-Factory/src/llamafactory/data/processors/__init__.py
0 → 100644
View file @
032b90a1
LLaMA-Factory/src/llamafactory/data/processors/feedback.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
.processor_utils
import
get_paligemma_token_type_ids
,
get_pixel_values
,
infer_seqlen
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..template
import
Template
logger
=
get_logger
(
__name__
)
def
_encode_feedback_example
(
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
kl_response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
],
bool
]:
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
if
response
[
0
][
"content"
]:
# desired example
kto_tag
=
True
messages
=
prompt
+
[
response
[
0
]]
else
:
# undesired example
kto_tag
=
False
messages
=
prompt
+
[
response
[
1
]]
if
kl_response
[
0
][
"content"
]:
kl_messages
=
prompt
+
[
kl_response
[
0
]]
else
:
kl_messages
=
prompt
+
[
kl_response
[
1
]]
prompt_ids
,
response_ids
=
template
.
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
kl_prompt_ids
,
kl_response_ids
=
template
.
encode_oneturn
(
tokenizer
,
kl_messages
,
system
,
tools
)
if
template
.
efficient_eos
:
response_ids
+=
[
tokenizer
.
eos_token_id
]
kl_response_ids
+=
[
tokenizer
.
eos_token_id
]
if
processor
is
not
None
and
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
template
.
image_token
)
prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
prompt_ids
kl_prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
kl_prompt_ids
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
len
(
response_ids
),
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
response_ids
=
response_ids
[:
target_len
]
kl_source_len
,
kl_target_len
=
infer_seqlen
(
len
(
kl_prompt_ids
),
len
(
kl_response_ids
),
cutoff_len
)
kl_prompt_ids
=
kl_prompt_ids
[:
kl_source_len
]
kl_response_ids
=
kl_response_ids
[:
kl_target_len
]
input_ids
=
prompt_ids
+
response_ids
labels
=
[
IGNORE_INDEX
]
*
source_len
+
response_ids
kl_input_ids
=
kl_prompt_ids
+
kl_response_ids
kl_labels
=
[
IGNORE_INDEX
]
*
kl_source_len
+
kl_response_ids
return
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
def
preprocess_feedback_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
List
[
int
]]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response
=
examples
[
"response"
][::
-
1
]
model_inputs
=
{
"input_ids"
:
[],
"attention_mask"
:
[],
"labels"
:
[],
"kl_input_ids"
:
[],
"kl_attention_mask"
:
[],
"kl_labels"
:
[],
"kto_tags"
:
[],
}
if
processor
is
not
None
:
model_inputs
[
"pixel_values"
]
=
[]
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
model_inputs
[
"token_type_ids"
]
=
[]
model_inputs
[
"kl_token_type_ids"
]
=
[]
for
i
in
range
(
len
(
examples
[
"prompt"
])):
if
len
(
examples
[
"prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"response"
][
i
])
<
2
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"prompt"
][
i
]
+
examples
[
"response"
][
i
]))
continue
input_ids
,
labels
,
kl_input_ids
,
kl_labels
,
kto_tag
=
_encode_feedback_example
(
prompt
=
examples
[
"prompt"
][
i
],
response
=
examples
[
"response"
][
i
],
kl_response
=
kl_response
[
i
],
system
=
examples
[
"system"
][
i
],
tools
=
examples
[
"tools"
][
i
],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
model_inputs
[
"kl_input_ids"
].
append
(
kl_input_ids
)
model_inputs
[
"kl_attention_mask"
].
append
([
1
]
*
len
(
kl_input_ids
))
model_inputs
[
"kl_labels"
].
append
(
kl_labels
)
model_inputs
[
"kto_tags"
].
append
(
kto_tag
)
if
processor
is
not
None
:
model_inputs
[
"pixel_values"
].
append
(
get_pixel_values
(
examples
[
"images"
][
i
],
processor
))
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
model_inputs
[
"token_type_ids"
].
append
(
get_paligemma_token_type_ids
(
len
(
input_ids
),
processor
))
model_inputs
[
"kl_token_type_ids"
].
append
(
get_paligemma_token_type_ids
(
len
(
kl_input_ids
),
processor
))
desirable_num
=
sum
([
1
for
tag
in
model_inputs
[
"kto_tags"
]
if
tag
])
undesirable_num
=
len
(
model_inputs
[
"kto_tags"
])
-
desirable_num
if
desirable_num
==
0
or
undesirable_num
==
0
:
logger
.
warning
(
"Your dataset only has one preference type."
)
return
model_inputs
LLaMA-Factory/src/llamafactory/data/processors/pairwise.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
.processor_utils
import
get_paligemma_token_type_ids
,
get_pixel_values
,
infer_seqlen
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..template
import
Template
logger
=
get_logger
(
__name__
)
def
_encode_pairwise_example
(
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
int
],
List
[
int
],
List
[
int
]]:
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
chosen_messages
=
prompt
+
[
response
[
0
]]
rejected_messages
=
prompt
+
[
response
[
1
]]
prompt_ids
,
chosen_ids
=
template
.
encode_oneturn
(
tokenizer
,
chosen_messages
,
system
,
tools
)
_
,
rejected_ids
=
template
.
encode_oneturn
(
tokenizer
,
rejected_messages
,
system
,
tools
)
if
template
.
efficient_eos
:
chosen_ids
+=
[
tokenizer
.
eos_token_id
]
rejected_ids
+=
[
tokenizer
.
eos_token_id
]
if
processor
is
not
None
and
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
template
.
image_token
)
prompt_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
prompt_ids
# consider the response is more important
source_len
,
target_len
=
infer_seqlen
(
len
(
prompt_ids
),
max
(
len
(
chosen_ids
),
len
(
rejected_ids
)),
cutoff_len
)
prompt_ids
=
prompt_ids
[:
source_len
]
chosen_ids
=
chosen_ids
[:
target_len
]
rejected_ids
=
rejected_ids
[:
target_len
]
chosen_input_ids
=
prompt_ids
+
chosen_ids
chosen_labels
=
[
IGNORE_INDEX
]
*
source_len
+
chosen_ids
rejected_input_ids
=
prompt_ids
+
rejected_ids
rejected_labels
=
[
IGNORE_INDEX
]
*
source_len
+
rejected_ids
return
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
def
preprocess_pairwise_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
List
[
int
]]]:
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs
=
{
"chosen_input_ids"
:
[],
"chosen_attention_mask"
:
[],
"chosen_labels"
:
[],
"rejected_input_ids"
:
[],
"rejected_attention_mask"
:
[],
"rejected_labels"
:
[],
}
if
processor
is
not
None
:
model_inputs
[
"pixel_values"
]
=
[]
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
model_inputs
[
"chosen_token_type_ids"
]
=
[]
model_inputs
[
"rejected_token_type_ids"
]
=
[]
for
i
in
range
(
len
(
examples
[
"prompt"
])):
if
len
(
examples
[
"prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"response"
][
i
])
<
2
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"prompt"
][
i
]
+
examples
[
"response"
][
i
]))
continue
chosen_input_ids
,
chosen_labels
,
rejected_input_ids
,
rejected_labels
=
_encode_pairwise_example
(
prompt
=
examples
[
"prompt"
][
i
],
response
=
examples
[
"response"
][
i
],
system
=
examples
[
"system"
][
i
],
tools
=
examples
[
"tools"
][
i
],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
)
model_inputs
[
"chosen_input_ids"
].
append
(
chosen_input_ids
)
model_inputs
[
"chosen_attention_mask"
].
append
([
1
]
*
len
(
chosen_input_ids
))
model_inputs
[
"chosen_labels"
].
append
(
chosen_labels
)
model_inputs
[
"rejected_input_ids"
].
append
(
rejected_input_ids
)
model_inputs
[
"rejected_attention_mask"
].
append
([
1
]
*
len
(
rejected_input_ids
))
model_inputs
[
"rejected_labels"
].
append
(
rejected_labels
)
if
processor
is
not
None
:
model_inputs
[
"pixel_values"
].
append
(
get_pixel_values
(
examples
[
"images"
][
i
],
processor
))
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
model_inputs
[
"chosen_token_type_ids"
].
append
(
get_paligemma_token_type_ids
(
len
(
chosen_input_ids
),
processor
)
)
model_inputs
[
"rejected_token_type_ids"
].
append
(
get_paligemma_token_type_ids
(
len
(
rejected_input_ids
),
processor
)
)
return
model_inputs
def
print_pairwise_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
valid_chosen_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"chosen_labels"
]))
valid_rejected_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"rejected_labels"
]))
print
(
"chosen_input_ids:
\n
{}"
.
format
(
example
[
"chosen_input_ids"
]))
print
(
"chosen_inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"chosen_input_ids"
],
skip_special_tokens
=
False
)))
print
(
"chosen_label_ids:
\n
{}"
.
format
(
example
[
"chosen_labels"
]))
print
(
"chosen_labels:
\n
{}"
.
format
(
tokenizer
.
decode
(
valid_chosen_labels
,
skip_special_tokens
=
False
)))
print
(
"rejected_input_ids:
\n
{}"
.
format
(
example
[
"rejected_input_ids"
]))
print
(
"rejected_inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"rejected_input_ids"
],
skip_special_tokens
=
False
)))
print
(
"rejected_label_ids:
\n
{}"
.
format
(
example
[
"rejected_labels"
]))
print
(
"rejected_labels:
\n
{}"
.
format
(
tokenizer
.
decode
(
valid_rejected_labels
,
skip_special_tokens
=
False
)))
LLaMA-Factory/src/llamafactory/data/processors/pretrain.py
0 → 100644
View file @
032b90a1
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
itertools
import
chain
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
from
...hparams
import
DataArguments
def
preprocess_pretrain_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
)
->
Dict
[
str
,
List
[
List
[
int
]]]:
# build grouped texts with format `X1 X2 X3 ...` if packing is enabled
eos_token
=
"<|end_of_text|>"
if
data_args
.
template
==
"llama3"
else
tokenizer
.
eos_token
text_examples
=
[
messages
[
0
][
"content"
]
+
eos_token
for
messages
in
examples
[
"prompt"
]]
if
not
data_args
.
packing
:
if
data_args
.
template
==
"gemma"
:
text_examples
=
[
tokenizer
.
bos_token
+
example
for
example
in
text_examples
]
result
=
tokenizer
(
text_examples
,
add_special_tokens
=
False
,
max_length
=
data_args
.
cutoff_len
,
truncation
=
True
)
else
:
tokenized_examples
=
tokenizer
(
text_examples
,
add_special_tokens
=
False
)
concatenated_examples
=
{
k
:
list
(
chain
(
*
tokenized_examples
[
k
]))
for
k
in
tokenized_examples
.
keys
()}
total_length
=
len
(
concatenated_examples
[
list
(
concatenated_examples
.
keys
())[
0
]])
block_size
=
data_args
.
cutoff_len
total_length
=
(
total_length
//
block_size
)
*
block_size
result
=
{
k
:
[
t
[
i
:
i
+
block_size
]
for
i
in
range
(
0
,
total_length
,
block_size
)]
for
k
,
t
in
concatenated_examples
.
items
()
}
if
data_args
.
template
==
"gemma"
:
for
i
in
range
(
len
(
result
[
"input_ids"
])):
result
[
"input_ids"
][
i
][
0
]
=
tokenizer
.
bos_token_id
return
result
Prev
1
2
3
4
5
6
7
8
9
10
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment