Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
qwen2.5-coder_pytorch
Commits
53b3977b
Commit
53b3977b
authored
Jul 11, 2025
by
dongchy920
Browse files
Initial commit
parents
Pipeline
#2841
failed with stages
in 0 seconds
Changes
350
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2257 additions
and
0 deletions
+2257
-0
LLaMA-Factory/scripts/loftq_init.py
LLaMA-Factory/scripts/loftq_init.py
+88
-0
LLaMA-Factory/scripts/pissa_init.py
LLaMA-Factory/scripts/pissa_init.py
+86
-0
LLaMA-Factory/scripts/stat_utils/cal_flops.py
LLaMA-Factory/scripts/stat_utils/cal_flops.py
+49
-0
LLaMA-Factory/scripts/stat_utils/cal_lr.py
LLaMA-Factory/scripts/stat_utils/cal_lr.py
+98
-0
LLaMA-Factory/scripts/stat_utils/cal_mfu.py
LLaMA-Factory/scripts/stat_utils/cal_mfu.py
+163
-0
LLaMA-Factory/scripts/stat_utils/cal_ppl.py
LLaMA-Factory/scripts/stat_utils/cal_ppl.py
+136
-0
LLaMA-Factory/scripts/stat_utils/length_cdf.py
LLaMA-Factory/scripts/stat_utils/length_cdf.py
+68
-0
LLaMA-Factory/scripts/vllm_infer.py
LLaMA-Factory/scripts/vllm_infer.py
+144
-0
LLaMA-Factory/setup.py
LLaMA-Factory/setup.py
+105
-0
LLaMA-Factory/src/api.py
LLaMA-Factory/src/api.py
+33
-0
LLaMA-Factory/src/llamafactory/__init__.py
LLaMA-Factory/src/llamafactory/__init__.py
+47
-0
LLaMA-Factory/src/llamafactory/api/__init__.py
LLaMA-Factory/src/llamafactory/api/__init__.py
+0
-0
LLaMA-Factory/src/llamafactory/api/app.py
LLaMA-Factory/src/llamafactory/api/app.py
+134
-0
LLaMA-Factory/src/llamafactory/api/chat.py
LLaMA-Factory/src/llamafactory/api/chat.py
+237
-0
LLaMA-Factory/src/llamafactory/api/common.py
LLaMA-Factory/src/llamafactory/api/common.py
+34
-0
LLaMA-Factory/src/llamafactory/api/protocol.py
LLaMA-Factory/src/llamafactory/api/protocol.py
+153
-0
LLaMA-Factory/src/llamafactory/chat/__init__.py
LLaMA-Factory/src/llamafactory/chat/__init__.py
+19
-0
LLaMA-Factory/src/llamafactory/chat/base_engine.py
LLaMA-Factory/src/llamafactory/chat/base_engine.py
+102
-0
LLaMA-Factory/src/llamafactory/chat/chat_model.py
LLaMA-Factory/src/llamafactory/chat/chat_model.py
+187
-0
LLaMA-Factory/src/llamafactory/chat/hf_engine.py
LLaMA-Factory/src/llamafactory/chat/hf_engine.py
+374
-0
No files found.
LLaMA-Factory/scripts/loftq_init.py
0 → 100644
View file @
53b3977b
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
import
fire
from
peft
import
LoftQConfig
,
LoraConfig
,
TaskType
,
get_peft_model
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
def
quantize_loftq
(
model_name_or_path
:
str
,
output_dir
:
str
,
loftq_bits
:
int
=
4
,
loftq_iter
:
int
=
4
,
lora_alpha
:
int
=
None
,
lora_rank
:
int
=
16
,
lora_dropout
:
float
=
0
,
lora_target
:
tuple
=
(
"q_proj"
,
"v_proj"
),
save_safetensors
:
bool
=
True
,
):
r
"""
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
if
isinstance
(
lora_target
,
str
):
lora_target
=
[
name
.
strip
()
for
name
in
lora_target
.
split
(
","
)]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
,
torch_dtype
=
"auto"
)
loftq_config
=
LoftQConfig
(
loftq_bits
=
loftq_bits
,
loftq_iter
=
loftq_iter
)
lora_config
=
LoraConfig
(
task_type
=
TaskType
.
CAUSAL_LM
,
inference_mode
=
True
,
r
=
lora_rank
,
lora_alpha
=
lora_alpha
if
lora_alpha
is
not
None
else
lora_rank
*
2
,
lora_dropout
=
lora_dropout
,
target_modules
=
lora_target
,
init_lora_weights
=
"loftq"
,
loftq_config
=
loftq_config
,
)
# Init LoftQ model
print
(
"Initializing LoftQ weights, it may be take several minutes, wait patiently."
)
peft_model
=
get_peft_model
(
model
,
lora_config
)
loftq_dir
=
os
.
path
.
join
(
output_dir
,
"loftq_init"
)
# Save LoftQ model
setattr
(
peft_model
.
peft_config
[
"default"
],
"base_model_name_or_path"
,
os
.
path
.
abspath
(
output_dir
))
setattr
(
peft_model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
# don't apply loftq again
peft_model
.
save_pretrained
(
loftq_dir
,
safe_serialization
=
save_safetensors
)
print
(
f
"Adapter weights saved in
{
loftq_dir
}
"
)
# Save base model
base_model
:
"PreTrainedModel"
=
peft_model
.
unload
()
base_model
.
save_pretrained
(
output_dir
,
safe_serialization
=
save_safetensors
)
tokenizer
.
save_pretrained
(
output_dir
)
print
(
f
"Model weights saved in
{
output_dir
}
"
)
print
(
"- Fine-tune this model with:"
)
print
(
f
"model_name_or_path:
{
output_dir
}
"
)
print
(
f
"adapter_name_or_path:
{
loftq_dir
}
"
)
print
(
"finetuning_type: lora"
)
print
(
f
"quantization_bit:
{
loftq_bits
}
"
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
quantize_loftq
)
LLaMA-Factory/scripts/pissa_init.py
0 → 100644
View file @
53b3977b
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
import
fire
from
peft
import
LoraConfig
,
TaskType
,
get_peft_model
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
def
quantize_pissa
(
model_name_or_path
:
str
,
output_dir
:
str
,
pissa_iter
:
int
=
16
,
lora_alpha
:
int
=
None
,
lora_rank
:
int
=
16
,
lora_dropout
:
float
=
0
,
lora_target
:
tuple
=
(
"q_proj"
,
"v_proj"
),
save_safetensors
:
bool
=
True
,
):
r
"""
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
if
isinstance
(
lora_target
,
str
):
lora_target
=
[
name
.
strip
()
for
name
in
lora_target
.
split
(
","
)]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
,
torch_dtype
=
"auto"
)
lora_config
=
LoraConfig
(
task_type
=
TaskType
.
CAUSAL_LM
,
r
=
lora_rank
,
lora_alpha
=
lora_alpha
if
lora_alpha
is
not
None
else
lora_rank
*
2
,
lora_dropout
=
lora_dropout
,
target_modules
=
lora_target
,
init_lora_weights
=
"pissa"
if
pissa_iter
==
-
1
else
f
"pissa_niter_
{
pissa_iter
}
"
,
)
# Init PiSSA model
peft_model
=
get_peft_model
(
model
,
lora_config
)
pissa_dir
=
os
.
path
.
join
(
output_dir
,
"pissa_init"
)
# Save PiSSA model
setattr
(
peft_model
.
peft_config
[
"default"
],
"base_model_name_or_path"
,
os
.
path
.
abspath
(
output_dir
))
setattr
(
peft_model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
# don't apply pissa again
peft_model
.
save_pretrained
(
pissa_dir
,
safe_serialization
=
save_safetensors
)
print
(
f
"Adapter weights saved in
{
pissa_dir
}
"
)
# Save base model
base_model
:
"PreTrainedModel"
=
peft_model
.
unload
()
base_model
.
save_pretrained
(
output_dir
,
safe_serialization
=
save_safetensors
)
tokenizer
.
save_pretrained
(
output_dir
)
print
(
f
"Model weights saved in
{
output_dir
}
"
)
print
(
"- Fine-tune this model with:"
)
print
(
f
"model_name_or_path:
{
output_dir
}
"
)
print
(
f
"adapter_name_or_path:
{
pissa_dir
}
"
)
print
(
"finetuning_type: lora"
)
print
(
"pissa_init: false"
)
print
(
"pissa_convert: true"
)
print
(
"- and optionally with:"
)
print
(
"quantization_bit: 4"
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
quantize_pissa
)
LLaMA-Factory/scripts/stat_utils/cal_flops.py
0 → 100644
View file @
53b3977b
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
#
# This code is inspired by the Microsoft's DeepSpeed library.
# https://www.deepspeed.ai/tutorials/flops-profiler/
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
fire
import
torch
from
deepspeed.accelerator
import
get_accelerator
# type: ignore
from
deepspeed.profiling.flops_profiler
import
get_model_profile
# type: ignore
from
llamafactory.chat
import
ChatModel
def
calculate_flops
(
model_name_or_path
:
str
,
batch_size
:
int
=
1
,
seq_length
:
int
=
512
,
flash_attn
:
str
=
"auto"
,
):
r
"""
Calculates the flops of pre-trained models.
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
"""
with
get_accelerator
().
device
(
0
):
chat_model
=
ChatModel
(
dict
(
model_name_or_path
=
model_name_or_path
,
template
=
"empty"
,
flash_attn
=
flash_attn
))
fake_input
=
torch
.
ones
((
batch_size
,
seq_length
),
dtype
=
torch
.
long
,
device
=
chat_model
.
engine
.
model
.
device
)
input_dict
=
{
"input_ids"
:
fake_input
,
"labels"
:
fake_input
.
clone
()}
flops
,
macs
,
params
=
get_model_profile
(
chat_model
.
engine
.
model
,
kwargs
=
input_dict
,
print_profile
=
True
,
detailed
=
True
)
print
(
"FLOPs:"
,
flops
)
print
(
"MACs:"
,
macs
)
print
(
"Params:"
,
params
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
calculate_flops
)
LLaMA-Factory/scripts/stat_utils/cal_lr.py
0 → 100644
View file @
53b3977b
# Copyright 2024 imoneoi and the LlamaFactory team.
#
# This code is inspired by the imoneoi's OpenChat library.
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
from
typing
import
Literal
import
fire
import
torch
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
transformers
import
DataCollatorForLanguageModeling
from
llamafactory.data
import
MultiModalDataCollatorForSeq2Seq
,
get_dataset
,
get_template_and_fix_tokenizer
from
llamafactory.extras.constants
import
IGNORE_INDEX
from
llamafactory.hparams
import
get_train_args
from
llamafactory.model
import
load_tokenizer
BASE_LR
=
3e-4
# 1.5e-4 for 30B-70B models
BASE_BS
=
4_000_000
# from llama paper
def
calculate_lr
(
model_name_or_path
:
str
,
batch_size
:
int
,
# total batch size, namely (batch size * gradient accumulation * world size)
stage
:
Literal
[
"pt"
,
"sft"
]
=
"sft"
,
dataset
:
str
=
"alpaca_en_demo"
,
dataset_dir
:
str
=
"data"
,
template
:
str
=
"default"
,
cutoff_len
:
int
=
1024
,
# i.e. maximum input length during training
is_mistral_or_gemma
:
bool
=
False
,
# mistral and gemma models opt for a smaller learning rate,
packing
:
bool
=
False
,
):
r
"""
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage:
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
"""
model_args
,
data_args
,
training_args
,
_
,
_
=
get_train_args
(
dict
(
stage
=
stage
,
model_name_or_path
=
model_name_or_path
,
dataset
=
dataset
,
dataset_dir
=
dataset_dir
,
template
=
template
,
cutoff_len
=
cutoff_len
,
packing
=
packing
,
output_dir
=
"dummy_dir"
,
overwrite_cache
=
True
,
do_train
=
True
,
)
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
trainset
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
,
**
tokenizer_module
)[
"train_dataset"
]
if
stage
==
"pt"
:
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm
=
False
)
elif
stage
==
"sft"
:
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
tokenizer
=
tokenizer
,
label_pad_token_id
=
IGNORE_INDEX
)
else
:
raise
NotImplementedError
(
f
"Stage does not supported:
{
stage
}
."
)
dataloader
=
DataLoader
(
trainset
,
batch_size
,
shuffle
=
False
,
collate_fn
=
data_collator
,
pin_memory
=
True
)
valid_tokens
,
total_tokens
=
0
,
0
for
batch
in
tqdm
(
dataloader
):
valid_tokens
+=
torch
.
sum
(
batch
[
"labels"
]
!=
IGNORE_INDEX
).
item
()
total_tokens
+=
torch
.
numel
(
batch
[
"labels"
])
valid_ratio
=
valid_tokens
/
total_tokens
token_batch_size
=
cutoff_len
*
batch_size
*
valid_ratio
lr
=
BASE_LR
*
math
.
sqrt
(
token_batch_size
/
BASE_BS
)
# lr ~ sqrt(batch_size)
lr
=
lr
/
6.0
if
is_mistral_or_gemma
else
lr
print
(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective token batch size {:.2f}"
.
format
(
lr
,
valid_ratio
*
100
,
token_batch_size
)
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
calculate_lr
)
LLaMA-Factory/scripts/stat_utils/cal_mfu.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
import
fire
import
torch
import
torch.distributed
as
dist
from
transformers
import
AutoConfig
from
llamafactory.train.tuner
import
run_exp
BASE
=
2
# gemm (add + mul)
def
compute_model_flops
(
model_name_or_path
:
str
,
total_batch_size
:
int
,
seq_length
:
int
,
include_backward
:
bool
=
True
,
include_recompute
:
bool
=
False
,
include_flashattn
:
bool
=
False
,
)
->
int
:
r
"""
Calculates the FLOPs of model per forward/backward pass.
"""
config
=
AutoConfig
.
from_pretrained
(
model_name_or_path
)
hidden_size
=
getattr
(
config
,
"hidden_size"
,
None
)
vocab_size
=
getattr
(
config
,
"vocab_size"
,
None
)
intermediate_size
=
getattr
(
config
,
"intermediate_size"
,
None
)
num_attention_heads
=
getattr
(
config
,
"num_attention_heads"
,
None
)
num_key_value_heads
=
getattr
(
config
,
"num_key_value_heads"
,
None
)
num_hidden_layers
=
getattr
(
config
,
"num_hidden_layers"
,
None
)
tie_word_embeddings
=
getattr
(
config
,
"tie_word_embeddings"
,
False
)
# mlp module
mlp_flops_per_token
=
3
*
BASE
*
hidden_size
*
intermediate_size
# up, gate, down
mlp_flops
=
total_batch_size
*
seq_length
*
num_hidden_layers
*
mlp_flops_per_token
# attn projector module
q_flops_per_token
=
BASE
*
hidden_size
*
hidden_size
o_flops_per_token
=
BASE
*
hidden_size
*
hidden_size
k_flops_per_token
=
BASE
*
hidden_size
*
hidden_size
*
num_key_value_heads
//
num_attention_heads
v_flops_per_token
=
BASE
*
hidden_size
*
hidden_size
*
num_key_value_heads
//
num_attention_heads
attn_proj_flops_per_token
=
q_flops_per_token
+
o_flops_per_token
+
k_flops_per_token
+
v_flops_per_token
attn_proj_flops
=
total_batch_size
*
seq_length
*
num_hidden_layers
*
attn_proj_flops_per_token
# attn sdpa module
sdpa_flops_per_layer
=
2
*
BASE
*
hidden_size
*
seq_length
*
seq_length
# (q * k^T) * v
sdpa_flops
=
total_batch_size
*
num_hidden_layers
*
sdpa_flops_per_layer
# embedding module
embedding_flops_per_token
=
hidden_size
*
vocab_size
embedding_flops
=
total_batch_size
*
seq_length
*
embedding_flops_per_token
if
tie_word_embeddings
is
False
:
embedding_flops
*=
2
non_embedding_flops
=
mlp_flops
+
attn_proj_flops
+
sdpa_flops
non_embedding_coeff
,
embedding_coeff
=
1
,
1
if
include_backward
:
non_embedding_coeff
+=
2
embedding_coeff
+=
2
if
include_recompute
:
non_embedding_coeff
+=
1
total_flops
=
non_embedding_coeff
*
non_embedding_flops
+
embedding_coeff
*
embedding_flops
if
include_flashattn
:
total_flops
+=
sdpa_flops
return
total_flops
def
compute_device_flops
(
world_size
:
int
)
->
float
:
r
"""
Calculates the FLOPs of the device capability per second.
"""
device_name
=
torch
.
cuda
.
get_device_name
()
if
"H100"
in
device_name
or
"H800"
in
device_name
:
return
989
*
1e12
*
world_size
elif
"A100"
in
device_name
or
"A800"
in
device_name
:
return
312
*
1e12
*
world_size
elif
"V100"
in
device_name
:
return
125
*
1e12
*
world_size
elif
"4090"
in
device_name
:
return
98
*
1e12
*
world_size
else
:
raise
NotImplementedError
(
f
"Device not supported:
{
device_name
}
."
)
def
calculate_mfu
(
model_name_or_path
:
str
,
batch_size
:
int
=
1
,
seq_length
:
int
=
1024
,
num_steps
:
int
=
100
,
finetuning_type
:
str
=
"lora"
,
flash_attn
:
str
=
"auto"
,
deepspeed_stage
:
int
=
0
,
disable_gc
:
bool
=
False
,
liger_kernel
:
bool
=
False
,
unsloth_gc
:
bool
=
False
,
)
->
float
:
r
"""
Calculates MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
"""
args
=
{
"model_name_or_path"
:
model_name_or_path
,
"flash_attn"
:
flash_attn
,
"disable_gradient_checkpointing"
:
disable_gc
,
"enable_liger_kernel"
:
liger_kernel
,
"use_unsloth_gc"
:
unsloth_gc
,
"stage"
:
"pt"
,
"do_train"
:
True
,
"finetuning_type"
:
finetuning_type
,
"dataset"
:
"c4_demo"
,
"cutoff_len"
:
seq_length
,
"output_dir"
:
os
.
path
.
join
(
"saves"
,
"test_mfu"
),
"logging_strategy"
:
"no"
,
"save_strategy"
:
"no"
,
"save_only_model"
:
True
,
"overwrite_output_dir"
:
True
,
"per_device_train_batch_size"
:
batch_size
,
"max_steps"
:
num_steps
,
"bf16"
:
True
,
}
if
deepspeed_stage
in
[
2
,
3
]:
args
[
"deepspeed"
]
=
f
"examples/deepspeed/ds_z
{
deepspeed_stage
}
_config.json"
run_exp
(
args
)
with
open
(
os
.
path
.
join
(
"saves"
,
"test_mfu"
,
"all_results.json"
),
encoding
=
"utf-8"
)
as
f
:
result
=
json
.
load
(
f
)
if
dist
.
is_initialized
():
world_size
=
dist
.
get_world_size
()
else
:
world_size
=
1
total_batch_size
=
batch_size
*
world_size
mfu_value
=
(
result
[
"train_steps_per_second"
]
*
compute_model_flops
(
model_name_or_path
,
total_batch_size
,
seq_length
)
/
compute_device_flops
(
world_size
)
)
print
(
f
"MFU:
{
mfu_value
*
100
:.
2
f
}
%"
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
calculate_mfu
)
LLaMA-Factory/scripts/stat_utils/cal_ppl.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Literal
,
Optional
,
Sequence
import
fire
import
torch
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
transformers
import
DataCollatorForLanguageModeling
from
llamafactory.data
import
MultiModalDataCollatorForSeq2Seq
,
get_dataset
,
get_template_and_fix_tokenizer
from
llamafactory.extras.constants
import
IGNORE_INDEX
from
llamafactory.hparams
import
get_train_args
from
llamafactory.model
import
load_model
,
load_tokenizer
@
dataclass
class
PairwiseDataCollatorWithPadding
(
MultiModalDataCollatorForSeq2Seq
):
r
"""
Data collator for pairwise data.
"""
train_on_prompt
:
bool
=
False
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
torch
.
Tensor
]:
r
"""
Pads batched data to the longest sequence in the batch.
"""
chosen_features
=
[]
for
feature
in
features
:
chosen_features
.
append
(
{
"input_ids"
:
feature
[
"chosen_input_ids"
],
"attention_mask"
:
feature
[
"chosen_attention_mask"
],
"labels"
:
feature
[
"chosen_input_ids"
]
if
self
.
train_on_prompt
else
feature
[
"chosen_labels"
],
"images"
:
feature
[
"images"
],
"videos"
:
feature
[
"videos"
],
}
)
return
super
().
__call__
(
chosen_features
)
def
calculate_ppl
(
model_name_or_path
:
str
,
save_name
:
str
=
"ppl.json"
,
batch_size
:
int
=
4
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
]
=
"sft"
,
dataset
:
str
=
"alpaca_en_demo"
,
dataset_dir
:
str
=
"data"
,
template
:
str
=
"default"
,
cutoff_len
:
int
=
1024
,
max_samples
:
Optional
[
int
]
=
None
,
train_on_prompt
:
bool
=
False
,
):
r
"""
Calculates the ppl on the dataset of the pre-trained models.
Usage: export CUDA_VISIBLE_DEVICES=0
python cal_ppl.py --model_name_or_path path_to_model --dataset alpaca_en_demo --save_name ppl.json
"""
model_args
,
data_args
,
training_args
,
finetuning_args
,
_
=
get_train_args
(
dict
(
stage
=
stage
,
model_name_or_path
=
model_name_or_path
,
dataset
=
dataset
,
dataset_dir
=
dataset_dir
,
template
=
template
,
cutoff_len
=
cutoff_len
,
max_samples
=
max_samples
,
train_on_prompt
=
train_on_prompt
,
output_dir
=
"dummy_dir"
,
overwrite_cache
=
True
,
do_train
=
True
,
)
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
trainset
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
stage
,
**
tokenizer_module
)[
"train_dataset"
]
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
is_trainable
=
False
)
if
stage
==
"pt"
:
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm
=
False
)
elif
stage
==
"sft"
:
data_collator
=
MultiModalDataCollatorForSeq2Seq
(
template
=
template
,
tokenizer
=
tokenizer
,
label_pad_token_id
=
IGNORE_INDEX
)
elif
stage
==
"rm"
:
data_collator
=
PairwiseDataCollatorWithPadding
(
template
=
template
,
tokenizer
=
tokenizer
,
label_pad_token_id
=
IGNORE_INDEX
,
train_on_prompt
=
train_on_prompt
)
else
:
raise
NotImplementedError
(
f
"Stage does not supported:
{
stage
}
."
)
dataloader
=
DataLoader
(
trainset
,
batch_size
,
shuffle
=
False
,
collate_fn
=
data_collator
,
pin_memory
=
True
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
(
reduction
=
"none"
)
total_ppl
=
0
perplexities
=
[]
batch
:
Dict
[
str
,
"torch.Tensor"
]
with
torch
.
no_grad
():
for
batch
in
tqdm
(
dataloader
):
batch
=
batch
.
to
(
model
.
device
)
outputs
=
model
(
**
batch
)
shift_logits
:
"torch.Tensor"
=
outputs
[
"logits"
][...,
:
-
1
,
:]
shift_labels
:
"torch.Tensor"
=
batch
[
"labels"
][...,
1
:]
loss_mask
=
shift_labels
!=
IGNORE_INDEX
flatten_logits
=
shift_logits
.
contiguous
().
view
(
shift_labels
.
size
(
0
)
*
shift_labels
.
size
(
1
),
-
1
)
flatten_labels
=
shift_labels
.
contiguous
().
view
(
-
1
)
token_logps
:
"torch.Tensor"
=
criterion
(
flatten_logits
,
flatten_labels
)
token_logps
=
token_logps
.
contiguous
().
view
(
shift_logits
.
size
(
0
),
-
1
)
sentence_logps
=
(
token_logps
*
loss_mask
).
sum
(
-
1
)
/
loss_mask
.
sum
(
-
1
)
total_ppl
+=
sentence_logps
.
exp
().
sum
().
item
()
perplexities
.
extend
(
sentence_logps
.
exp
().
tolist
())
with
open
(
save_name
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
perplexities
,
f
,
indent
=
2
)
print
(
f
"Average perplexity is
{
total_ppl
/
len
(
perplexities
):.
2
f
}
"
)
print
(
f
"Perplexities have been saved at
{
save_name
}
."
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
calculate_ppl
)
LLaMA-Factory/scripts/stat_utils/length_cdf.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
import
fire
from
tqdm
import
tqdm
from
llamafactory.data
import
get_dataset
,
get_template_and_fix_tokenizer
from
llamafactory.hparams
import
get_train_args
from
llamafactory.model
import
load_tokenizer
def
length_cdf
(
model_name_or_path
:
str
,
dataset
:
str
=
"alpaca_en_demo"
,
dataset_dir
:
str
=
"data"
,
template
:
str
=
"default"
,
interval
:
int
=
1000
,
):
r
"""
Calculates the distribution of the input lengths in the dataset.
Usage: export CUDA_VISIBLE_DEVICES=0
python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default
"""
model_args
,
data_args
,
training_args
,
_
,
_
=
get_train_args
(
dict
(
stage
=
"sft"
,
model_name_or_path
=
model_name_or_path
,
dataset
=
dataset
,
dataset_dir
=
dataset_dir
,
template
=
template
,
cutoff_len
=
1_000_000
,
output_dir
=
"dummy_dir"
,
overwrite_cache
=
True
,
do_train
=
True
,
)
)
tokenizer_module
=
load_tokenizer
(
model_args
)
template
=
get_template_and_fix_tokenizer
(
tokenizer_module
[
"tokenizer"
],
data_args
)
trainset
=
get_dataset
(
template
,
model_args
,
data_args
,
training_args
,
"sft"
,
**
tokenizer_module
)[
"train_dataset"
]
total_num
=
len
(
trainset
)
length_dict
=
defaultdict
(
int
)
for
sample
in
tqdm
(
trainset
[
"input_ids"
]):
length_dict
[
len
(
sample
)
//
interval
*
interval
]
+=
1
length_tuples
=
list
(
length_dict
.
items
())
length_tuples
.
sort
()
count_accu
,
prob_accu
=
0
,
0
for
length
,
count
in
length_tuples
:
count_accu
+=
count
prob_accu
+=
count
/
total_num
*
100
print
(
f
"
{
count_accu
:
d
}
(
{
prob_accu
:.
2
f
}
%) samples have length <
{
length
+
interval
}
."
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
length_cdf
)
LLaMA-Factory/scripts/vllm_infer.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
fire
from
transformers
import
Seq2SeqTrainingArguments
from
llamafactory.data
import
get_dataset
,
get_template_and_fix_tokenizer
from
llamafactory.extras.constants
import
IGNORE_INDEX
from
llamafactory.extras.misc
import
get_device_count
from
llamafactory.extras.packages
import
is_pillow_available
,
is_vllm_available
from
llamafactory.hparams
import
get_infer_args
from
llamafactory.model
import
load_tokenizer
if
is_pillow_available
():
from
PIL
import
Image
from
PIL.Image
import
Image
as
ImageObject
if
is_vllm_available
():
from
vllm
import
LLM
,
SamplingParams
from
vllm.lora.request
import
LoRARequest
def
vllm_infer
(
model_name_or_path
:
str
,
adapter_name_or_path
:
str
=
None
,
dataset
:
str
=
"alpaca_en_demo"
,
dataset_dir
:
str
=
"data"
,
template
:
str
=
"default"
,
cutoff_len
:
int
=
2048
,
max_samples
:
int
=
None
,
vllm_config
:
str
=
"{}"
,
save_name
:
str
=
"generated_predictions.jsonl"
,
temperature
:
float
=
0.95
,
top_p
:
float
=
0.7
,
top_k
:
int
=
50
,
max_new_tokens
:
int
=
1024
,
repetition_penalty
:
float
=
1.0
,
):
r
"""
Performs batch generation using vLLM engine, which supports tensor parallelism.
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
"""
model_args
,
data_args
,
_
,
generating_args
=
get_infer_args
(
dict
(
model_name_or_path
=
model_name_or_path
,
adapter_name_or_path
=
adapter_name_or_path
,
dataset
=
dataset
,
dataset_dir
=
dataset_dir
,
template
=
template
,
cutoff_len
=
cutoff_len
,
max_samples
=
max_samples
,
vllm_config
=
vllm_config
,
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
max_new_tokens
=
max_new_tokens
,
repetition_penalty
=
repetition_penalty
,
)
)
training_args
=
Seq2SeqTrainingArguments
(
output_dir
=
"dummy_dir"
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
template_obj
=
get_template_and_fix_tokenizer
(
tokenizer
,
data_args
)
template_obj
.
mm_plugin
.
expand_mm_tokens
=
False
# for vllm generate
dataset_module
=
get_dataset
(
template_obj
,
model_args
,
data_args
,
training_args
,
"ppo"
,
**
tokenizer_module
)
inputs
,
prompts
,
labels
=
[],
[],
[]
for
sample
in
dataset_module
[
"train_dataset"
]:
if
sample
[
"images"
]:
multi_modal_data
=
{
"image"
:
[]}
for
image
in
sample
[
"images"
]:
if
not
isinstance
(
image
,
(
str
,
ImageObject
)):
raise
ValueError
(
f
"Expected image input is a path or PIL.Image, but got
{
type
(
image
)
}
."
)
if
isinstance
(
image
,
str
):
image
=
Image
.
open
(
image
).
convert
(
"RGB"
)
multi_modal_data
[
"image"
].
append
(
image
)
else
:
multi_modal_data
=
None
inputs
.
append
({
"prompt_token_ids"
:
sample
[
"input_ids"
],
"multi_modal_data"
:
multi_modal_data
})
prompts
.
append
(
tokenizer
.
decode
(
sample
[
"input_ids"
],
skip_special_tokens
=
False
))
labels
.
append
(
tokenizer
.
decode
(
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
sample
[
"labels"
])),
skip_special_tokens
=
False
)
)
sampling_params
=
SamplingParams
(
repetition_penalty
=
generating_args
.
repetition_penalty
or
1.0
,
# repetition_penalty must > 0
temperature
=
generating_args
.
temperature
,
top_p
=
generating_args
.
top_p
or
1.0
,
# top_p must > 0
top_k
=
generating_args
.
top_k
,
stop_token_ids
=
[
tokenizer
.
eos_token_id
]
+
tokenizer
.
additional_special_tokens_ids
,
max_tokens
=
generating_args
.
max_new_tokens
,
skip_special_tokens
=
False
,
)
if
model_args
.
adapter_name_or_path
is
not
None
:
lora_request
=
LoRARequest
(
"default"
,
1
,
model_args
.
adapter_name_or_path
[
0
])
else
:
lora_request
=
None
engine_args
=
{
"model"
:
model_args
.
model_name_or_path
,
"trust_remote_code"
:
True
,
"dtype"
:
model_args
.
infer_dtype
,
"tensor_parallel_size"
:
get_device_count
()
or
1
,
"disable_log_stats"
:
True
,
"enable_lora"
:
model_args
.
adapter_name_or_path
is
not
None
,
}
if
template_obj
.
mm_plugin
.
__class__
.
__name__
!=
"BasePlugin"
:
engine_args
[
"limit_mm_per_prompt"
]
=
{
"image"
:
4
,
"video"
:
2
}
if
isinstance
(
model_args
.
vllm_config
,
dict
):
engine_args
.
update
(
model_args
.
vllm_config
)
results
=
LLM
(
**
engine_args
).
generate
(
inputs
,
sampling_params
,
lora_request
=
lora_request
)
preds
=
[
result
.
outputs
[
0
].
text
for
result
in
results
]
with
open
(
save_name
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
for
text
,
pred
,
label
in
zip
(
prompts
,
preds
,
labels
):
f
.
write
(
json
.
dumps
({
"prompt"
:
text
,
"predict"
:
pred
,
"label"
:
label
},
ensure_ascii
=
False
)
+
"
\n
"
)
print
(
"*"
*
70
)
print
(
f
"
{
len
(
prompts
)
}
generated results have been saved at
{
save_name
}
."
)
print
(
"*"
*
70
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
vllm_infer
)
LLaMA-Factory/setup.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
from
typing
import
List
from
setuptools
import
find_packages
,
setup
def
get_version
()
->
str
:
with
open
(
os
.
path
.
join
(
"src"
,
"llamafactory"
,
"extras"
,
"env.py"
),
encoding
=
"utf-8"
)
as
f
:
file_content
=
f
.
read
()
pattern
=
r
"{}\W*=\W*\"([^\"]+)\""
.
format
(
"VERSION"
)
(
version
,)
=
re
.
findall
(
pattern
,
file_content
)
return
version
def
get_requires
()
->
List
[
str
]:
with
open
(
"requirements.txt"
,
encoding
=
"utf-8"
)
as
f
:
file_content
=
f
.
read
()
lines
=
[
line
.
strip
()
for
line
in
file_content
.
strip
().
split
(
"
\n
"
)
if
not
line
.
startswith
(
"#"
)]
return
lines
def
get_console_scripts
()
->
List
[
str
]:
console_scripts
=
[
"llamafactory-cli = llamafactory.cli:main"
]
if
os
.
environ
.
get
(
"ENABLE_SHORT_CONSOLE"
,
"1"
).
lower
()
in
[
"true"
,
"1"
]:
console_scripts
.
append
(
"lmf = llamafactory.cli:main"
)
return
console_scripts
extra_require
=
{
"torch"
:
[
"torch>=1.13.1"
],
"torch-npu"
:
[
"torch==2.1.0"
,
"torch-npu==2.1.0.post3"
,
"decorator"
],
"metrics"
:
[
"nltk"
,
"jieba"
,
"rouge-chinese"
],
"deepspeed"
:
[
"deepspeed>=0.10.0,<=0.14.4"
],
"liger-kernel"
:
[
"liger-kernel"
],
"bitsandbytes"
:
[
"bitsandbytes>=0.39.0"
],
"hqq"
:
[
"hqq"
],
"eetq"
:
[
"eetq"
],
"gptq"
:
[
"optimum>=1.17.0"
,
"auto-gptq>=0.5.0"
],
"awq"
:
[
"autoawq"
],
"aqlm"
:
[
"aqlm[gpu]>=1.1.0"
],
"vllm"
:
[
"vllm>=0.4.3,<0.6.5"
],
"galore"
:
[
"galore-torch"
],
"badam"
:
[
"badam>=1.2.1"
],
"adam-mini"
:
[
"adam-mini"
],
"qwen"
:
[
"transformers_stream_generator"
],
"modelscope"
:
[
"modelscope"
],
"openmind"
:
[
"openmind"
],
"swanlab"
:
[
"swanlab"
],
"dev"
:
[
"pre-commit"
,
"ruff"
,
"pytest"
],
}
def
main
():
setup
(
name
=
"llamafactory"
,
version
=
get_version
(),
author
=
"hiyouga"
,
author_email
=
"hiyouga"
"@"
"buaa.edu.cn"
,
description
=
"Easy-to-use LLM fine-tuning framework"
,
long_description
=
open
(
"README.md"
,
encoding
=
"utf-8"
).
read
(),
long_description_content_type
=
"text/markdown"
,
keywords
=
[
"LLaMA"
,
"BLOOM"
,
"Falcon"
,
"LLM"
,
"ChatGPT"
,
"transformer"
,
"pytorch"
,
"deep learning"
],
license
=
"Apache 2.0 License"
,
url
=
"https://github.com/hiyouga/LLaMA-Factory"
,
package_dir
=
{
""
:
"src"
},
packages
=
find_packages
(
"src"
),
python_requires
=
">=3.8.0"
,
install_requires
=
get_requires
(),
extras_require
=
extra_require
,
entry_points
=
{
"console_scripts"
:
get_console_scripts
()},
classifiers
=
[
"Development Status :: 4 - Beta"
,
"Intended Audience :: Developers"
,
"Intended Audience :: Education"
,
"Intended Audience :: Science/Research"
,
"License :: OSI Approved :: Apache Software License"
,
"Operating System :: OS Independent"
,
"Programming Language :: Python :: 3"
,
"Programming Language :: Python :: 3.8"
,
"Programming Language :: Python :: 3.9"
,
"Programming Language :: Python :: 3.10"
,
"Programming Language :: Python :: 3.11"
,
"Topic :: Scientific/Engineering :: Artificial Intelligence"
,
],
)
if
__name__
==
"__main__"
:
main
()
LLaMA-Factory/src/api.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
uvicorn
from
llamafactory.api.app
import
create_app
from
llamafactory.chat
import
ChatModel
def
main
():
chat_model
=
ChatModel
()
app
=
create_app
(
chat_model
)
api_host
=
os
.
getenv
(
"API_HOST"
,
"0.0.0.0"
)
api_port
=
int
(
os
.
getenv
(
"API_PORT"
,
"8000"
))
print
(
f
"Visit http://localhost:
{
api_port
}
/docs for API document."
)
uvicorn
.
run
(
app
,
host
=
api_host
,
port
=
api_port
)
if
__name__
==
"__main__"
:
main
()
LLaMA-Factory/src/llamafactory/__init__.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""
Efficient fine-tuning of large language models.
Level:
api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
main:
transformers>=4.41.2,<=4.46.1
datasets>=2.16.0,<=3.1.0
accelerate>=0.34.0,<=1.0.1
peft>=0.11.1,<=0.12.0
trl>=0.8.6,<=0.9.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<=4.46.1
packing:
transformers>=4.43.0,<=4.46.1
Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1
Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1
Use openmind: USE_OPENMIND_HUB=1
"""
from
.extras.env
import
VERSION
__version__
=
VERSION
LLaMA-Factory/src/llamafactory/api/__init__.py
0 → 100644
View file @
53b3977b
LLaMA-Factory/src/llamafactory/api/app.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
import
os
from
contextlib
import
asynccontextmanager
from
functools
import
partial
from
typing
import
Optional
from
typing_extensions
import
Annotated
from
..chat
import
ChatModel
from
..extras.misc
import
torch_gc
from
..extras.packages
import
is_fastapi_available
,
is_starlette_available
,
is_uvicorn_available
from
.chat
import
(
create_chat_completion_response
,
create_score_evaluation_response
,
create_stream_chat_completion_response
,
)
from
.protocol
import
(
ChatCompletionRequest
,
ChatCompletionResponse
,
ModelCard
,
ModelList
,
ScoreEvaluationRequest
,
ScoreEvaluationResponse
,
)
if
is_fastapi_available
():
from
fastapi
import
Depends
,
FastAPI
,
HTTPException
,
status
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.security.http
import
HTTPAuthorizationCredentials
,
HTTPBearer
if
is_starlette_available
():
from
sse_starlette
import
EventSourceResponse
if
is_uvicorn_available
():
import
uvicorn
async
def
sweeper
()
->
None
:
while
True
:
torch_gc
()
await
asyncio
.
sleep
(
300
)
@
asynccontextmanager
async
def
lifespan
(
app
:
"FastAPI"
,
chat_model
:
"ChatModel"
):
# collects GPU memory
if
chat_model
.
engine_type
==
"huggingface"
:
asyncio
.
create_task
(
sweeper
())
yield
torch_gc
()
def
create_app
(
chat_model
:
"ChatModel"
)
->
"FastAPI"
:
root_path
=
os
.
getenv
(
"FASTAPI_ROOT_PATH"
,
""
)
app
=
FastAPI
(
lifespan
=
partial
(
lifespan
,
chat_model
=
chat_model
),
root_path
=
root_path
)
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
[
"*"
],
allow_credentials
=
True
,
allow_methods
=
[
"*"
],
allow_headers
=
[
"*"
],
)
api_key
=
os
.
getenv
(
"API_KEY"
)
security
=
HTTPBearer
(
auto_error
=
False
)
async
def
verify_api_key
(
auth
:
Annotated
[
Optional
[
HTTPAuthorizationCredentials
],
Depends
(
security
)]):
if
api_key
and
(
auth
is
None
or
auth
.
credentials
!=
api_key
):
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
"Invalid API key."
)
@
app
.
get
(
"/v1/models"
,
response_model
=
ModelList
,
status_code
=
status
.
HTTP_200_OK
,
dependencies
=
[
Depends
(
verify_api_key
)],
)
async
def
list_models
():
model_card
=
ModelCard
(
id
=
os
.
getenv
(
"API_MODEL_NAME"
,
"gpt-3.5-turbo"
))
return
ModelList
(
data
=
[
model_card
])
@
app
.
post
(
"/v1/chat/completions"
,
response_model
=
ChatCompletionResponse
,
status_code
=
status
.
HTTP_200_OK
,
dependencies
=
[
Depends
(
verify_api_key
)],
)
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
):
if
not
chat_model
.
engine
.
can_generate
:
raise
HTTPException
(
status_code
=
status
.
HTTP_405_METHOD_NOT_ALLOWED
,
detail
=
"Not allowed"
)
if
request
.
stream
:
generate
=
create_stream_chat_completion_response
(
request
,
chat_model
)
return
EventSourceResponse
(
generate
,
media_type
=
"text/event-stream"
)
else
:
return
await
create_chat_completion_response
(
request
,
chat_model
)
@
app
.
post
(
"/v1/score/evaluation"
,
response_model
=
ScoreEvaluationResponse
,
status_code
=
status
.
HTTP_200_OK
,
dependencies
=
[
Depends
(
verify_api_key
)],
)
async
def
create_score_evaluation
(
request
:
ScoreEvaluationRequest
):
if
chat_model
.
engine
.
can_generate
:
raise
HTTPException
(
status_code
=
status
.
HTTP_405_METHOD_NOT_ALLOWED
,
detail
=
"Not allowed"
)
return
await
create_score_evaluation_response
(
request
,
chat_model
)
return
app
def
run_api
()
->
None
:
chat_model
=
ChatModel
()
app
=
create_app
(
chat_model
)
api_host
=
os
.
getenv
(
"API_HOST"
,
"0.0.0.0"
)
api_port
=
int
(
os
.
getenv
(
"API_PORT"
,
"8000"
))
print
(
f
"Visit http://localhost:
{
api_port
}
/docs for API document."
)
uvicorn
.
run
(
app
,
host
=
api_host
,
port
=
api_port
)
LLaMA-Factory/src/llamafactory/api/chat.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
base64
import
io
import
json
import
os
import
re
import
uuid
from
typing
import
TYPE_CHECKING
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
from
..data
import
Role
as
DataRole
from
..extras
import
logging
from
..extras.packages
import
is_fastapi_available
,
is_pillow_available
,
is_requests_available
from
.common
import
dictify
,
jsonify
from
.protocol
import
(
ChatCompletionMessage
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseUsage
,
ChatCompletionStreamResponse
,
ChatCompletionStreamResponseChoice
,
Finish
,
Function
,
FunctionCall
,
Role
,
ScoreEvaluationResponse
,
)
if
is_fastapi_available
():
from
fastapi
import
HTTPException
,
status
if
is_pillow_available
():
from
PIL
import
Image
if
is_requests_available
():
import
requests
if
TYPE_CHECKING
:
from
..chat
import
ChatModel
from
..data.mm_plugin
import
ImageInput
from
.protocol
import
ChatCompletionRequest
,
ScoreEvaluationRequest
logger
=
logging
.
get_logger
(
__name__
)
ROLE_MAPPING
=
{
Role
.
USER
:
DataRole
.
USER
.
value
,
Role
.
ASSISTANT
:
DataRole
.
ASSISTANT
.
value
,
Role
.
SYSTEM
:
DataRole
.
SYSTEM
.
value
,
Role
.
FUNCTION
:
DataRole
.
FUNCTION
.
value
,
Role
.
TOOL
:
DataRole
.
OBSERVATION
.
value
,
}
def
_process_request
(
request
:
"ChatCompletionRequest"
,
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
Optional
[
str
],
Optional
[
str
],
Optional
[
List
[
"ImageInput"
]]]:
logger
.
info_rank0
(
f
"==== request ====
\n
{
json
.
dumps
(
dictify
(
request
),
indent
=
2
,
ensure_ascii
=
False
)
}
"
)
if
len
(
request
.
messages
)
==
0
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid length"
)
if
request
.
messages
[
0
].
role
==
Role
.
SYSTEM
:
system
=
request
.
messages
.
pop
(
0
).
content
else
:
system
=
None
if
len
(
request
.
messages
)
%
2
==
0
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Only supports u/a/u/a/u..."
)
input_messages
=
[]
images
=
[]
for
i
,
message
in
enumerate
(
request
.
messages
):
if
i
%
2
==
0
and
message
.
role
not
in
[
Role
.
USER
,
Role
.
TOOL
]:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid role"
)
elif
i
%
2
==
1
and
message
.
role
not
in
[
Role
.
ASSISTANT
,
Role
.
FUNCTION
]:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid role"
)
if
message
.
role
==
Role
.
ASSISTANT
and
isinstance
(
message
.
tool_calls
,
list
)
and
len
(
message
.
tool_calls
):
tool_calls
=
[
{
"name"
:
tool_call
.
function
.
name
,
"arguments"
:
tool_call
.
function
.
arguments
}
for
tool_call
in
message
.
tool_calls
]
content
=
json
.
dumps
(
tool_calls
,
ensure_ascii
=
False
)
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
Role
.
FUNCTION
],
"content"
:
content
})
elif
isinstance
(
message
.
content
,
list
):
for
input_item
in
message
.
content
:
if
input_item
.
type
==
"text"
:
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
input_item
.
text
})
else
:
image_url
=
input_item
.
image_url
.
url
if
re
.
match
(
r
"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$"
,
image_url
):
# base64 image
image_stream
=
io
.
BytesIO
(
base64
.
b64decode
(
image_url
.
split
(
","
,
maxsplit
=
1
)[
1
]))
elif
os
.
path
.
isfile
(
image_url
):
# local file
image_stream
=
open
(
image_url
,
"rb"
)
else
:
# web uri
image_stream
=
requests
.
get
(
image_url
,
stream
=
True
).
raw
images
.
append
(
Image
.
open
(
image_stream
).
convert
(
"RGB"
))
else
:
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
message
.
content
})
tool_list
=
request
.
tools
if
isinstance
(
tool_list
,
list
)
and
len
(
tool_list
):
try
:
tools
=
json
.
dumps
([
dictify
(
tool
.
function
)
for
tool
in
tool_list
],
ensure_ascii
=
False
)
except
json
.
JSONDecodeError
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid tools"
)
else
:
tools
=
None
return
input_messages
,
system
,
tools
,
images
or
None
def
_create_stream_chat_completion_chunk
(
completion_id
:
str
,
model
:
str
,
delta
:
"ChatCompletionMessage"
,
index
:
Optional
[
int
]
=
0
,
finish_reason
:
Optional
[
"Finish"
]
=
None
,
)
->
str
:
choice_data
=
ChatCompletionStreamResponseChoice
(
index
=
index
,
delta
=
delta
,
finish_reason
=
finish_reason
)
chunk
=
ChatCompletionStreamResponse
(
id
=
completion_id
,
model
=
model
,
choices
=
[
choice_data
])
return
jsonify
(
chunk
)
async
def
create_chat_completion_response
(
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
)
->
"ChatCompletionResponse"
:
completion_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
input_messages
,
system
,
tools
,
images
=
_process_request
(
request
)
responses
=
await
chat_model
.
achat
(
input_messages
,
system
,
tools
,
images
,
do_sample
=
request
.
do_sample
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
max_new_tokens
=
request
.
max_tokens
,
num_return_sequences
=
request
.
n
,
stop
=
request
.
stop
,
)
prompt_length
,
response_length
=
0
,
0
choices
=
[]
for
i
,
response
in
enumerate
(
responses
):
if
tools
:
result
=
chat_model
.
engine
.
template
.
extract_tool
(
response
.
response_text
)
else
:
result
=
response
.
response_text
if
isinstance
(
result
,
list
):
tool_calls
=
[]
for
tool
in
result
:
function
=
Function
(
name
=
tool
.
name
,
arguments
=
tool
.
arguments
)
tool_calls
.
append
(
FunctionCall
(
id
=
f
"call_
{
uuid
.
uuid4
().
hex
}
"
,
function
=
function
))
response_message
=
ChatCompletionMessage
(
role
=
Role
.
ASSISTANT
,
tool_calls
=
tool_calls
)
finish_reason
=
Finish
.
TOOL
else
:
response_message
=
ChatCompletionMessage
(
role
=
Role
.
ASSISTANT
,
content
=
result
)
finish_reason
=
Finish
.
STOP
if
response
.
finish_reason
==
"stop"
else
Finish
.
LENGTH
choices
.
append
(
ChatCompletionResponseChoice
(
index
=
i
,
message
=
response_message
,
finish_reason
=
finish_reason
))
prompt_length
=
response
.
prompt_length
response_length
+=
response
.
response_length
usage
=
ChatCompletionResponseUsage
(
prompt_tokens
=
prompt_length
,
completion_tokens
=
response_length
,
total_tokens
=
prompt_length
+
response_length
,
)
return
ChatCompletionResponse
(
id
=
completion_id
,
model
=
request
.
model
,
choices
=
choices
,
usage
=
usage
)
async
def
create_stream_chat_completion_response
(
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
)
->
AsyncGenerator
[
str
,
None
]:
completion_id
=
f
"chatcmpl-
{
uuid
.
uuid4
().
hex
}
"
input_messages
,
system
,
tools
,
images
=
_process_request
(
request
)
if
tools
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Cannot stream function calls."
)
if
request
.
n
>
1
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Cannot stream multiple responses."
)
yield
_create_stream_chat_completion_chunk
(
completion_id
=
completion_id
,
model
=
request
.
model
,
delta
=
ChatCompletionMessage
(
role
=
Role
.
ASSISTANT
,
content
=
""
)
)
async
for
new_token
in
chat_model
.
astream_chat
(
input_messages
,
system
,
tools
,
images
,
do_sample
=
request
.
do_sample
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
max_new_tokens
=
request
.
max_tokens
,
stop
=
request
.
stop
,
):
if
len
(
new_token
)
!=
0
:
yield
_create_stream_chat_completion_chunk
(
completion_id
=
completion_id
,
model
=
request
.
model
,
delta
=
ChatCompletionMessage
(
content
=
new_token
)
)
yield
_create_stream_chat_completion_chunk
(
completion_id
=
completion_id
,
model
=
request
.
model
,
delta
=
ChatCompletionMessage
(),
finish_reason
=
Finish
.
STOP
)
yield
"[DONE]"
async
def
create_score_evaluation_response
(
request
:
"ScoreEvaluationRequest"
,
chat_model
:
"ChatModel"
)
->
"ScoreEvaluationResponse"
:
score_id
=
f
"scoreval-
{
uuid
.
uuid4
().
hex
}
"
if
len
(
request
.
messages
)
==
0
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid request"
)
scores
=
await
chat_model
.
aget_scores
(
request
.
messages
,
max_length
=
request
.
max_length
)
return
ScoreEvaluationResponse
(
id
=
score_id
,
model
=
request
.
model
,
scores
=
scores
)
LLaMA-Factory/src/llamafactory/api/common.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
if
TYPE_CHECKING
:
from
pydantic
import
BaseModel
def
dictify
(
data
:
"BaseModel"
)
->
Dict
[
str
,
Any
]:
try
:
# pydantic v2
return
data
.
model_dump
(
exclude_unset
=
True
)
except
AttributeError
:
# pydantic v1
return
data
.
dict
(
exclude_unset
=
True
)
def
jsonify
(
data
:
"BaseModel"
)
->
str
:
try
:
# pydantic v2
return
json
.
dumps
(
data
.
model_dump
(
exclude_unset
=
True
),
ensure_ascii
=
False
)
except
AttributeError
:
# pydantic v1
return
data
.
json
(
exclude_unset
=
True
,
ensure_ascii
=
False
)
LLaMA-Factory/src/llamafactory/api/protocol.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
time
from
enum
import
Enum
,
unique
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
from
typing_extensions
import
Literal
@
unique
class
Role
(
str
,
Enum
):
USER
=
"user"
ASSISTANT
=
"assistant"
SYSTEM
=
"system"
FUNCTION
=
"function"
TOOL
=
"tool"
@
unique
class
Finish
(
str
,
Enum
):
STOP
=
"stop"
LENGTH
=
"length"
TOOL
=
"tool_calls"
class
ModelCard
(
BaseModel
):
id
:
str
object
:
Literal
[
"model"
]
=
"model"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
owned_by
:
Literal
[
"owner"
]
=
"owner"
class
ModelList
(
BaseModel
):
object
:
Literal
[
"list"
]
=
"list"
data
:
List
[
ModelCard
]
=
[]
class
Function
(
BaseModel
):
name
:
str
arguments
:
str
class
FunctionDefinition
(
BaseModel
):
name
:
str
description
:
str
parameters
:
Dict
[
str
,
Any
]
class
FunctionAvailable
(
BaseModel
):
type
:
Literal
[
"function"
,
"code_interpreter"
]
=
"function"
function
:
Optional
[
FunctionDefinition
]
=
None
class
FunctionCall
(
BaseModel
):
id
:
str
type
:
Literal
[
"function"
]
=
"function"
function
:
Function
class
ImageURL
(
BaseModel
):
url
:
str
class
MultimodalInputItem
(
BaseModel
):
type
:
Literal
[
"text"
,
"image_url"
]
text
:
Optional
[
str
]
=
None
image_url
:
Optional
[
ImageURL
]
=
None
class
ChatMessage
(
BaseModel
):
role
:
Role
content
:
Optional
[
Union
[
str
,
List
[
MultimodalInputItem
]]]
=
None
tool_calls
:
Optional
[
List
[
FunctionCall
]]
=
None
class
ChatCompletionMessage
(
BaseModel
):
role
:
Optional
[
Role
]
=
None
content
:
Optional
[
str
]
=
None
tool_calls
:
Optional
[
List
[
FunctionCall
]]
=
None
class
ChatCompletionRequest
(
BaseModel
):
model
:
str
messages
:
List
[
ChatMessage
]
tools
:
Optional
[
List
[
FunctionAvailable
]]
=
None
do_sample
:
Optional
[
bool
]
=
None
temperature
:
Optional
[
float
]
=
None
top_p
:
Optional
[
float
]
=
None
n
:
int
=
1
max_tokens
:
Optional
[
int
]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
stream
:
bool
=
False
class
ChatCompletionResponseChoice
(
BaseModel
):
index
:
int
message
:
ChatCompletionMessage
finish_reason
:
Finish
class
ChatCompletionStreamResponseChoice
(
BaseModel
):
index
:
int
delta
:
ChatCompletionMessage
finish_reason
:
Optional
[
Finish
]
=
None
class
ChatCompletionResponseUsage
(
BaseModel
):
prompt_tokens
:
int
completion_tokens
:
int
total_tokens
:
int
class
ChatCompletionResponse
(
BaseModel
):
id
:
str
object
:
Literal
[
"chat.completion"
]
=
"chat.completion"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
ChatCompletionResponseChoice
]
usage
:
ChatCompletionResponseUsage
class
ChatCompletionStreamResponse
(
BaseModel
):
id
:
str
object
:
Literal
[
"chat.completion.chunk"
]
=
"chat.completion.chunk"
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
ChatCompletionStreamResponseChoice
]
class
ScoreEvaluationRequest
(
BaseModel
):
model
:
str
messages
:
List
[
str
]
max_length
:
Optional
[
int
]
=
None
class
ScoreEvaluationResponse
(
BaseModel
):
id
:
str
object
:
Literal
[
"score.evaluation"
]
=
"score.evaluation"
model
:
str
scores
:
List
[
float
]
LLaMA-Factory/src/llamafactory/chat/__init__.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.base_engine
import
BaseEngine
from
.chat_model
import
ChatModel
__all__
=
[
"BaseEngine"
,
"ChatModel"
]
LLaMA-Factory/src/llamafactory/chat/base_engine.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Dict
,
List
,
Literal
,
Optional
,
Sequence
,
Union
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
from
vllm
import
AsyncLLMEngine
from
..data
import
Template
from
..data.mm_plugin
import
ImageInput
,
VideoInput
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
@
dataclass
class
Response
:
response_text
:
str
response_length
:
int
prompt_length
:
int
finish_reason
:
Literal
[
"stop"
,
"length"
]
class
BaseEngine
(
ABC
):
r
"""
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
"""
model
:
Union
[
"PreTrainedModel"
,
"AsyncLLMEngine"
]
tokenizer
:
"PreTrainedTokenizer"
can_generate
:
bool
template
:
"Template"
generating_args
:
Dict
[
str
,
Any
]
@
abstractmethod
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
r
"""
Initializes an inference engine.
"""
...
@
abstractmethod
async
def
chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
r
"""
Gets a list of responses of the chat model.
"""
...
@
abstractmethod
async
def
stream_chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
r
"""
Gets the response token-by-token of the chat model.
"""
...
@
abstractmethod
async
def
get_scores
(
self
,
batch_input
:
List
[
str
],
**
input_kwargs
,
)
->
List
[
float
]:
r
"""
Gets a list of scores of the reward model.
"""
...
LLaMA-Factory/src/llamafactory/chat/chat_model.py
0 → 100644
View file @
53b3977b
# Copyright 2024 THUDM and the LlamaFactory team.
#
# This code is inspired by the THUDM's ChatGLM implementation.
# https://github.com/THUDM/ChatGLM-6B/blob/main/cli_demo.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
import
os
from
threading
import
Thread
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Dict
,
Generator
,
List
,
Optional
,
Sequence
from
..extras.misc
import
torch_gc
from
..hparams
import
get_infer_args
from
.hf_engine
import
HuggingfaceEngine
from
.vllm_engine
import
VllmEngine
if
TYPE_CHECKING
:
from
..data.mm_plugin
import
ImageInput
,
VideoInput
from
.base_engine
import
BaseEngine
,
Response
def
_start_background_loop
(
loop
:
"asyncio.AbstractEventLoop"
)
->
None
:
asyncio
.
set_event_loop
(
loop
)
loop
.
run_forever
()
class
ChatModel
:
r
"""
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def
__init__
(
self
,
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
None
:
model_args
,
data_args
,
finetuning_args
,
generating_args
=
get_infer_args
(
args
)
self
.
engine_type
=
model_args
.
infer_backend
if
model_args
.
infer_backend
==
"huggingface"
:
self
.
engine
:
"BaseEngine"
=
HuggingfaceEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
elif
model_args
.
infer_backend
==
"vllm"
:
self
.
engine
:
"BaseEngine"
=
VllmEngine
(
model_args
,
data_args
,
finetuning_args
,
generating_args
)
else
:
raise
NotImplementedError
(
f
"Unknown backend:
{
model_args
.
infer_backend
}
"
)
self
.
_loop
=
asyncio
.
new_event_loop
()
self
.
_thread
=
Thread
(
target
=
_start_background_loop
,
args
=
(
self
.
_loop
,),
daemon
=
True
)
self
.
_thread
.
start
()
def
chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
r
"""
Gets a list of responses of the chat model.
"""
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
achat
(
messages
,
system
,
tools
,
images
,
videos
,
**
input_kwargs
),
self
.
_loop
)
return
task
.
result
()
async
def
achat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
r
"""
Asynchronously gets a list of responses of the chat model.
"""
return
await
self
.
engine
.
chat
(
messages
,
system
,
tools
,
images
,
videos
,
**
input_kwargs
)
def
stream_chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
**
input_kwargs
,
)
->
Generator
[
str
,
None
,
None
]:
r
"""
Gets the response token-by-token of the chat model.
"""
generator
=
self
.
astream_chat
(
messages
,
system
,
tools
,
images
,
videos
,
**
input_kwargs
)
while
True
:
try
:
task
=
asyncio
.
run_coroutine_threadsafe
(
generator
.
__anext__
(),
self
.
_loop
)
yield
task
.
result
()
except
StopAsyncIteration
:
break
async
def
astream_chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
r
"""
Asynchronously gets the response token-by-token of the chat model.
"""
async
for
new_token
in
self
.
engine
.
stream_chat
(
messages
,
system
,
tools
,
images
,
videos
,
**
input_kwargs
):
yield
new_token
def
get_scores
(
self
,
batch_input
:
List
[
str
],
**
input_kwargs
,
)
->
List
[
float
]:
r
"""
Gets a list of scores of the reward model.
"""
task
=
asyncio
.
run_coroutine_threadsafe
(
self
.
aget_scores
(
batch_input
,
**
input_kwargs
),
self
.
_loop
)
return
task
.
result
()
async
def
aget_scores
(
self
,
batch_input
:
List
[
str
],
**
input_kwargs
,
)
->
List
[
float
]:
r
"""
Asynchronously gets a list of scores of the reward model.
"""
return
await
self
.
engine
.
get_scores
(
batch_input
,
**
input_kwargs
)
def
run_chat
()
->
None
:
if
os
.
name
!=
"nt"
:
try
:
import
readline
# noqa: F401
except
ImportError
:
print
(
"Install `readline` for a better experience."
)
chat_model
=
ChatModel
()
messages
=
[]
print
(
"Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application."
)
while
True
:
try
:
query
=
input
(
"
\n
User: "
)
except
UnicodeDecodeError
:
print
(
"Detected decoding error at the inputs, please set the terminal encoding to utf-8."
)
continue
except
Exception
:
raise
if
query
.
strip
()
==
"exit"
:
break
if
query
.
strip
()
==
"clear"
:
messages
=
[]
torch_gc
()
print
(
"History has been removed."
)
continue
messages
.
append
({
"role"
:
"user"
,
"content"
:
query
})
print
(
"Assistant: "
,
end
=
""
,
flush
=
True
)
response
=
""
for
new_text
in
chat_model
.
stream_chat
(
messages
):
print
(
new_text
,
end
=
""
,
flush
=
True
)
response
+=
new_text
print
()
messages
.
append
({
"role"
:
"assistant"
,
"content"
:
response
})
LLaMA-Factory/src/llamafactory/chat/hf_engine.py
0 → 100644
View file @
53b3977b
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
asyncio
import
concurrent.futures
import
os
from
threading
import
Thread
from
typing
import
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
import
torch
from
transformers
import
GenerationConfig
,
TextIteratorStreamer
from
typing_extensions
import
override
from
..data
import
get_template_and_fix_tokenizer
from
..extras
import
logging
from
..extras.constants
import
IMAGE_PLACEHOLDER
,
VIDEO_PLACEHOLDER
from
..extras.misc
import
get_logits_processor
from
..model
import
load_model
,
load_tokenizer
from
.base_engine
import
BaseEngine
,
Response
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
,
PreTrainedTokenizer
,
ProcessorMixin
from
trl
import
PreTrainedModelWrapper
from
..data
import
Template
from
..data.mm_plugin
import
ImageInput
,
VideoInput
from
..hparams
import
DataArguments
,
FinetuningArguments
,
GeneratingArguments
,
ModelArguments
logger
=
logging
.
get_logger
(
__name__
)
class
HuggingfaceEngine
(
BaseEngine
):
def
__init__
(
self
,
model_args
:
"ModelArguments"
,
data_args
:
"DataArguments"
,
finetuning_args
:
"FinetuningArguments"
,
generating_args
:
"GeneratingArguments"
,
)
->
None
:
self
.
can_generate
=
finetuning_args
.
stage
==
"sft"
tokenizer_module
=
load_tokenizer
(
model_args
)
self
.
tokenizer
=
tokenizer_module
[
"tokenizer"
]
self
.
processor
=
tokenizer_module
[
"processor"
]
self
.
tokenizer
.
padding_side
=
"left"
if
self
.
can_generate
else
"right"
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
data_args
)
self
.
model
=
load_model
(
self
.
tokenizer
,
model_args
,
finetuning_args
,
is_trainable
=
False
,
add_valuehead
=
(
not
self
.
can_generate
)
)
# must after fixing tokenizer to resize vocab
self
.
generating_args
=
generating_args
.
to_dict
()
try
:
asyncio
.
get_event_loop
()
except
RuntimeError
:
logger
.
warning_once
(
"There is no current event loop, creating a new one."
)
loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
loop
)
self
.
semaphore
=
asyncio
.
Semaphore
(
int
(
os
.
getenv
(
"MAX_CONCURRENT"
,
"1"
)))
@
staticmethod
def
_process_args
(
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
template
:
"Template"
,
generating_args
:
Dict
[
str
,
Any
],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
Tuple
[
Dict
[
str
,
Any
],
int
]:
mm_input_dict
=
{
"images"
:
[],
"videos"
:
[],
"imglens"
:
[
0
],
"vidlens"
:
[
0
]}
if
images
is
not
None
:
mm_input_dict
.
update
({
"images"
:
images
,
"imglens"
:
[
len
(
images
)]})
if
not
any
(
IMAGE_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
IMAGE_PLACEHOLDER
*
len
(
images
)
+
messages
[
0
][
"content"
]
if
videos
is
not
None
:
mm_input_dict
.
update
({
"videos"
:
videos
,
"vidlens"
:
[
len
(
videos
)]})
if
not
any
(
VIDEO_PLACEHOLDER
in
message
[
"content"
]
for
message
in
messages
):
messages
[
0
][
"content"
]
=
VIDEO_PLACEHOLDER
*
len
(
videos
)
+
messages
[
0
][
"content"
]
messages
=
template
.
mm_plugin
.
process_messages
(
messages
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
processor
)
paired_messages
=
messages
+
[{
"role"
:
"assistant"
,
"content"
:
""
}]
system
=
system
or
generating_args
[
"default_system"
]
prompt_ids
,
_
=
template
.
encode_oneturn
(
tokenizer
,
paired_messages
,
system
,
tools
)
prompt_ids
,
_
=
template
.
mm_plugin
.
process_token_ids
(
prompt_ids
,
None
,
mm_input_dict
[
"images"
],
mm_input_dict
[
"videos"
],
tokenizer
,
processor
)
prompt_length
=
len
(
prompt_ids
)
inputs
=
torch
.
tensor
([
prompt_ids
],
device
=
model
.
device
)
attention_mask
=
torch
.
ones_like
(
inputs
,
dtype
=
torch
.
bool
)
do_sample
:
Optional
[
bool
]
=
input_kwargs
.
pop
(
"do_sample"
,
None
)
temperature
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"temperature"
,
None
)
top_p
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_p"
,
None
)
top_k
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"top_k"
,
None
)
num_return_sequences
:
int
=
input_kwargs
.
pop
(
"num_return_sequences"
,
1
)
repetition_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"repetition_penalty"
,
None
)
length_penalty
:
Optional
[
float
]
=
input_kwargs
.
pop
(
"length_penalty"
,
None
)
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
max_new_tokens
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_new_tokens"
,
None
)
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
input_kwargs
.
pop
(
"stop"
,
None
)
if
stop
is
not
None
:
logger
.
warning_rank0
(
"Stop parameter is not supported by the huggingface engine yet."
)
generating_args
=
generating_args
.
copy
()
generating_args
.
update
(
dict
(
do_sample
=
do_sample
if
do_sample
is
not
None
else
generating_args
[
"do_sample"
],
temperature
=
temperature
if
temperature
is
not
None
else
generating_args
[
"temperature"
],
top_p
=
top_p
if
top_p
is
not
None
else
generating_args
[
"top_p"
],
top_k
=
top_k
if
top_k
is
not
None
else
generating_args
[
"top_k"
],
num_return_sequences
=
num_return_sequences
,
repetition_penalty
=
repetition_penalty
if
repetition_penalty
is
not
None
else
generating_args
[
"repetition_penalty"
],
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
generating_args
[
"length_penalty"
],
eos_token_id
=
[
tokenizer
.
eos_token_id
]
+
tokenizer
.
additional_special_tokens_ids
,
pad_token_id
=
tokenizer
.
pad_token_id
,
)
)
if
isinstance
(
num_return_sequences
,
int
)
and
num_return_sequences
>
1
:
# do_sample needs temperature > 0
generating_args
[
"do_sample"
]
=
True
generating_args
[
"temperature"
]
=
generating_args
[
"temperature"
]
or
1.0
if
not
generating_args
[
"temperature"
]:
generating_args
[
"do_sample"
]
=
False
if
not
generating_args
[
"do_sample"
]:
generating_args
.
pop
(
"temperature"
,
None
)
generating_args
.
pop
(
"top_p"
,
None
)
if
max_length
:
generating_args
.
pop
(
"max_new_tokens"
,
None
)
generating_args
[
"max_length"
]
=
max_length
if
max_new_tokens
:
generating_args
.
pop
(
"max_length"
,
None
)
generating_args
[
"max_new_tokens"
]
=
max_new_tokens
gen_kwargs
=
dict
(
inputs
=
inputs
,
attention_mask
=
attention_mask
,
generation_config
=
GenerationConfig
(
**
generating_args
),
logits_processor
=
get_logits_processor
(),
)
mm_inputs
=
template
.
mm_plugin
.
get_mm_inputs
(
**
mm_input_dict
,
batch_ids
=
[
prompt_ids
],
processor
=
processor
)
for
key
,
value
in
mm_inputs
.
items
():
if
isinstance
(
value
,
list
)
and
all
(
isinstance
(
v
,
torch
.
Tensor
)
for
v
in
value
):
# for pixtral inputs
value
=
torch
.
stack
(
value
)
# assume they have same sizes
elif
not
isinstance
(
value
,
torch
.
Tensor
):
value
=
torch
.
tensor
(
value
)
if
torch
.
is_floating_point
(
value
):
value
=
value
.
to
(
model
.
dtype
)
gen_kwargs
[
key
]
=
value
.
to
(
model
.
device
)
return
gen_kwargs
,
prompt_length
@
staticmethod
@
torch
.
inference_mode
()
def
_chat
(
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
template
:
"Template"
,
generating_args
:
Dict
[
str
,
Any
],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
List
[
"Response"
]:
gen_kwargs
,
prompt_length
=
HuggingfaceEngine
.
_process_args
(
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
images
,
videos
,
input_kwargs
,
)
generate_output
=
model
.
generate
(
**
gen_kwargs
)
response_ids
=
generate_output
[:,
prompt_length
:]
response
=
tokenizer
.
batch_decode
(
response_ids
,
skip_special_tokens
=
generating_args
[
"skip_special_tokens"
],
clean_up_tokenization_spaces
=
True
)
results
=
[]
for
i
in
range
(
len
(
response
)):
eos_index
=
(
response_ids
[
i
]
==
tokenizer
.
eos_token_id
).
nonzero
()
response_length
=
(
eos_index
[
0
].
item
()
+
1
)
if
len
(
eos_index
)
else
len
(
response_ids
[
i
])
results
.
append
(
Response
(
response_text
=
response
[
i
],
response_length
=
response_length
,
prompt_length
=
prompt_length
,
finish_reason
=
"stop"
if
len
(
eos_index
)
else
"length"
,
)
)
return
results
@
staticmethod
@
torch
.
inference_mode
()
def
_stream_chat
(
model
:
"PreTrainedModel"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
template
:
"Template"
,
generating_args
:
Dict
[
str
,
Any
],
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
Callable
[[],
str
]:
gen_kwargs
,
_
=
HuggingfaceEngine
.
_process_args
(
model
,
tokenizer
,
processor
,
template
,
generating_args
,
messages
,
system
,
tools
,
images
,
videos
,
input_kwargs
,
)
streamer
=
TextIteratorStreamer
(
tokenizer
,
skip_prompt
=
True
,
skip_special_tokens
=
generating_args
[
"skip_special_tokens"
]
)
gen_kwargs
[
"streamer"
]
=
streamer
thread
=
Thread
(
target
=
model
.
generate
,
kwargs
=
gen_kwargs
,
daemon
=
True
)
thread
.
start
()
def
stream
():
try
:
return
streamer
.
__next__
()
except
StopIteration
:
raise
StopAsyncIteration
()
return
stream
@
staticmethod
@
torch
.
inference_mode
()
def
_get_scores
(
model
:
"PreTrainedModelWrapper"
,
tokenizer
:
"PreTrainedTokenizer"
,
batch_input
:
List
[
str
],
input_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
{},
)
->
List
[
float
]:
max_length
:
Optional
[
int
]
=
input_kwargs
.
pop
(
"max_length"
,
None
)
device
=
getattr
(
model
.
pretrained_model
,
"device"
,
"cuda"
)
inputs
:
Dict
[
str
,
"torch.Tensor"
]
=
tokenizer
(
batch_input
,
padding
=
True
,
truncation
=
True
,
max_length
=
max_length
or
getattr
(
model
.
config
,
"max_position_embeddings"
,
1024
),
return_tensors
=
"pt"
,
add_special_tokens
=
False
,
).
to
(
device
)
values
:
"torch.Tensor"
=
model
(
**
inputs
,
return_dict
=
True
,
use_cache
=
False
)[
-
1
]
scores
=
values
.
gather
(
dim
=-
1
,
index
=
(
inputs
[
"attention_mask"
].
sum
(
dim
=-
1
,
keepdim
=
True
)
-
1
))
return
scores
@
override
async
def
chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
**
input_kwargs
,
)
->
List
[
"Response"
]:
if
not
self
.
can_generate
:
raise
ValueError
(
"The current model does not support `chat`."
)
loop
=
asyncio
.
get_running_loop
()
input_args
=
(
self
.
model
,
self
.
tokenizer
,
self
.
processor
,
self
.
template
,
self
.
generating_args
,
messages
,
system
,
tools
,
images
,
videos
,
input_kwargs
,
)
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
return
await
loop
.
run_in_executor
(
pool
,
self
.
_chat
,
*
input_args
)
@
override
async
def
stream_chat
(
self
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
images
:
Optional
[
Sequence
[
"ImageInput"
]]
=
None
,
videos
:
Optional
[
Sequence
[
"VideoInput"
]]
=
None
,
**
input_kwargs
,
)
->
AsyncGenerator
[
str
,
None
]:
if
not
self
.
can_generate
:
raise
ValueError
(
"The current model does not support `stream_chat`."
)
loop
=
asyncio
.
get_running_loop
()
input_args
=
(
self
.
model
,
self
.
tokenizer
,
self
.
processor
,
self
.
template
,
self
.
generating_args
,
messages
,
system
,
tools
,
images
,
videos
,
input_kwargs
,
)
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
stream
=
self
.
_stream_chat
(
*
input_args
)
while
True
:
try
:
yield
await
loop
.
run_in_executor
(
pool
,
stream
)
except
StopAsyncIteration
:
break
@
override
async
def
get_scores
(
self
,
batch_input
:
List
[
str
],
**
input_kwargs
,
)
->
List
[
float
]:
if
self
.
can_generate
:
raise
ValueError
(
"Cannot get scores using an auto-regressive model."
)
loop
=
asyncio
.
get_running_loop
()
input_args
=
(
self
.
model
,
self
.
tokenizer
,
batch_input
,
input_kwargs
)
async
with
self
.
semaphore
:
with
concurrent
.
futures
.
ThreadPoolExecutor
()
as
pool
:
return
await
loop
.
run_in_executor
(
pool
,
self
.
_get_scores
,
*
input_args
)
Prev
1
…
3
4
5
6
7
8
9
10
11
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment