Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LLaMA-Factory
Commits
7ea81099
Commit
7ea81099
authored
Apr 07, 2025
by
chenych
Browse files
update llama4
parent
84987715
Changes
139
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
404 additions
and
125 deletions
+404
-125
scripts/convert_ckpt/llamafy_qwen.py
scripts/convert_ckpt/llamafy_qwen.py
+7
-7
scripts/convert_ckpt/tiny_llama4.py
scripts/convert_ckpt/tiny_llama4.py
+39
-0
scripts/eval_bleu_rouge.py
scripts/eval_bleu_rouge.py
+78
-0
scripts/llama_pro.py
scripts/llama_pro.py
+5
-5
scripts/loftq_init.py
scripts/loftq_init.py
+3
-3
scripts/pissa_init.py
scripts/pissa_init.py
+3
-3
scripts/qwen_omni_merge.py
scripts/qwen_omni_merge.py
+132
-0
scripts/stat_utils/cal_flops.py
scripts/stat_utils/cal_flops.py
+2
-2
scripts/stat_utils/cal_lr.py
scripts/stat_utils/cal_lr.py
+4
-5
scripts/stat_utils/cal_mfu.py
scripts/stat_utils/cal_mfu.py
+4
-8
scripts/stat_utils/cal_ppl.py
scripts/stat_utils/cal_ppl.py
+10
-14
scripts/stat_utils/length_cdf.py
scripts/stat_utils/length_cdf.py
+2
-2
scripts/vllm_infer.py
scripts/vllm_infer.py
+17
-5
setup.py
setup.py
+9
-8
src/llamafactory/__init__.py
src/llamafactory/__init__.py
+5
-6
src/llamafactory/api/app.py
src/llamafactory/api/app.py
+1
-3
src/llamafactory/api/chat.py
src/llamafactory/api/chat.py
+43
-9
src/llamafactory/api/common.py
src/llamafactory/api/common.py
+2
-2
src/llamafactory/api/protocol.py
src/llamafactory/api/protocol.py
+19
-16
src/llamafactory/chat/base_engine.py
src/llamafactory/chat/base_engine.py
+19
-27
No files found.
scripts/convert_ckpt/llamafy_qwen.py
View file @
7ea81099
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
import
json
import
json
import
os
import
os
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
from
typing
import
Any
import
fire
import
fire
import
torch
import
torch
...
@@ -37,14 +37,14 @@ CONFIG_NAME = "config.json"
...
@@ -37,14 +37,14 @@ CONFIG_NAME = "config.json"
def
save_weight
(
input_dir
:
str
,
output_dir
:
str
,
shard_size
:
str
,
save_safetensors
:
bool
)
->
str
:
def
save_weight
(
input_dir
:
str
,
output_dir
:
str
,
shard_size
:
str
,
save_safetensors
:
bool
)
->
str
:
qwen_state_dict
:
D
ict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
qwen_state_dict
:
d
ict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
for
filepath
in
tqdm
(
os
.
listdir
(
input_dir
),
desc
=
"Load weights"
):
for
filepath
in
tqdm
(
os
.
listdir
(
input_dir
),
desc
=
"Load weights"
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
input_dir
,
filepath
))
and
filepath
.
endswith
(
".safetensors"
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
input_dir
,
filepath
))
and
filepath
.
endswith
(
".safetensors"
):
with
safe_open
(
os
.
path
.
join
(
input_dir
,
filepath
),
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
with
safe_open
(
os
.
path
.
join
(
input_dir
,
filepath
),
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
key
in
f
.
keys
():
for
key
in
f
.
keys
():
qwen_state_dict
[
key
]
=
f
.
get_tensor
(
key
)
qwen_state_dict
[
key
]
=
f
.
get_tensor
(
key
)
llama_state_dict
:
D
ict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
llama_state_dict
:
d
ict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
torch_dtype
=
None
torch_dtype
=
None
for
key
,
value
in
tqdm
(
qwen_state_dict
.
items
(),
desc
=
"Convert format"
):
for
key
,
value
in
tqdm
(
qwen_state_dict
.
items
(),
desc
=
"Convert format"
):
if
torch_dtype
is
None
:
if
torch_dtype
is
None
:
...
@@ -112,9 +112,9 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
...
@@ -112,9 +112,9 @@ def save_weight(input_dir: str, output_dir: str, shard_size: str, save_safetenso
def
save_config
(
input_dir
:
str
,
output_dir
:
str
,
torch_dtype
:
str
):
def
save_config
(
input_dir
:
str
,
output_dir
:
str
,
torch_dtype
:
str
):
with
open
(
os
.
path
.
join
(
input_dir
,
CONFIG_NAME
),
encoding
=
"utf-8"
)
as
f
:
with
open
(
os
.
path
.
join
(
input_dir
,
CONFIG_NAME
),
encoding
=
"utf-8"
)
as
f
:
qwen_config_dict
:
D
ict
[
str
,
Any
]
=
json
.
load
(
f
)
qwen_config_dict
:
d
ict
[
str
,
Any
]
=
json
.
load
(
f
)
llama2_config_dict
:
D
ict
[
str
,
Any
]
=
OrderedDict
()
llama2_config_dict
:
d
ict
[
str
,
Any
]
=
OrderedDict
()
llama2_config_dict
[
"architectures"
]
=
[
"LlamaForCausalLM"
]
llama2_config_dict
[
"architectures"
]
=
[
"LlamaForCausalLM"
]
llama2_config_dict
[
"hidden_act"
]
=
"silu"
llama2_config_dict
[
"hidden_act"
]
=
"silu"
llama2_config_dict
[
"hidden_size"
]
=
qwen_config_dict
[
"hidden_size"
]
llama2_config_dict
[
"hidden_size"
]
=
qwen_config_dict
[
"hidden_size"
]
...
@@ -147,8 +147,8 @@ def llamafy_qwen(
...
@@ -147,8 +147,8 @@ def llamafy_qwen(
shard_size
:
str
=
"2GB"
,
shard_size
:
str
=
"2GB"
,
save_safetensors
:
bool
=
False
,
save_safetensors
:
bool
=
False
,
):
):
r
"""
r
"""
Convert the Qwen models in the same format as LLaMA2.
Converts the Qwen models in the same format as LLaMA2.
Usage: python llamafy_qwen.py --input_dir input --output_dir output
Usage: python llamafy_qwen.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
"""
"""
...
...
scripts/convert_ckpt/tiny_llama4.py
0 → 100644
View file @
7ea81099
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
transformers
import
Llama4Config
,
Llama4ForConditionalGeneration
,
Llama4TextConfig
,
Llama4VisionConfig
if
__name__
==
"__main__"
:
vision_config
=
Llama4VisionConfig
(
hidden_size
=
1408
,
image_size
=
336
,
intermediate_size
=
5632
,
num_attention_heads
=
16
,
num_hidden_layers
=
4
,
vision_output_dim
=
4096
,
)
text_config
=
Llama4TextConfig
(
hidden_size
=
512
,
intermediate_size
=
1024
,
intermediate_size_mlp
=
1024
,
num_hidden_layers
=
4
,
num_attention_heads
=
8
,
num_key_value_heads
=
2
,
head_dim
=
512
//
8
,
num_local_experts
=
2
,
)
config
=
Llama4Config
(
vision_config
=
vision_config
,
text_config
=
text_config
)
model
=
Llama4ForConditionalGeneration
.
_from_config
(
config
)
model
.
save_pretrained
(
"tiny-llama4"
)
scripts/eval_bleu_rouge.py
0 → 100644
View file @
7ea81099
# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
logging
import
time
import
fire
from
datasets
import
load_dataset
try
:
import
jieba
from
nltk.translate.bleu_score
import
SmoothingFunction
,
sentence_bleu
from
rouge_chinese
import
Rouge
jieba
.
setLogLevel
(
logging
.
CRITICAL
)
jieba
.
initialize
()
except
ImportError
:
print
(
"Please install llamafactory with `pip install -e .[metrics]`."
)
raise
def
compute_metrics
(
sample
):
hypothesis
=
list
(
jieba
.
cut
(
sample
[
"predict"
]))
reference
=
list
(
jieba
.
cut
(
sample
[
"label"
]))
bleu_score
=
sentence_bleu
(
[
list
(
sample
[
"label"
])],
list
(
sample
[
"predict"
]),
smoothing_function
=
SmoothingFunction
().
method3
,
)
if
len
(
" "
.
join
(
hypothesis
).
split
())
==
0
or
len
(
" "
.
join
(
reference
).
split
())
==
0
:
result
=
{
"rouge-1"
:
{
"f"
:
0.0
},
"rouge-2"
:
{
"f"
:
0.0
},
"rouge-l"
:
{
"f"
:
0.0
}}
else
:
rouge
=
Rouge
()
scores
=
rouge
.
get_scores
(
" "
.
join
(
hypothesis
),
" "
.
join
(
reference
))
result
=
scores
[
0
]
metric_result
=
{}
for
k
,
v
in
result
.
items
():
metric_result
[
k
]
=
round
(
v
[
"f"
]
*
100
,
4
)
metric_result
[
"bleu-4"
]
=
round
(
bleu_score
*
100
,
4
)
return
metric_result
def
main
(
filename
:
str
):
start_time
=
time
.
time
()
dataset
=
load_dataset
(
"json"
,
data_files
=
filename
,
split
=
"train"
)
dataset
=
dataset
.
map
(
compute_metrics
,
num_proc
=
8
,
remove_columns
=
dataset
.
column_names
)
score_dict
=
dataset
.
to_dict
()
average_score
=
{}
for
task
,
scores
in
sorted
(
score_dict
.
items
(),
key
=
lambda
x
:
x
[
0
]):
print
(
f
"
{
task
}
:
{
sum
(
scores
)
/
len
(
scores
):.
4
f
}
"
)
average_score
[
task
]
=
sum
(
scores
)
/
len
(
scores
)
with
open
(
"predictions_score.json"
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
average_score
,
f
,
indent
=
4
)
print
(
f
"
\n
Done in
{
time
.
time
()
-
start_time
:.
3
f
}
s.
\n
Score file saved to predictions_score.json"
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
main
)
scripts/llama_pro.py
View file @
7ea81099
...
@@ -18,7 +18,7 @@
...
@@ -18,7 +18,7 @@
import
json
import
json
import
os
import
os
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
TYPE_CHECKING
,
Dict
from
typing
import
TYPE_CHECKING
import
fire
import
fire
import
torch
import
torch
...
@@ -44,11 +44,11 @@ def block_expansion(
...
@@ -44,11 +44,11 @@ def block_expansion(
shard_size
:
str
=
"5GB"
,
shard_size
:
str
=
"5GB"
,
save_safetensors
:
bool
=
True
,
save_safetensors
:
bool
=
True
,
):
):
r
"""
r
"""
Perform block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Performs block expansion for LLaMA, Mistral, Qwen2 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
"""
"""
config
:
"
PretrainedConfig
"
=
AutoConfig
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
config
:
PretrainedConfig
=
AutoConfig
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
num_layers
=
getattr
(
config
,
"num_hidden_layers"
)
num_layers
=
getattr
(
config
,
"num_hidden_layers"
)
if
num_layers
%
num_expand
!=
0
:
if
num_layers
%
num_expand
!=
0
:
raise
ValueError
(
f
"`num_layers`
{
num_layers
}
should be divisible by `num_expand`
{
num_expand
}
."
)
raise
ValueError
(
f
"`num_layers`
{
num_layers
}
should be divisible by `num_expand`
{
num_expand
}
."
)
...
@@ -70,7 +70,7 @@ def block_expansion(
...
@@ -70,7 +70,7 @@ def block_expansion(
split
=
num_layers
//
num_expand
split
=
num_layers
//
num_expand
layer_cnt
=
0
layer_cnt
=
0
state_dict
=
model
.
state_dict
()
state_dict
=
model
.
state_dict
()
output_state_dict
:
D
ict
[
str
,
"
torch.Tensor
"
]
=
OrderedDict
()
output_state_dict
:
d
ict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
for
i
in
range
(
num_layers
):
for
i
in
range
(
num_layers
):
for
key
,
value
in
state_dict
.
items
():
for
key
,
value
in
state_dict
.
items
():
if
f
".
{
i
:
d
}
."
in
key
:
if
f
".
{
i
:
d
}
."
in
key
:
...
...
scripts/loftq_init.py
View file @
7ea81099
...
@@ -38,8 +38,8 @@ def quantize_loftq(
...
@@ -38,8 +38,8 @@ def quantize_loftq(
lora_target
:
tuple
=
(
"q_proj"
,
"v_proj"
),
lora_target
:
tuple
=
(
"q_proj"
,
"v_proj"
),
save_safetensors
:
bool
=
True
,
save_safetensors
:
bool
=
True
,
):
):
r
"""
r
"""
Initialize LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ).
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
"""
if
isinstance
(
lora_target
,
str
):
if
isinstance
(
lora_target
,
str
):
...
@@ -72,7 +72,7 @@ def quantize_loftq(
...
@@ -72,7 +72,7 @@ def quantize_loftq(
print
(
f
"Adapter weights saved in
{
loftq_dir
}
"
)
print
(
f
"Adapter weights saved in
{
loftq_dir
}
"
)
# Save base model
# Save base model
base_model
:
"
PreTrainedModel
"
=
peft_model
.
unload
()
base_model
:
PreTrainedModel
=
peft_model
.
unload
()
base_model
.
save_pretrained
(
output_dir
,
safe_serialization
=
save_safetensors
)
base_model
.
save_pretrained
(
output_dir
,
safe_serialization
=
save_safetensors
)
tokenizer
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
print
(
f
"Model weights saved in
{
output_dir
}
"
)
print
(
f
"Model weights saved in
{
output_dir
}
"
)
...
...
scripts/pissa_init.py
View file @
7ea81099
...
@@ -37,8 +37,8 @@ def quantize_pissa(
...
@@ -37,8 +37,8 @@ def quantize_pissa(
lora_target
:
tuple
=
(
"q_proj"
,
"v_proj"
),
lora_target
:
tuple
=
(
"q_proj"
,
"v_proj"
),
save_safetensors
:
bool
=
True
,
save_safetensors
:
bool
=
True
,
):
):
r
"""
r
"""
Initialize LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA).
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
"""
if
isinstance
(
lora_target
,
str
):
if
isinstance
(
lora_target
,
str
):
...
@@ -67,7 +67,7 @@ def quantize_pissa(
...
@@ -67,7 +67,7 @@ def quantize_pissa(
print
(
f
"Adapter weights saved in
{
pissa_dir
}
"
)
print
(
f
"Adapter weights saved in
{
pissa_dir
}
"
)
# Save base model
# Save base model
base_model
:
"
PreTrainedModel
"
=
peft_model
.
unload
()
base_model
:
PreTrainedModel
=
peft_model
.
unload
()
base_model
.
save_pretrained
(
output_dir
,
safe_serialization
=
save_safetensors
)
base_model
.
save_pretrained
(
output_dir
,
safe_serialization
=
save_safetensors
)
tokenizer
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
print
(
f
"Model weights saved in
{
output_dir
}
"
)
print
(
f
"Model weights saved in
{
output_dir
}
"
)
...
...
scripts/qwen_omni_merge.py
0 → 100644
View file @
7ea81099
# Copyright 2025 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
shutil
import
fire
from
peft
import
PeftModel
from
transformers
import
AutoModel
,
AutoProcessor
,
AutoTokenizer
,
Qwen2_5OmniThinkerForConditionalGeneration
def
merge_lora
(
base_model_path
:
str
,
lora_checkpoint_path
:
str
,
extra_file
:
str
=
"spk_dict.pt"
,
submodule_name
:
str
=
"thinker"
,
save_path
:
str
=
"./merged_model_checkpoint"
,
):
"""Load the original model, tokenizer, and processor configuration, merge the LoRA weights.
For a specified submodule, and save the final merged model along with its configurations.
Args:
base_model_path (str): Path to the original model directory.
lora_checkpoint_path (str): Path to the directory containing LoRA weights.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
submodule_name (str): Name of the submodule to merge (default: "thinker").
save_path (str): Directory where the merged model and configurations will be saved.
"""
# 1. Load the original model, tokenizer, and processor
model
=
AutoModel
.
from_pretrained
(
base_model_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
base_model_path
)
try
:
processor
=
AutoProcessor
.
from_pretrained
(
base_model_path
)
except
Exception
:
print
(
"Processor configuration not found, skipping processor load."
)
processor
=
None
print
(
"Successfully loaded the original model, tokenizer, and processor (if available)."
)
# 2. Extract the submodule to be merged (e.g., model.thinker)
if
not
hasattr
(
model
,
submodule_name
):
raise
AttributeError
(
f
"The model does not have a submodule named '
{
submodule_name
}
'."
)
base_submodule
=
getattr
(
model
,
submodule_name
)
print
(
f
"Successfully extracted submodule:
{
submodule_name
}
."
)
# 3. Load the LoRA weights onto the extracted submodule
lora_model
=
PeftModel
.
from_pretrained
(
base_submodule
,
lora_checkpoint_path
)
print
(
"LoRA weights loaded successfully."
)
# 4. Merge the LoRA weights into the submodule and unload the LoRA modules
merged_submodule
=
lora_model
.
merge_and_unload
()
print
(
"LoRA weights merged successfully."
)
# 5. Replace the original submodule with the merged submodule in the model
setattr
(
model
,
submodule_name
,
merged_submodule
)
# 6. Save the final merged model along with the tokenizer and processor configuration
model
.
save_pretrained
(
save_path
)
tokenizer
.
save_pretrained
(
save_path
)
if
processor
is
not
None
:
processor
.
save_pretrained
(
save_path
)
print
(
f
"Merged model and configuration saved to
{
save_path
}
."
)
source_file
=
os
.
path
.
join
(
base_model_path
,
extra_file
)
target_file
=
os
.
path
.
join
(
save_path
,
extra_file
)
if
os
.
path
.
exists
(
source_file
):
shutil
.
copy
(
source_file
,
target_file
)
print
(
f
"File '
{
extra_file
}
' copied from
{
base_model_path
}
to
{
save_path
}
."
)
else
:
print
(
f
"File '
{
extra_file
}
' not found in
{
base_model_path
}
, skipping copy."
)
def
save_full_model
(
saved_thinker_path
:
str
,
base_model_path
:
str
,
save_path
:
str
,
extra_file
:
str
=
"spk_dict.pt"
,
):
"""Load the saved thinker module and the original model, replace the thinker in the original model.
Then save the complete model along with its tokenizer and processor configuration.
Args:
saved_thinker_path (str): Path to the saved thinker weights.
base_model_path (str): Directory path of the original model.
save_path (str): Directory where the final complete model will be saved.
extra_file (str): Name of the extra file to be copied (default: "spk_dict.pt").
"""
# Load the thinker module
thinker
=
Qwen2_5OmniThinkerForConditionalGeneration
.
from_pretrained
(
saved_thinker_path
,
device_map
=
"cpu"
)
# Load the original model
base_model
=
AutoModel
.
from_pretrained
(
base_model_path
,
device_map
=
"cpu"
)
# Replace the thinker module in the original model
base_model
.
thinker
=
thinker
# Load the processor and tokenizer
processor
=
AutoProcessor
.
from_pretrained
(
base_model_path
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
base_model_path
,
trust_remote_code
=
True
)
# Save the complete model along with its configurations
base_model
.
save_pretrained
(
save_path
)
tokenizer
.
save_pretrained
(
save_path
)
processor
.
save_pretrained
(
save_path
)
print
(
f
"Complete model, tokenizer, and processor configuration have been saved to
{
save_path
}
."
)
source_file
=
os
.
path
.
join
(
base_model_path
,
extra_file
)
target_file
=
os
.
path
.
join
(
save_path
,
extra_file
)
if
os
.
path
.
exists
(
source_file
):
shutil
.
copy
(
source_file
,
target_file
)
print
(
f
"File '
{
extra_file
}
' copied from
{
base_model_path
}
to
{
save_path
}
."
)
else
:
print
(
f
"File '
{
extra_file
}
' not found in
{
base_model_path
}
, skipping copy."
)
if
__name__
==
"__main__"
:
fire
.
Fire
({
"save_full"
:
save_full_model
,
"merge_lora"
:
merge_lora
})
scripts/stat_utils/cal_flops.py
View file @
7ea81099
...
@@ -29,8 +29,8 @@ def calculate_flops(
...
@@ -29,8 +29,8 @@ def calculate_flops(
seq_length
:
int
=
512
,
seq_length
:
int
=
512
,
flash_attn
:
str
=
"auto"
,
flash_attn
:
str
=
"auto"
,
):
):
r
"""
r
"""
Calculate the flops of pre-trained models.
Calculates the flops of pre-trained models.
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
"""
"""
with
get_accelerator
().
device
(
0
):
with
get_accelerator
().
device
(
0
):
...
...
scripts/stat_utils/cal_lr.py
View file @
7ea81099
...
@@ -45,8 +45,8 @@ def calculate_lr(
...
@@ -45,8 +45,8 @@ def calculate_lr(
is_mistral_or_gemma
:
bool
=
False
,
# mistral and gemma models opt for a smaller learning rate,
is_mistral_or_gemma
:
bool
=
False
,
# mistral and gemma models opt for a smaller learning rate,
packing
:
bool
=
False
,
packing
:
bool
=
False
,
):
):
r
"""
r
"""
Calculate the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage:
Usage:
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
"""
"""
...
@@ -89,9 +89,8 @@ def calculate_lr(
...
@@ -89,9 +89,8 @@ def calculate_lr(
lr
=
BASE_LR
*
math
.
sqrt
(
token_batch_size
/
BASE_BS
)
# lr ~ sqrt(batch_size)
lr
=
BASE_LR
*
math
.
sqrt
(
token_batch_size
/
BASE_BS
)
# lr ~ sqrt(batch_size)
lr
=
lr
/
6.0
if
is_mistral_or_gemma
else
lr
lr
=
lr
/
6.0
if
is_mistral_or_gemma
else
lr
print
(
print
(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}"
.
format
(
f
"Optimal learning rate is
{
lr
:.
2
e
}
for valid ratio%
{
valid_ratio
*
100
:.
2
f
}
"
lr
,
valid_ratio
*
100
,
token_batch_size
f
"and effective token batch size
{
token_batch_size
:.
2
f
}
"
)
)
)
...
...
scripts/stat_utils/cal_mfu.py
View file @
7ea81099
...
@@ -34,9 +34,7 @@ def compute_model_flops(
...
@@ -34,9 +34,7 @@ def compute_model_flops(
include_recompute
:
bool
=
False
,
include_recompute
:
bool
=
False
,
include_flashattn
:
bool
=
False
,
include_flashattn
:
bool
=
False
,
)
->
int
:
)
->
int
:
r
"""
r
"""Calculate the FLOPs of model per forward/backward pass."""
Calculates the FLOPs of model per forward/backward pass.
"""
config
=
AutoConfig
.
from_pretrained
(
model_name_or_path
)
config
=
AutoConfig
.
from_pretrained
(
model_name_or_path
)
hidden_size
=
getattr
(
config
,
"hidden_size"
,
None
)
hidden_size
=
getattr
(
config
,
"hidden_size"
,
None
)
vocab_size
=
getattr
(
config
,
"vocab_size"
,
None
)
vocab_size
=
getattr
(
config
,
"vocab_size"
,
None
)
...
@@ -86,9 +84,7 @@ def compute_model_flops(
...
@@ -86,9 +84,7 @@ def compute_model_flops(
def
compute_device_flops
(
world_size
:
int
)
->
float
:
def
compute_device_flops
(
world_size
:
int
)
->
float
:
r
"""
r
"""Calculate the FLOPs of the device capability per second."""
Calculates the FLOPs of the device capability per second.
"""
device_name
=
torch
.
cuda
.
get_device_name
()
device_name
=
torch
.
cuda
.
get_device_name
()
if
"H100"
in
device_name
or
"H800"
in
device_name
:
if
"H100"
in
device_name
or
"H800"
in
device_name
:
return
989
*
1e12
*
world_size
return
989
*
1e12
*
world_size
...
@@ -114,8 +110,8 @@ def calculate_mfu(
...
@@ -114,8 +110,8 @@ def calculate_mfu(
liger_kernel
:
bool
=
False
,
liger_kernel
:
bool
=
False
,
unsloth_gc
:
bool
=
False
,
unsloth_gc
:
bool
=
False
,
)
->
float
:
)
->
float
:
r
"""
r
"""
Calculate MFU for given model and hyper-params.
Calculates MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
"""
"""
args
=
{
args
=
{
...
...
scripts/stat_utils/cal_ppl.py
View file @
7ea81099
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
json
import
json
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Literal
,
Optional
,
Sequence
from
typing
import
Any
,
Literal
,
Optional
import
fire
import
fire
import
torch
import
torch
...
@@ -30,16 +30,12 @@ from llamafactory.model import load_model, load_tokenizer
...
@@ -30,16 +30,12 @@ from llamafactory.model import load_model, load_tokenizer
@
dataclass
@
dataclass
class
PairwiseDataCollatorWithPadding
(
MultiModalDataCollatorForSeq2Seq
):
class
PairwiseDataCollatorWithPadding
(
MultiModalDataCollatorForSeq2Seq
):
r
"""
r
"""Data collator for pairwise data."""
Data collator for pairwise data.
"""
train_on_prompt
:
bool
=
False
train_on_prompt
:
bool
=
False
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
torch
.
Tensor
]:
def
__call__
(
self
,
features
:
list
[
dict
[
str
,
Any
]])
->
dict
[
str
,
torch
.
Tensor
]:
r
"""
r
"""Pad batched data to the longest sequence in the batch."""
Pads batched data to the longest sequence in the batch.
"""
chosen_features
=
[]
chosen_features
=
[]
for
feature
in
features
:
for
feature
in
features
:
chosen_features
.
append
(
chosen_features
.
append
(
...
@@ -68,8 +64,8 @@ def calculate_ppl(
...
@@ -68,8 +64,8 @@ def calculate_ppl(
max_samples
:
Optional
[
int
]
=
None
,
max_samples
:
Optional
[
int
]
=
None
,
train_on_prompt
:
bool
=
False
,
train_on_prompt
:
bool
=
False
,
):
):
r
"""
r
"""
Calculate the ppl on the dataset of the pre-trained models.
Calculates the ppl on the dataset of the pre-trained models.
Usage: export CUDA_VISIBLE_DEVICES=0
Usage: export CUDA_VISIBLE_DEVICES=0
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
"""
"""
...
@@ -111,17 +107,17 @@ def calculate_ppl(
...
@@ -111,17 +107,17 @@ def calculate_ppl(
criterion
=
torch
.
nn
.
CrossEntropyLoss
(
reduction
=
"none"
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
(
reduction
=
"none"
)
total_ppl
=
0
total_ppl
=
0
perplexities
=
[]
perplexities
=
[]
batch
:
D
ict
[
str
,
"
torch.Tensor
"
]
batch
:
d
ict
[
str
,
torch
.
Tensor
]
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
batch
in
tqdm
(
dataloader
,
desc
=
"Computing perplexities"
):
for
batch
in
tqdm
(
dataloader
,
desc
=
"Computing perplexities"
):
batch
=
batch
.
to
(
model
.
device
)
batch
=
batch
.
to
(
model
.
device
)
outputs
=
model
(
**
batch
)
outputs
=
model
(
**
batch
)
shift_logits
:
"
torch.Tensor
"
=
outputs
[
"logits"
][...,
:
-
1
,
:]
shift_logits
:
torch
.
Tensor
=
outputs
[
"logits"
][...,
:
-
1
,
:]
shift_labels
:
"
torch.Tensor
"
=
batch
[
"labels"
][...,
1
:]
shift_labels
:
torch
.
Tensor
=
batch
[
"labels"
][...,
1
:]
loss_mask
=
shift_labels
!=
IGNORE_INDEX
loss_mask
=
shift_labels
!=
IGNORE_INDEX
flatten_logits
=
shift_logits
.
contiguous
().
view
(
shift_labels
.
size
(
0
)
*
shift_labels
.
size
(
1
),
-
1
)
flatten_logits
=
shift_logits
.
contiguous
().
view
(
shift_labels
.
size
(
0
)
*
shift_labels
.
size
(
1
),
-
1
)
flatten_labels
=
shift_labels
.
contiguous
().
view
(
-
1
)
flatten_labels
=
shift_labels
.
contiguous
().
view
(
-
1
)
token_logps
:
"
torch.Tensor
"
=
criterion
(
flatten_logits
,
flatten_labels
)
token_logps
:
torch
.
Tensor
=
criterion
(
flatten_logits
,
flatten_labels
)
token_logps
=
token_logps
.
contiguous
().
view
(
shift_logits
.
size
(
0
),
-
1
)
token_logps
=
token_logps
.
contiguous
().
view
(
shift_logits
.
size
(
0
),
-
1
)
sentence_logps
=
(
token_logps
*
loss_mask
).
sum
(
-
1
)
/
loss_mask
.
sum
(
-
1
)
sentence_logps
=
(
token_logps
*
loss_mask
).
sum
(
-
1
)
/
loss_mask
.
sum
(
-
1
)
total_ppl
+=
sentence_logps
.
exp
().
sum
().
item
()
total_ppl
+=
sentence_logps
.
exp
().
sum
().
item
()
...
...
scripts/stat_utils/length_cdf.py
View file @
7ea81099
...
@@ -29,8 +29,8 @@ def length_cdf(
...
@@ -29,8 +29,8 @@ def length_cdf(
template
:
str
=
"default"
,
template
:
str
=
"default"
,
interval
:
int
=
1000
,
interval
:
int
=
1000
,
):
):
r
"""
r
"""
Calculate the distribution of the input lengths in the dataset.
Calculates the distribution of the input lengths in the dataset.
Usage: export CUDA_VISIBLE_DEVICES=0
Usage: export CUDA_VISIBLE_DEVICES=0
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
"""
"""
...
...
scripts/vllm_infer.py
View file @
7ea81099
...
@@ -52,11 +52,11 @@ def vllm_infer(
...
@@ -52,11 +52,11 @@ def vllm_infer(
image_max_pixels
:
int
=
768
*
768
,
image_max_pixels
:
int
=
768
*
768
,
image_min_pixels
:
int
=
32
*
32
,
image_min_pixels
:
int
=
32
*
32
,
):
):
r
"""
r
"""
Perform batch generation using vLLM engine, which supports tensor parallelism.
Performs batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
"""
"""
check_version
(
"vllm>=0.4.3,<=0.
7.3
"
)
check_version
(
"vllm>=0.4.3,<=0.
8.2
"
)
if
pipeline_parallel_size
>
get_device_count
():
if
pipeline_parallel_size
>
get_device_count
():
raise
ValueError
(
"Pipeline parallel size should be smaller than the number of gpus."
)
raise
ValueError
(
"Pipeline parallel size should be smaller than the number of gpus."
)
...
@@ -92,8 +92,20 @@ def vllm_infer(
...
@@ -92,8 +92,20 @@ def vllm_infer(
multi_modal_data
=
{
multi_modal_data
=
{
"image"
:
template_obj
.
mm_plugin
.
_regularize_images
(
"image"
:
template_obj
.
mm_plugin
.
_regularize_images
(
sample
[
"images"
],
image_max_pixels
=
image_max_pixels
,
image_min_pixels
=
image_min_pixels
sample
[
"images"
],
image_max_pixels
=
image_max_pixels
,
image_min_pixels
=
image_min_pixels
)
)
[
"images"
]
}
}
elif
sample
[
"videos"
]:
multi_modal_data
=
{
"video"
:
template_obj
.
mm_plugin
.
_regularize_videos
(
sample
[
"videos"
],
image_max_pixels
=
image_max_pixels
,
image_min_pixels
=
image_min_pixels
)[
"videos"
]
}
elif
sample
[
"audios"
]:
audio_data
=
template_obj
.
mm_plugin
.
_regularize_audios
(
sample
[
"audios"
],
sampling_rate
=
16000
,
)
multi_modal_data
=
{
"audio"
:
zip
(
audio_data
[
"audios"
],
audio_data
[
"sampling_rates"
])}
else
:
else
:
multi_modal_data
=
None
multi_modal_data
=
None
...
@@ -131,7 +143,7 @@ def vllm_infer(
...
@@ -131,7 +143,7 @@ def vllm_infer(
"enable_lora"
:
model_args
.
adapter_name_or_path
is
not
None
,
"enable_lora"
:
model_args
.
adapter_name_or_path
is
not
None
,
}
}
if
template_obj
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
if
template_obj
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
engine_args
[
"limit_mm_per_prompt"
]
=
{
"image"
:
4
,
"video"
:
2
}
engine_args
[
"limit_mm_per_prompt"
]
=
{
"image"
:
4
,
"video"
:
2
,
"audio"
:
2
}
if
isinstance
(
model_args
.
vllm_config
,
dict
):
if
isinstance
(
model_args
.
vllm_config
,
dict
):
engine_args
.
update
(
model_args
.
vllm_config
)
engine_args
.
update
(
model_args
.
vllm_config
)
...
...
setup.py
View file @
7ea81099
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
import
os
import
os
import
re
import
re
from
typing
import
List
from
setuptools
import
find_packages
,
setup
from
setuptools
import
find_packages
,
setup
...
@@ -27,14 +26,14 @@ def get_version() -> str:
...
@@ -27,14 +26,14 @@ def get_version() -> str:
return
version
return
version
def
get_requires
()
->
L
ist
[
str
]:
def
get_requires
()
->
l
ist
[
str
]:
with
open
(
"requirements.txt"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
"requirements.txt"
,
encoding
=
"utf-8"
)
as
f
:
file_content
=
f
.
read
()
file_content
=
f
.
read
()
lines
=
[
line
.
strip
()
for
line
in
file_content
.
strip
().
split
(
"
\n
"
)
if
not
line
.
startswith
(
"#"
)]
lines
=
[
line
.
strip
()
for
line
in
file_content
.
strip
().
split
(
"
\n
"
)
if
not
line
.
startswith
(
"#"
)]
return
lines
return
lines
def
get_console_scripts
()
->
L
ist
[
str
]:
def
get_console_scripts
()
->
l
ist
[
str
]:
console_scripts
=
[
"llamafactory-cli = llamafactory.cli:main"
]
console_scripts
=
[
"llamafactory-cli = llamafactory.cli:main"
]
if
os
.
getenv
(
"ENABLE_SHORT_CONSOLE"
,
"1"
).
lower
()
in
[
"true"
,
"y"
,
"1"
]:
if
os
.
getenv
(
"ENABLE_SHORT_CONSOLE"
,
"1"
).
lower
()
in
[
"true"
,
"y"
,
"1"
]:
console_scripts
.
append
(
"lmf = llamafactory.cli:main"
)
console_scripts
.
append
(
"lmf = llamafactory.cli:main"
)
...
@@ -47,14 +46,15 @@ extra_require = {
...
@@ -47,14 +46,15 @@ extra_require = {
"torch-npu"
:
[
"torch==2.4.0"
,
"torch-npu==2.4.0.post2"
,
"decorator"
],
"torch-npu"
:
[
"torch==2.4.0"
,
"torch-npu==2.4.0.post2"
,
"decorator"
],
"metrics"
:
[
"nltk"
,
"jieba"
,
"rouge-chinese"
],
"metrics"
:
[
"nltk"
,
"jieba"
,
"rouge-chinese"
],
"deepspeed"
:
[
"deepspeed>=0.10.0,<=0.16.4"
],
"deepspeed"
:
[
"deepspeed>=0.10.0,<=0.16.4"
],
"liger-kernel"
:
[
"liger-kernel"
],
"liger-kernel"
:
[
"liger-kernel
>=0.5.5
"
],
"bitsandbytes"
:
[
"bitsandbytes>=0.39.0"
],
"bitsandbytes"
:
[
"bitsandbytes>=0.39.0"
],
"hqq"
:
[
"hqq"
],
"hqq"
:
[
"hqq"
],
"eetq"
:
[
"eetq"
],
"eetq"
:
[
"eetq"
],
"gptq"
:
[
"optimum>=1.17.0"
,
"auto-gptq>=0.5.0"
],
"gptq"
:
[
"optimum>=1.17.0"
,
"auto-gptq>=0.5.0"
],
"awq"
:
[
"autoawq"
],
"awq"
:
[
"autoawq"
],
"aqlm"
:
[
"aqlm[gpu]>=1.1.0"
],
"aqlm"
:
[
"aqlm[gpu]>=1.1.0"
],
"vllm"
:
[
"vllm>=0.4.3,<=0.7.3"
],
"vllm"
:
[
"vllm>=0.4.3,<=0.8.2"
],
"sglang"
:
[
"sglang[srt]>=0.4.4"
,
"transformers==4.48.3"
],
"galore"
:
[
"galore-torch"
],
"galore"
:
[
"galore-torch"
],
"apollo"
:
[
"apollo-torch"
],
"apollo"
:
[
"apollo-torch"
],
"badam"
:
[
"badam>=1.2.1"
],
"badam"
:
[
"badam>=1.2.1"
],
...
@@ -69,6 +69,7 @@ extra_require = {
...
@@ -69,6 +69,7 @@ extra_require = {
"msgpack"
,
"msgpack"
,
"referencing"
,
"referencing"
,
"jsonschema_specifications"
,
"jsonschema_specifications"
,
"transformers==4.48.3"
,
],
],
"modelscope"
:
[
"modelscope"
],
"modelscope"
:
[
"modelscope"
],
"openmind"
:
[
"openmind"
],
"openmind"
:
[
"openmind"
],
...
@@ -82,11 +83,11 @@ def main():
...
@@ -82,11 +83,11 @@ def main():
name
=
"llamafactory"
,
name
=
"llamafactory"
,
version
=
get_version
(),
version
=
get_version
(),
author
=
"hiyouga"
,
author
=
"hiyouga"
,
author_email
=
"hiyouga
AT
buaa.edu.cn"
,
author_email
=
"hiyouga
@
buaa.edu.cn"
,
description
=
"
Easy-to-use LLM f
ine-
t
uning
framework
"
,
description
=
"
Unified Efficient F
ine-
T
uning
of 100+ LLMs
"
,
long_description
=
open
(
"README.md"
,
encoding
=
"utf-8"
).
read
(),
long_description
=
open
(
"README.md"
,
encoding
=
"utf-8"
).
read
(),
long_description_content_type
=
"text/markdown"
,
long_description_content_type
=
"text/markdown"
,
keywords
=
[
"
LLaMA"
,
"BLOOM"
,
"Falcon"
,
"LLM
"
,
"ChatGPT"
,
"
t
ransformer"
,
"
pytorch"
,
"deep learning
"
],
keywords
=
[
"
AI"
,
"LLM"
,
"GPT
"
,
"ChatGPT"
,
"
Llama"
,
"T
ransformer"
,
"
DeepSeek"
,
"Pytorch
"
],
license
=
"Apache 2.0 License"
,
license
=
"Apache 2.0 License"
,
url
=
"https://github.com/hiyouga/LLaMA-Factory"
,
url
=
"https://github.com/hiyouga/LLaMA-Factory"
,
package_dir
=
{
""
:
"src"
},
package_dir
=
{
""
:
"src"
},
...
...
src/llamafactory/__init__.py
View file @
7ea81099
...
@@ -12,18 +12,17 @@
...
@@ -12,18 +12,17 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
r
"""
r
"""Efficient fine-tuning of large language models.
Efficient fine-tuning of large language models.
Level:
Level:
api, webui > chat, eval, train > data, model > hparams > extras
api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
Dependency graph:
main:
main:
transformers>=4.41.2,<=4.
49
.0,!=4.46.*,!=4.47.*,!=4.48.0
transformers>=4.41.2,<=4.
51
.0,!=4.46.*,!=4.47.*,!=4.48.0
datasets>=2.16.0,<=3.
2.0
datasets>=2.16.0,<=3.
4.1
accelerate>=0.34.0,<=1.
2.1
accelerate>=0.34.0,<=1.
5.2
peft>=0.1
1.1
,<=0.1
2
.0
peft>=0.1
4.0
,<=0.1
5
.0
trl>=0.8.6,<=0.9.6
trl>=0.8.6,<=0.9.6
attention:
attention:
transformers>=4.42.4 (gemma+fa2)
transformers>=4.42.4 (gemma+fa2)
...
...
src/llamafactory/api/app.py
View file @
7ea81099
...
@@ -16,9 +16,7 @@ import asyncio
...
@@ -16,9 +16,7 @@ import asyncio
import
os
import
os
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
functools
import
partial
from
functools
import
partial
from
typing
import
Optional
from
typing
import
Annotated
,
Optional
from
typing_extensions
import
Annotated
from
..chat
import
ChatModel
from
..chat
import
ChatModel
from
..extras.constants
import
EngineName
from
..extras.constants
import
EngineName
...
...
src/llamafactory/api/chat.py
View file @
7ea81099
...
@@ -18,11 +18,12 @@ import json
...
@@ -18,11 +18,12 @@ import json
import
os
import
os
import
re
import
re
import
uuid
import
uuid
from
typing
import
TYPE_CHECKING
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
from
collections.abc
import
AsyncGenerator
from
typing
import
TYPE_CHECKING
,
Optional
from
..data
import
Role
as
DataRole
from
..data
import
Role
as
DataRole
from
..extras
import
logging
from
..extras
import
logging
from
..extras.constants
import
IMAGE
_PLACEHOLDER
from
..extras.constants
import
AUDIO_PLACEHOLDER
,
IMAGE_PLACEHOLDER
,
VIDEO
_PLACEHOLDER
from
..extras.misc
import
is_env_enabled
from
..extras.misc
import
is_env_enabled
from
..extras.packages
import
is_fastapi_available
,
is_pillow_available
,
is_requests_available
from
..extras.packages
import
is_fastapi_available
,
is_pillow_available
,
is_requests_available
from
.common
import
dictify
,
jsonify
from
.common
import
dictify
,
jsonify
...
@@ -55,7 +56,7 @@ if is_requests_available():
...
@@ -55,7 +56,7 @@ if is_requests_available():
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
..chat
import
ChatModel
from
..chat
import
ChatModel
from
..data.mm_plugin
import
Image
Input
from
..data.mm_plugin
import
AudioInput
,
ImageInput
,
Video
Input
from
.protocol
import
ChatCompletionRequest
,
ScoreEvaluationRequest
from
.protocol
import
ChatCompletionRequest
,
ScoreEvaluationRequest
...
@@ -71,7 +72,14 @@ ROLE_MAPPING = {
...
@@ -71,7 +72,14 @@ ROLE_MAPPING = {
def
_process_request
(
def
_process_request
(
request
:
"ChatCompletionRequest"
,
request
:
"ChatCompletionRequest"
,
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
Optional
[
str
],
Optional
[
str
],
Optional
[
List
[
"ImageInput"
]]]:
)
->
tuple
[
list
[
dict
[
str
,
str
]],
Optional
[
str
],
Optional
[
str
],
Optional
[
list
[
"ImageInput"
]],
Optional
[
list
[
"VideoInput"
]],
Optional
[
list
[
"AudioInput"
]],
]:
if
is_env_enabled
(
"API_VERBOSE"
,
"1"
):
if
is_env_enabled
(
"API_VERBOSE"
,
"1"
):
logger
.
info_rank0
(
f
"==== request ====
\n
{
json
.
dumps
(
dictify
(
request
),
indent
=
2
,
ensure_ascii
=
False
)
}
"
)
logger
.
info_rank0
(
f
"==== request ====
\n
{
json
.
dumps
(
dictify
(
request
),
indent
=
2
,
ensure_ascii
=
False
)
}
"
)
...
@@ -87,7 +95,7 @@ def _process_request(
...
@@ -87,7 +95,7 @@ def _process_request(
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Only supports u/a/u/a/u..."
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Only supports u/a/u/a/u..."
)
input_messages
=
[]
input_messages
=
[]
images
=
[]
images
,
videos
,
audios
=
[],
[],
[]
for
i
,
message
in
enumerate
(
request
.
messages
):
for
i
,
message
in
enumerate
(
request
.
messages
):
if
i
%
2
==
0
and
message
.
role
not
in
[
Role
.
USER
,
Role
.
TOOL
]:
if
i
%
2
==
0
and
message
.
role
not
in
[
Role
.
USER
,
Role
.
TOOL
]:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid role"
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid role"
)
...
@@ -106,7 +114,7 @@ def _process_request(
...
@@ -106,7 +114,7 @@ def _process_request(
for
input_item
in
message
.
content
:
for
input_item
in
message
.
content
:
if
input_item
.
type
==
"text"
:
if
input_item
.
type
==
"text"
:
text_content
+=
input_item
.
text
text_content
+=
input_item
.
text
el
se
:
el
if
input_item
.
type
==
"image_url"
:
text_content
+=
IMAGE_PLACEHOLDER
text_content
+=
IMAGE_PLACEHOLDER
image_url
=
input_item
.
image_url
.
url
image_url
=
input_item
.
image_url
.
url
if
re
.
match
(
r
"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$"
,
image_url
):
# base64 image
if
re
.
match
(
r
"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$"
,
image_url
):
# base64 image
...
@@ -117,6 +125,28 @@ def _process_request(
...
@@ -117,6 +125,28 @@ def _process_request(
image_stream
=
requests
.
get
(
image_url
,
stream
=
True
).
raw
image_stream
=
requests
.
get
(
image_url
,
stream
=
True
).
raw
images
.
append
(
Image
.
open
(
image_stream
).
convert
(
"RGB"
))
images
.
append
(
Image
.
open
(
image_stream
).
convert
(
"RGB"
))
elif
input_item
.
type
==
"video_url"
:
text_content
+=
VIDEO_PLACEHOLDER
video_url
=
input_item
.
video_url
.
url
if
os
.
path
.
isfile
(
video_url
):
# local file
video_stream
=
open
(
video_url
,
"rb"
)
else
:
# web uri
video_stream
=
requests
.
get
(
video_url
,
stream
=
True
).
raw
videos
.
append
(
video_stream
)
elif
input_item
.
type
==
"audio_url"
:
text_content
+=
AUDIO_PLACEHOLDER
audio_url
=
input_item
.
audio_url
.
url
if
os
.
path
.
isfile
(
audio_url
):
# local file
audio_stream
=
open
(
audio_url
,
"rb"
)
else
:
# web uri
audio_stream
=
requests
.
get
(
audio_url
,
stream
=
True
).
raw
audios
.
append
(
audio_stream
)
else
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
f
"Invalid input type
{
input_item
.
type
}
."
)
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
text_content
})
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
text_content
})
else
:
else
:
...
@@ -131,7 +161,7 @@ def _process_request(
...
@@ -131,7 +161,7 @@ def _process_request(
else
:
else
:
tools
=
None
tools
=
None
return
input_messages
,
system
,
tools
,
images
or
None
return
input_messages
,
system
,
tools
,
images
or
None
,
videos
or
None
,
audios
or
None
def
_create_stream_chat_completion_chunk
(
def
_create_stream_chat_completion_chunk
(
...
@@ -150,12 +180,14 @@ async def create_chat_completion_response(
...
@@ -150,12 +180,14 @@ async def create_chat_completion_response(
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
)
->
"ChatCompletionResponse"
:
)
->
"ChatCompletionResponse"
:
completion_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
completion_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
input_messages
,
system
,
tools
,
images
=
_process_request
(
request
)
input_messages
,
system
,
tools
,
images
,
videos
,
audios
=
_process_request
(
request
)
responses
=
await
chat_model
.
achat
(
responses
=
await
chat_model
.
achat
(
input_messages
,
input_messages
,
system
,
system
,
tools
,
tools
,
images
,
images
,
videos
,
audios
,
do_sample
=
request
.
do_sample
,
do_sample
=
request
.
do_sample
,
temperature
=
request
.
temperature
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
top_p
=
request
.
top_p
,
...
@@ -201,7 +233,7 @@ async def create_stream_chat_completion_response(
...
@@ -201,7 +233,7 @@ async def create_stream_chat_completion_response(
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
completion_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
completion_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
input_messages
,
system
,
tools
,
images
=
_process_request
(
request
)
input_messages
,
system
,
tools
,
images
,
videos
,
audios
=
_process_request
(
request
)
if
tools
:
if
tools
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Cannot stream function calls."
)
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Cannot stream function calls."
)
...
@@ -216,6 +248,8 @@ async def create_stream_chat_completion_response(
...
@@ -216,6 +248,8 @@ async def create_stream_chat_completion_response(
system
,
system
,
tools
,
tools
,
images
,
images
,
videos
,
audios
,
do_sample
=
request
.
do_sample
,
do_sample
=
request
.
do_sample
,
temperature
=
request
.
temperature
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
top_p
=
request
.
top_p
,
...
...
src/llamafactory/api/common.py
View file @
7ea81099
...
@@ -13,14 +13,14 @@
...
@@ -13,14 +13,14 @@
# limitations under the License.
# limitations under the License.
import
json
import
json
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
from
typing
import
TYPE_CHECKING
,
Any
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
def
dictify
(
data
:
"BaseModel"
)
->
D
ict
[
str
,
Any
]:
def
dictify
(
data
:
"BaseModel"
)
->
d
ict
[
str
,
Any
]:
try
:
# pydantic v2
try
:
# pydantic v2
return
data
.
model_dump
(
exclude_unset
=
True
)
return
data
.
model_dump
(
exclude_unset
=
True
)
except
AttributeError
:
# pydantic v1
except
AttributeError
:
# pydantic v1
...
...
src/llamafactory/api/protocol.py
View file @
7ea81099
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
import
time
import
time
from
enum
import
Enum
,
unique
from
enum
import
Enum
,
unique
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
from
typing_extensions
import
Literal
from
typing_extensions
import
Literal
...
@@ -45,7 +45,7 @@ class ModelCard(BaseModel):
...
@@ -45,7 +45,7 @@ class ModelCard(BaseModel):
class
ModelList
(
BaseModel
):
class
ModelList
(
BaseModel
):
object
:
Literal
[
"list"
]
=
"list"
object
:
Literal
[
"list"
]
=
"list"
data
:
L
ist
[
ModelCard
]
=
[]
data
:
l
ist
[
ModelCard
]
=
[]
class
Function
(
BaseModel
):
class
Function
(
BaseModel
):
...
@@ -56,7 +56,7 @@ class Function(BaseModel):
...
@@ -56,7 +56,7 @@ class Function(BaseModel):
class
FunctionDefinition
(
BaseModel
):
class
FunctionDefinition
(
BaseModel
):
name
:
str
name
:
str
description
:
str
description
:
str
parameters
:
D
ict
[
str
,
Any
]
parameters
:
d
ict
[
str
,
Any
]
class
FunctionAvailable
(
BaseModel
):
class
FunctionAvailable
(
BaseModel
):
...
@@ -70,38 +70,41 @@ class FunctionCall(BaseModel):
...
@@ -70,38 +70,41 @@ class FunctionCall(BaseModel):
function
:
Function
function
:
Function
class
Image
URL
(
BaseModel
):
class
URL
(
BaseModel
):
url
:
str
url
:
str
detail
:
Literal
[
"auto"
,
"low"
,
"high"
]
=
"auto"
class
MultimodalInputItem
(
BaseModel
):
class
MultimodalInputItem
(
BaseModel
):
type
:
Literal
[
"text"
,
"image_url"
]
type
:
Literal
[
"text"
,
"image_url"
,
"video_url"
,
"audio_url"
]
text
:
Optional
[
str
]
=
None
text
:
Optional
[
str
]
=
None
image_url
:
Optional
[
ImageURL
]
=
None
image_url
:
Optional
[
URL
]
=
None
video_url
:
Optional
[
URL
]
=
None
audio_url
:
Optional
[
URL
]
=
None
class
ChatMessage
(
BaseModel
):
class
ChatMessage
(
BaseModel
):
role
:
Role
role
:
Role
content
:
Optional
[
Union
[
str
,
L
ist
[
MultimodalInputItem
]]]
=
None
content
:
Optional
[
Union
[
str
,
l
ist
[
MultimodalInputItem
]]]
=
None
tool_calls
:
Optional
[
L
ist
[
FunctionCall
]]
=
None
tool_calls
:
Optional
[
l
ist
[
FunctionCall
]]
=
None
class
ChatCompletionMessage
(
BaseModel
):
class
ChatCompletionMessage
(
BaseModel
):
role
:
Optional
[
Role
]
=
None
role
:
Optional
[
Role
]
=
None
content
:
Optional
[
str
]
=
None
content
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
L
ist
[
FunctionCall
]]
=
None
tool_calls
:
Optional
[
l
ist
[
FunctionCall
]]
=
None
class
ChatCompletionRequest
(
BaseModel
):
class
ChatCompletionRequest
(
BaseModel
):
model
:
str
model
:
str
messages
:
L
ist
[
ChatMessage
]
messages
:
l
ist
[
ChatMessage
]
tools
:
Optional
[
L
ist
[
FunctionAvailable
]]
=
None
tools
:
Optional
[
l
ist
[
FunctionAvailable
]]
=
None
do_sample
:
Optional
[
bool
]
=
None
do_sample
:
Optional
[
bool
]
=
None
temperature
:
Optional
[
float
]
=
None
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
n
:
int
=
1
n
:
int
=
1
max_tokens
:
Optional
[
int
]
=
None
max_tokens
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
L
ist
[
str
]]]
=
None
stop
:
Optional
[
Union
[
str
,
l
ist
[
str
]]]
=
None
stream
:
bool
=
False
stream
:
bool
=
False
...
@@ -128,7 +131,7 @@ class ChatCompletionResponse(BaseModel):
...
@@ -128,7 +131,7 @@ class ChatCompletionResponse(BaseModel):
object
:
Literal
[
"chat.completion"
]
=
"chat.completion"
object
:
Literal
[
"chat.completion"
]
=
"chat.completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
model
:
str
choices
:
L
ist
[
ChatCompletionResponseChoice
]
choices
:
l
ist
[
ChatCompletionResponseChoice
]
usage
:
ChatCompletionResponseUsage
usage
:
ChatCompletionResponseUsage
...
@@ -137,12 +140,12 @@ class ChatCompletionStreamResponse(BaseModel):
...
@@ -137,12 +140,12 @@ class ChatCompletionStreamResponse(BaseModel):
object
:
Literal
[
"chat.completion.chunk"
]
=
"chat.completion.chunk"
object
:
Literal
[
"chat.completion.chunk"
]
=
"chat.completion.chunk"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
model
:
str
choices
:
L
ist
[
ChatCompletionStreamResponseChoice
]
choices
:
l
ist
[
ChatCompletionStreamResponseChoice
]
class
ScoreEvaluationRequest
(
BaseModel
):
class
ScoreEvaluationRequest
(
BaseModel
):
model
:
str
model
:
str
messages
:
L
ist
[
str
]
messages
:
l
ist
[
str
]
max_length
:
Optional
[
int
]
=
None
max_length
:
Optional
[
int
]
=
None
...
@@ -150,4 +153,4 @@ class ScoreEvaluationResponse(BaseModel):
...
@@ -150,4 +153,4 @@ class ScoreEvaluationResponse(BaseModel):
id
:
str
id
:
str
object
:
Literal
[
"score.evaluation"
]
=
"score.evaluation"
object
:
Literal
[
"score.evaluation"
]
=
"score.evaluation"
model
:
str
model
:
str
scores
:
L
ist
[
float
]
scores
:
l
ist
[
float
]
src/llamafactory/chat/base_engine.py
View file @
7ea81099
...
@@ -13,8 +13,9 @@
...
@@ -13,8 +13,9 @@
# limitations under the License.
# limitations under the License.
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections.abc
import
AsyncGenerator
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
,
Union
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -36,8 +37,7 @@ class Response:
...
@@ -36,8 +37,7 @@ class Response:
class
BaseEngine
(
ABC
):
class
BaseEngine
(
ABC
):
r
"""
r
"""Base class for inference engine of chat models.
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
Must implements async methods: chat(), stream_chat() and get_scores().
"""
"""
...
@@ -47,7 +47,7 @@ class BaseEngine(ABC):
...
@@ -47,7 +47,7 @@ class BaseEngine(ABC):
tokenizer
:
"PreTrainedTokenizer"
tokenizer
:
"PreTrainedTokenizer"
can_generate
:
bool
can_generate
:
bool
template
:
"Template"
template
:
"Template"
generating_args
:
D
ict
[
str
,
Any
]
generating_args
:
d
ict
[
str
,
Any
]
@
abstractmethod
@
abstractmethod
def
__init__
(
def
__init__
(
...
@@ -57,50 +57,42 @@ class BaseEngine(ABC):
...
@@ -57,50 +57,42 @@ class BaseEngine(ABC):
finetuning_args
:
"FinetuningArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
)
->
None
:
r
"""
r
"""Initialize an inference engine."""
Initializes an inference engine.
"""
...
...
@
abstractmethod
@
abstractmethod
async
def
chat
(
async
def
chat
(
self
,
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
)
->
list
[
"Response"
]:
r
"""
r
"""Get a list of responses of the chat model."""
Gets a list of responses of the chat model.
"""
...
...
@
abstractmethod
@
abstractmethod
async
def
stream_chat
(
async
def
stream_chat
(
self
,
self
,
messages
:
Sequence
[
D
ict
[
str
,
str
]],
messages
:
list
[
d
ict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
images
:
Optional
[
list
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
videos
:
Optional
[
list
[
"VideoInput"
]]
=
None
,
audios
:
Optional
[
Sequence
[
"AudioInput"
]]
=
None
,
audios
:
Optional
[
list
[
"AudioInput"
]]
=
None
,
**
input_kwargs
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
)
->
AsyncGenerator
[
str
,
None
]:
r
"""
r
"""Get the response token-by-token of the chat model."""
Gets the response token-by-token of the chat model.
"""
...
...
@
abstractmethod
@
abstractmethod
async
def
get_scores
(
async
def
get_scores
(
self
,
self
,
batch_input
:
L
ist
[
str
],
batch_input
:
l
ist
[
str
],
**
input_kwargs
,
**
input_kwargs
,
)
->
List
[
float
]:
)
->
list
[
float
]:
r
"""
r
"""Get a list of scores of the reward model."""
Gets a list of scores of the reward model.
"""
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment