Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Qwen2_pytorch
Commits
032b90a1
Commit
032b90a1
authored
Sep 12, 2024
by
luopl
Browse files
init commit
parents
Pipeline
#1684
canceled with stages
Changes
233
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1654 additions
and
0 deletions
+1654
-0
LLaMA-Factory/examples/train_lora/qwen2_lora_sft.yaml
LLaMA-Factory/examples/train_lora/qwen2_lora_sft.yaml
+39
-0
LLaMA-Factory/examples/train_lora/qwen2_lora_sft_ds3.yaml
LLaMA-Factory/examples/train_lora/qwen2_lora_sft_ds3.yaml
+40
-0
LLaMA-Factory/pyproject.toml
LLaMA-Factory/pyproject.toml
+33
-0
LLaMA-Factory/requirements.txt
LLaMA-Factory/requirements.txt
+21
-0
LLaMA-Factory/scripts/cal_flops.py
LLaMA-Factory/scripts/cal_flops.py
+50
-0
LLaMA-Factory/scripts/cal_lr.py
LLaMA-Factory/scripts/cal_lr.py
+96
-0
LLaMA-Factory/scripts/cal_ppl.py
LLaMA-Factory/scripts/cal_ppl.py
+132
-0
LLaMA-Factory/scripts/length_cdf.py
LLaMA-Factory/scripts/length_cdf.py
+67
-0
LLaMA-Factory/scripts/llama_pro.py
LLaMA-Factory/scripts/llama_pro.py
+131
-0
LLaMA-Factory/scripts/llamafy_baichuan2.py
LLaMA-Factory/scripts/llamafy_baichuan2.py
+106
-0
LLaMA-Factory/scripts/llamafy_qwen.py
LLaMA-Factory/scripts/llamafy_qwen.py
+159
-0
LLaMA-Factory/scripts/loftq_init.py
LLaMA-Factory/scripts/loftq_init.py
+89
-0
LLaMA-Factory/scripts/pissa_init.py
LLaMA-Factory/scripts/pissa_init.py
+87
-0
LLaMA-Factory/scripts/test_toolcall.py
LLaMA-Factory/scripts/test_toolcall.py
+79
-0
LLaMA-Factory/setup.py
LLaMA-Factory/setup.py
+92
-0
LLaMA-Factory/src/api.py
LLaMA-Factory/src/api.py
+33
-0
LLaMA-Factory/src/llamafactory/__init__.py
LLaMA-Factory/src/llamafactory/__init__.py
+41
-0
LLaMA-Factory/src/llamafactory/api/__init__.py
LLaMA-Factory/src/llamafactory/api/__init__.py
+0
-0
LLaMA-Factory/src/llamafactory/api/app.py
LLaMA-Factory/src/llamafactory/api/app.py
+122
-0
LLaMA-Factory/src/llamafactory/api/chat.py
LLaMA-Factory/src/llamafactory/api/chat.py
+237
-0
No files found.
LLaMA-Factory/examples/train_lora/qwen2_lora_sft.yaml
0 → 100644
View file @
032b90a1
### model
model_name_or_path
:
/data/model/Qwen2-72B-Instruct
### method
stage
:
sft
do_train
:
true
finetuning_type
:
lora
lora_target
:
q_proj,v_proj
### dataset
dataset
:
identity,alpaca_en_demo
template
:
qwen
cutoff_len
:
1024
max_samples
:
1000
overwrite_cache
:
true
preprocessing_num_workers
:
16
### output
output_dir
:
saves/qwen2/lora/sft
logging_steps
:
10
save_steps
:
500
plot_loss
:
true
overwrite_output_dir
:
true
### train
per_device_train_batch_size
:
1
gradient_accumulation_steps
:
8
learning_rate
:
1.0e-4
num_train_epochs
:
3.0
lr_scheduler_type
:
cosine
warmup_ratio
:
0.1
bf16
:
true
ddp_timeout
:
180000000
### eval
val_size
:
0.1
per_device_eval_batch_size
:
1
eval_strategy
:
steps
eval_steps
:
500
LLaMA-Factory/examples/train_lora/qwen2_lora_sft_ds3.yaml
0 → 100644
View file @
032b90a1
### model
model_name_or_path
:
/data/model/Qwen2-72B
### method
stage
:
sft
do_train
:
true
finetuning_type
:
lora
lora_target
:
q_proj,v_proj
deepspeed
:
examples/deepspeed/ds_z3_config.json
### dataset
dataset
:
identity,alpaca_zh_demo,alpaca_en_demo
template
:
qwen
cutoff_len
:
1024
max_samples
:
1000
overwrite_cache
:
true
preprocessing_num_workers
:
8
### output
output_dir
:
saves/qwen2_72b/lora/sft/0912
logging_steps
:
10
save_steps
:
500
plot_loss
:
true
overwrite_output_dir
:
true
### train
per_device_train_batch_size
:
1
gradient_accumulation_steps
:
1
learning_rate
:
1.0e-5
num_train_epochs
:
3.0
lr_scheduler_type
:
cosine
warmup_ratio
:
0.1
bf16
:
true
ddp_timeout
:
180000000
### eval
val_size
:
0.1
per_device_eval_batch_size
:
1
eval_strategy
:
steps
eval_steps
:
250
LLaMA-Factory/pyproject.toml
0 → 100644
View file @
032b90a1
[build-system]
requires
=
["setuptools>=61.0"]
build-backend
=
"setuptools.build_meta"
[tool.ruff]
target-version
=
"py38"
line-length
=
119
indent-width
=
4
[tool.ruff.lint]
ignore
=
[
"C408"
,
"C901"
,
"E501"
,
"E731"
,
"E741"
,
"W605"
]
select
=
[
"C"
,
"E"
,
"F"
,
"I"
,
"W"
]
[tool.ruff.lint.isort]
lines-after-imports
=
2
known-first-party
=
["llamafactory"]
known-third-party
=
[
"accelerate"
,
"datasets"
,
"gradio"
,
"numpy"
,
"peft"
,
"torch"
,
"transformers"
,
"trl"
]
[tool.ruff.format]
quote-style
=
"double"
indent-style
=
"space"
docstring-code-format
=
true
skip-magic-trailing-comma
=
false
line-ending
=
"auto"
LLaMA-Factory/requirements.txt
0 → 100644
View file @
032b90a1
transformers>=4.41.2
datasets>=2.16.0
accelerate>=0.30.1
peft>=0.11.1
trl>=0.8.6
gradio>=4.0.0
pandas>=2.0.0
scipy
einops
sentencepiece
tiktoken
protobuf
uvicorn
pydantic
fastapi
sse-starlette
matplotlib>=3.7.0
fire
packaging
pyyaml
numpy<2.0.0
LLaMA-Factory/scripts/cal_flops.py
0 → 100644
View file @
032b90a1
# coding=utf-8
# Copyright 2024 Microsoft Corporation and the LlamaFactory team.
#
# This code is inspired by the Microsoft's DeepSpeed library.
# https://www.deepspeed.ai/tutorials/flops-profiler/
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
fire
import
torch
from
deepspeed.accelerator
import
get_accelerator
# type: ignore
from
deepspeed.profiling.flops_profiler
import
get_model_profile
# type: ignore
from
llamafactory.chat
import
ChatModel
def
calculate_flops
(
model_name_or_path
:
str
,
batch_size
:
int
=
1
,
seq_length
:
int
=
256
,
flash_attn
:
str
=
"auto"
,
):
r
"""
Calculates the flops of pre-trained models.
Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
"""
with
get_accelerator
().
device
(
0
):
chat_model
=
ChatModel
(
dict
(
model_name_or_path
=
model_name_or_path
,
template
=
"empty"
,
flash_attn
=
flash_attn
))
fake_input
=
torch
.
ones
((
batch_size
,
seq_length
),
dtype
=
torch
.
long
,
device
=
chat_model
.
engine
.
model
.
device
)
input_dict
=
{
"input_ids"
:
fake_input
,
"labels"
:
fake_input
.
clone
()}
flops
,
macs
,
params
=
get_model_profile
(
chat_model
.
engine
.
model
,
kwargs
=
input_dict
,
print_profile
=
True
,
detailed
=
True
)
print
(
"FLOPs:"
,
flops
)
print
(
"MACs:"
,
macs
)
print
(
"Params:"
,
params
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
calculate_flops
)
LLaMA-Factory/scripts/cal_lr.py
0 → 100644
View file @
032b90a1
# coding=utf-8
# Copyright 2024 imoneoi and the LlamaFactory team.
#
# This code is inspired by the imoneoi's OpenChat library.
# https://github.com/imoneoi/openchat/blob/3.6.0/ochat/training_deepspeed/train.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
from
typing
import
Literal
import
fire
import
torch
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
transformers
import
DataCollatorForLanguageModeling
,
DataCollatorForSeq2Seq
from
llamafactory.data
import
get_dataset
from
llamafactory.extras.constants
import
IGNORE_INDEX
from
llamafactory.hparams
import
get_train_args
from
llamafactory.model
import
load_tokenizer
BASE_LR
=
3e-4
# 1.5e-4 for 30B-70B models
BASE_BS
=
4_000_000
# from llama paper
def
calculate_lr
(
model_name_or_path
:
str
,
batch_size
:
int
,
# total batch size, namely (batch size * gradient accumulation * world size)
stage
:
Literal
[
"pt"
,
"sft"
]
=
"sft"
,
dataset
:
str
=
"alpaca_en"
,
dataset_dir
:
str
=
"data"
,
template
:
str
=
"default"
,
cutoff_len
:
int
=
1024
,
# i.e. maximum input length during training
is_mistral
:
bool
=
False
,
# mistral model uses a smaller learning rate,
packing
:
bool
=
False
,
):
r
"""
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
"""
model_args
,
data_args
,
training_args
,
_
,
_
=
get_train_args
(
dict
(
stage
=
stage
,
model_name_or_path
=
model_name_or_path
,
dataset
=
dataset
,
dataset_dir
=
dataset_dir
,
template
=
template
,
cutoff_len
=
cutoff_len
,
packing
=
packing
,
output_dir
=
"dummy_dir"
,
overwrite_cache
=
True
,
do_train
=
True
,
)
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
trainset
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
,
**
tokenizer_module
)[
"train_dataset"
]
if
stage
==
"pt"
:
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm
=
False
)
elif
stage
==
"sft"
:
data_collator
=
DataCollatorForSeq2Seq
(
tokenizer
=
tokenizer
,
label_pad_token_id
=
IGNORE_INDEX
)
else
:
raise
NotImplementedError
(
"Stage does not supported: {}."
.
format
(
stage
))
dataloader
=
DataLoader
(
trainset
,
batch_size
,
shuffle
=
False
,
collate_fn
=
data_collator
,
pin_memory
=
True
)
valid_tokens
,
total_tokens
=
0
,
0
for
batch
in
tqdm
(
dataloader
):
valid_tokens
+=
torch
.
sum
(
batch
[
"labels"
]
!=
IGNORE_INDEX
).
item
()
total_tokens
+=
torch
.
numel
(
batch
[
"labels"
])
batch_max_len
=
cutoff_len
*
batch_size
# max tokens in a batch
valid_ratio
=
valid_tokens
/
total_tokens
batch_valid_len
=
batch_max_len
*
valid_ratio
lr
=
BASE_LR
*
math
.
sqrt
(
batch_valid_len
/
BASE_BS
)
# lr ~ sqrt(batch_size)
lr
=
lr
/
6.0
if
is_mistral
else
lr
print
(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}"
.
format
(
lr
,
valid_ratio
*
100
,
batch_valid_len
)
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
calculate_lr
)
LLaMA-Factory/scripts/cal_ppl.py
0 → 100644
View file @
032b90a1
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
Literal
,
Optional
,
Sequence
import
fire
import
torch
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
transformers
import
DataCollatorForLanguageModeling
,
DataCollatorForSeq2Seq
from
llamafactory.data
import
get_dataset
from
llamafactory.extras.constants
import
IGNORE_INDEX
from
llamafactory.hparams
import
get_train_args
from
llamafactory.model
import
load_model
,
load_tokenizer
@
dataclass
class
PairwiseDataCollatorWithPadding
(
DataCollatorForSeq2Seq
):
r
"""
Data collator for pairwise data.
"""
train_on_prompt
:
bool
=
False
def
__call__
(
self
,
features
:
Sequence
[
Dict
[
str
,
Any
]])
->
Dict
[
str
,
torch
.
Tensor
]:
r
"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
chosen_features
=
[]
for
feature
in
features
:
prompt_len
,
answer_len
=
len
(
feature
[
"prompt_ids"
]),
len
(
feature
[
"chosen_ids"
])
input_ids
=
feature
[
"prompt_ids"
]
+
feature
[
"chosen_ids"
]
attention_mask
=
[
1
]
*
(
prompt_len
+
answer_len
)
labels
=
input_ids
if
self
.
train_on_prompt
else
[
IGNORE_INDEX
]
*
prompt_len
+
feature
[
"chosen_ids"
]
chosen_features
.
append
({
"input_ids"
:
input_ids
,
"attention_mask"
:
attention_mask
,
"labels"
:
labels
})
return
super
().
__call__
(
chosen_features
)
def
cal_ppl
(
model_name_or_path
:
str
,
save_name
:
str
,
batch_size
:
int
=
4
,
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
]
=
"sft"
,
dataset
:
str
=
"alpaca_en"
,
dataset_dir
:
str
=
"data"
,
template
:
str
=
"default"
,
cutoff_len
:
int
=
1024
,
max_samples
:
Optional
[
int
]
=
None
,
train_on_prompt
:
bool
=
False
,
):
r
"""
Calculates the ppl on the dataset of the pre-trained models.
Usage: python cal_ppl.py --model_name_or_path path_to_model --save_name ppl.json
"""
model_args
,
data_args
,
training_args
,
finetuning_args
,
_
=
get_train_args
(
dict
(
stage
=
stage
,
model_name_or_path
=
model_name_or_path
,
dataset
=
dataset
,
dataset_dir
=
dataset_dir
,
template
=
template
,
cutoff_len
=
cutoff_len
,
max_samples
=
max_samples
,
train_on_prompt
=
train_on_prompt
,
output_dir
=
"dummy_dir"
,
overwrite_cache
=
True
,
do_train
=
True
,
)
)
tokenizer_module
=
load_tokenizer
(
model_args
)
tokenizer
=
tokenizer_module
[
"tokenizer"
]
trainset
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
,
**
tokenizer_module
)[
"train_dataset"
]
model
=
load_model
(
tokenizer
,
model_args
,
finetuning_args
,
is_trainable
=
False
)
if
stage
==
"pt"
:
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm
=
False
)
elif
stage
==
"sft"
:
data_collator
=
DataCollatorForSeq2Seq
(
tokenizer
=
tokenizer
,
label_pad_token_id
=
IGNORE_INDEX
)
elif
stage
==
"rm"
:
data_collator
=
PairwiseDataCollatorWithPadding
(
tokenizer
=
tokenizer
,
label_pad_token_id
=
IGNORE_INDEX
,
train_on_prompt
=
train_on_prompt
)
else
:
raise
NotImplementedError
(
"Stage does not supported: {}."
.
format
(
stage
))
dataloader
=
DataLoader
(
trainset
,
batch_size
,
shuffle
=
False
,
collate_fn
=
data_collator
,
pin_memory
=
True
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
(
reduction
=
"none"
)
total_ppl
=
0
perplexities
=
[]
batch
:
Dict
[
str
,
"torch.Tensor"
]
with
torch
.
no_grad
():
for
batch
in
tqdm
(
dataloader
):
batch
=
batch
.
to
(
model
.
device
)
outputs
=
model
(
**
batch
)
shift_logits
:
"torch.Tensor"
=
outputs
[
"logits"
][...,
:
-
1
,
:]
shift_labels
:
"torch.Tensor"
=
batch
[
"labels"
][...,
1
:]
loss_mask
=
shift_labels
!=
IGNORE_INDEX
flatten_logits
=
shift_logits
.
contiguous
().
view
(
shift_labels
.
size
(
0
)
*
shift_labels
.
size
(
1
),
-
1
)
flatten_labels
=
shift_labels
.
contiguous
().
view
(
-
1
)
token_logps
:
"torch.Tensor"
=
criterion
(
flatten_logits
,
flatten_labels
)
token_logps
=
token_logps
.
contiguous
().
view
(
shift_logits
.
size
(
0
),
-
1
)
sentence_logps
=
(
token_logps
*
loss_mask
).
sum
(
-
1
)
/
loss_mask
.
sum
(
-
1
)
total_ppl
+=
sentence_logps
.
exp
().
sum
().
item
()
perplexities
.
extend
(
sentence_logps
.
exp
().
tolist
())
with
open
(
save_name
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
perplexities
,
f
,
indent
=
2
)
print
(
"Average perplexity is {:.2f}"
.
format
(
total_ppl
/
len
(
perplexities
)))
print
(
"Perplexities have been saved at {}."
.
format
(
save_name
))
if
__name__
==
"__main__"
:
fire
.
Fire
(
cal_ppl
)
LLaMA-Factory/scripts/length_cdf.py
0 → 100644
View file @
032b90a1
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
import
fire
from
tqdm
import
tqdm
from
llamafactory.data
import
get_dataset
from
llamafactory.hparams
import
get_train_args
from
llamafactory.model
import
load_tokenizer
def
length_cdf
(
model_name_or_path
:
str
,
dataset
:
str
=
"alpaca_en"
,
dataset_dir
:
str
=
"data"
,
template
:
str
=
"default"
,
interval
:
int
=
1000
,
):
r
"""
Calculates the distribution of the input lengths in the dataset.
Usage: python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en --template default
"""
model_args
,
data_args
,
training_args
,
_
,
_
=
get_train_args
(
dict
(
stage
=
"sft"
,
model_name_or_path
=
model_name_or_path
,
dataset
=
dataset
,
dataset_dir
=
dataset_dir
,
template
=
template
,
cutoff_len
=
1_000_000
,
output_dir
=
"dummy_dir"
,
overwrite_cache
=
True
,
do_train
=
True
,
)
)
tokenizer_module
=
load_tokenizer
(
model_args
)
trainset
=
get_dataset
(
model_args
,
data_args
,
training_args
,
stage
=
"sft"
,
**
tokenizer_module
)[
"train_dataset"
]
total_num
=
len
(
trainset
)
length_dict
=
defaultdict
(
int
)
for
sample
in
tqdm
(
trainset
[
"input_ids"
]):
length_dict
[
len
(
sample
)
//
interval
*
interval
]
+=
1
length_tuples
=
list
(
length_dict
.
items
())
length_tuples
.
sort
()
count_accu
,
prob_accu
=
0
,
0
for
length
,
count
in
length_tuples
:
count_accu
+=
count
prob_accu
+=
count
/
total_num
*
100
print
(
"{:d} ({:.2f}%) samples have length < {}."
.
format
(
count_accu
,
prob_accu
,
length
+
interval
))
if
__name__
==
"__main__"
:
fire
.
Fire
(
length_cdf
)
LLaMA-Factory/scripts/llama_pro.py
0 → 100644
View file @
032b90a1
# coding=utf-8
# Copyright 2024 Tencent Inc. and the LlamaFactory team.
#
# This code is inspired by the Tencent's LLaMA-Pro library.
# https://github.com/TencentARC/LLaMA-Pro/blob/main/scripts/block_expansion.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
collections
import
OrderedDict
from
typing
import
TYPE_CHECKING
,
Optional
import
fire
import
torch
from
safetensors.torch
import
save_file
from
tqdm
import
tqdm
from
transformers
import
AutoConfig
,
AutoModelForCausalLM
,
AutoTokenizer
from
transformers.modeling_utils
import
(
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
,
shard_checkpoint
,
)
if
TYPE_CHECKING
:
from
transformers
import
PretrainedConfig
,
PreTrainedModel
def
change_name
(
name
:
str
,
old_index
:
int
,
new_index
:
int
)
->
str
:
return
name
.
replace
(
".{:d}."
.
format
(
old_index
),
".{:d}."
.
format
(
new_index
))
def
block_expansion
(
model_name_or_path
:
str
,
output_dir
:
str
,
num_expand
:
int
,
shard_size
:
Optional
[
str
]
=
"2GB"
,
save_safetensors
:
Optional
[
bool
]
=
False
,
):
r
"""
Performs block expansion for LLaMA, Mistral, Qwen1.5 or Yi models.
Usage: python llama_pro.py --model_name_or_path meta-llama/Llama-2-7b-hf --output_dir llama2_pro --num_expand 8
"""
config
:
"PretrainedConfig"
=
AutoConfig
.
from_pretrained
(
model_name_or_path
)
num_layers
=
getattr
(
config
,
"num_hidden_layers"
)
setattr
(
config
,
"num_hidden_layers"
,
num_layers
+
num_expand
)
config
.
save_pretrained
(
output_dir
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name_or_path
)
tokenizer
.
save_pretrained
(
output_dir
)
config
:
"PretrainedConfig"
=
AutoConfig
.
from_pretrained
(
model_name_or_path
)
# load the original one
if
save_safetensors
:
setattr
(
config
,
"tie_word_embeddings"
,
False
)
# safetensors does not allow shared weights
model
:
"PreTrainedModel"
=
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
,
config
=
config
,
torch_dtype
=
"auto"
,
trust_remote_code
=
True
,
low_cpu_mem_usage
=
True
,
)
state_dict
=
model
.
state_dict
()
if
num_layers
%
num_expand
!=
0
:
raise
ValueError
(
"`num_layers` {} should be divisible by `num_expand` {}."
.
format
(
num_layers
,
num_expand
))
split
=
num_layers
//
num_expand
layer_cnt
=
0
output_state_dict
=
OrderedDict
()
for
i
in
range
(
num_layers
):
for
key
,
value
in
state_dict
.
items
():
if
".{:d}."
.
format
(
i
)
in
key
:
output_state_dict
[
change_name
(
key
,
i
,
layer_cnt
)]
=
value
print
(
"Add layer {} copied from layer {}"
.
format
(
layer_cnt
,
i
))
layer_cnt
+=
1
if
(
i
+
1
)
%
split
==
0
:
for
key
,
value
in
state_dict
.
items
():
if
".{:d}."
.
format
(
i
)
in
key
:
if
"down_proj"
in
key
or
"o_proj"
in
key
:
output_state_dict
[
change_name
(
key
,
i
,
layer_cnt
)]
=
torch
.
zeros_like
(
value
)
else
:
output_state_dict
[
change_name
(
key
,
i
,
layer_cnt
)]
=
torch
.
clone
(
value
)
print
(
"Add layer {} expanded from layer {}"
.
format
(
layer_cnt
,
i
))
layer_cnt
+=
1
for
key
,
value
in
state_dict
.
items
():
if
key
not
in
output_state_dict
:
output_state_dict
[
key
]
=
value
weights_name
=
SAFE_WEIGHTS_NAME
if
save_safetensors
else
WEIGHTS_NAME
shards
,
index
=
shard_checkpoint
(
output_state_dict
,
max_shard_size
=
shard_size
,
weights_name
=
weights_name
)
for
shard_file
,
shard
in
tqdm
(
shards
.
items
(),
desc
=
"Save weights"
):
if
save_safetensors
:
save_file
(
shard
,
os
.
path
.
join
(
output_dir
,
shard_file
),
metadata
=
{
"format"
:
"pt"
})
else
:
torch
.
save
(
shard
,
os
.
path
.
join
(
output_dir
,
shard_file
))
if
index
is
None
:
print
(
"Model weights saved in {}"
.
format
(
os
.
path
.
join
(
output_dir
,
weights_name
)))
else
:
index_name
=
SAFE_WEIGHTS_INDEX_NAME
if
save_safetensors
else
WEIGHTS_INDEX_NAME
with
open
(
os
.
path
.
join
(
output_dir
,
index_name
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
index
,
f
,
indent
=
2
,
sort_keys
=
True
)
print
(
"Model weights saved in {}"
.
format
(
output_dir
))
print
(
"- Fine-tune this model with:"
)
print
(
"model_name_or_path: {}"
.
format
(
output_dir
))
print
(
"finetuning_type: freeze"
)
print
(
"freeze_trainable_layers: {}"
.
format
(
num_expand
))
print
(
"use_llama_pro: true"
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
block_expansion
)
LLaMA-Factory/scripts/llamafy_baichuan2.py
0 → 100644
View file @
032b90a1
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
Optional
import
fire
import
torch
from
safetensors.torch
import
save_file
from
tqdm
import
tqdm
from
transformers.modeling_utils
import
(
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
,
shard_checkpoint
,
)
CONFIG_NAME
=
"config.json"
def
save_weight
(
input_dir
:
str
,
output_dir
:
str
,
shard_size
:
str
,
save_safetensors
:
bool
):
baichuan2_state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
for
filepath
in
tqdm
(
os
.
listdir
(
input_dir
),
desc
=
"Load weights"
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
input_dir
,
filepath
))
and
filepath
.
endswith
(
".bin"
):
shard_weight
=
torch
.
load
(
os
.
path
.
join
(
input_dir
,
filepath
),
map_location
=
"cpu"
)
baichuan2_state_dict
.
update
(
shard_weight
)
llama2_state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
for
key
,
value
in
tqdm
(
baichuan2_state_dict
.
items
(),
desc
=
"Convert format"
):
if
"W_pack"
in
key
:
proj_size
=
value
.
size
(
0
)
//
3
llama2_state_dict
[
key
.
replace
(
"W_pack"
,
"q_proj"
)]
=
value
[:
proj_size
,
:]
llama2_state_dict
[
key
.
replace
(
"W_pack"
,
"k_proj"
)]
=
value
[
proj_size
:
2
*
proj_size
,
:]
llama2_state_dict
[
key
.
replace
(
"W_pack"
,
"v_proj"
)]
=
value
[
2
*
proj_size
:,
:]
elif
"lm_head"
in
key
:
llama2_state_dict
[
key
]
=
torch
.
nn
.
functional
.
normalize
(
value
)
else
:
llama2_state_dict
[
key
]
=
value
weights_name
=
SAFE_WEIGHTS_NAME
if
save_safetensors
else
WEIGHTS_NAME
shards
,
index
=
shard_checkpoint
(
llama2_state_dict
,
max_shard_size
=
shard_size
,
weights_name
=
weights_name
)
for
shard_file
,
shard
in
tqdm
(
shards
.
items
(),
desc
=
"Save weights"
):
if
save_safetensors
:
save_file
(
shard
,
os
.
path
.
join
(
output_dir
,
shard_file
),
metadata
=
{
"format"
:
"pt"
})
else
:
torch
.
save
(
shard
,
os
.
path
.
join
(
output_dir
,
shard_file
))
if
index
is
None
:
print
(
"Model weights saved in {}"
.
format
(
os
.
path
.
join
(
output_dir
,
WEIGHTS_NAME
)))
else
:
index_name
=
SAFE_WEIGHTS_INDEX_NAME
if
save_safetensors
else
WEIGHTS_INDEX_NAME
with
open
(
os
.
path
.
join
(
output_dir
,
index_name
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
index
,
f
,
indent
=
2
,
sort_keys
=
True
)
print
(
"Model weights saved in {}"
.
format
(
output_dir
))
def
save_config
(
input_dir
:
str
,
output_dir
:
str
):
with
open
(
os
.
path
.
join
(
input_dir
,
CONFIG_NAME
),
"r"
,
encoding
=
"utf-8"
)
as
f
:
llama2_config_dict
:
Dict
[
str
,
Any
]
=
json
.
load
(
f
)
llama2_config_dict
[
"architectures"
]
=
[
"LlamaForCausalLM"
]
llama2_config_dict
.
pop
(
"auto_map"
,
None
)
llama2_config_dict
.
pop
(
"tokenizer_class"
,
None
)
llama2_config_dict
[
"model_type"
]
=
"llama"
with
open
(
os
.
path
.
join
(
output_dir
,
CONFIG_NAME
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
llama2_config_dict
,
f
,
indent
=
2
)
print
(
"Model config saved in {}"
.
format
(
os
.
path
.
join
(
output_dir
,
CONFIG_NAME
)))
def
llamafy_baichuan2
(
input_dir
:
str
,
output_dir
:
str
,
shard_size
:
Optional
[
str
]
=
"2GB"
,
save_safetensors
:
Optional
[
bool
]
=
False
):
r
"""
Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
Usage: python llamafy_baichuan2.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
"""
try
:
os
.
makedirs
(
output_dir
,
exist_ok
=
False
)
except
Exception
as
e
:
raise
print
(
"Output dir already exists"
,
e
)
save_weight
(
input_dir
,
output_dir
,
shard_size
,
save_safetensors
)
save_config
(
input_dir
,
output_dir
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
llamafy_baichuan2
)
LLaMA-Factory/scripts/llamafy_qwen.py
0 → 100644
View file @
032b90a1
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
Optional
import
fire
import
torch
from
safetensors
import
safe_open
from
safetensors.torch
import
save_file
from
tqdm
import
tqdm
from
transformers.modeling_utils
import
(
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
,
shard_checkpoint
,
)
from
transformers.utils
import
check_min_version
try
:
check_min_version
(
"4.34.0"
)
except
Exception
:
raise
ValueError
(
"Please upgrade `transformers` to 4.34.0"
)
CONFIG_NAME
=
"config.json"
def
save_weight
(
input_dir
:
str
,
output_dir
:
str
,
shard_size
:
str
,
save_safetensors
:
bool
)
->
str
:
qwen_state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
for
filepath
in
tqdm
(
os
.
listdir
(
input_dir
),
desc
=
"Load weights"
):
if
os
.
path
.
isfile
(
os
.
path
.
join
(
input_dir
,
filepath
))
and
filepath
.
endswith
(
".safetensors"
):
with
safe_open
(
os
.
path
.
join
(
input_dir
,
filepath
),
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
key
in
f
.
keys
():
qwen_state_dict
[
key
]
=
f
.
get_tensor
(
key
)
llama2_state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
OrderedDict
()
torch_dtype
=
None
for
key
,
value
in
tqdm
(
qwen_state_dict
.
items
(),
desc
=
"Convert format"
):
if
torch_dtype
is
None
:
torch_dtype
=
value
.
dtype
if
"wte"
in
key
:
llama2_state_dict
[
"model.embed_tokens.weight"
]
=
value
elif
"ln_f"
in
key
:
llama2_state_dict
[
"model.norm.weight"
]
=
value
else
:
key
=
key
.
replace
(
"transformer.h"
,
"model.layers"
)
if
"attn.c_attn"
in
key
:
proj_size
=
value
.
size
(
0
)
//
3
llama2_state_dict
[
key
.
replace
(
"attn.c_attn"
,
"self_attn.q_proj"
)]
=
value
[:
proj_size
,
...]
llama2_state_dict
[
key
.
replace
(
"attn.c_attn"
,
"self_attn.k_proj"
)]
=
value
[
proj_size
:
2
*
proj_size
,
...
]
llama2_state_dict
[
key
.
replace
(
"attn.c_attn"
,
"self_attn.v_proj"
)]
=
value
[
2
*
proj_size
:,
...]
elif
"attn.c_proj"
in
key
:
llama2_state_dict
[
key
.
replace
(
"attn.c_proj"
,
"self_attn.o_proj"
)]
=
value
llama2_state_dict
[
key
.
replace
(
"attn.c_proj.weight"
,
"self_attn.o_proj.bias"
)]
=
torch
.
zeros_like
(
value
[:,
0
]
).
squeeze
()
elif
"ln_1"
in
key
:
llama2_state_dict
[
key
.
replace
(
"ln_1"
,
"input_layernorm"
)]
=
value
elif
"ln_2"
in
key
:
llama2_state_dict
[
key
.
replace
(
"ln_2"
,
"post_attention_layernorm"
)]
=
value
elif
"mlp.w1"
in
key
:
llama2_state_dict
[
key
.
replace
(
"mlp.w1"
,
"mlp.up_proj"
)]
=
value
elif
"mlp.w2"
in
key
:
llama2_state_dict
[
key
.
replace
(
"mlp.w2"
,
"mlp.gate_proj"
)]
=
value
elif
"mlp.c_proj"
in
key
:
llama2_state_dict
[
key
.
replace
(
"mlp.c_proj"
,
"mlp.down_proj"
)]
=
value
elif
"lm_head"
in
key
:
llama2_state_dict
[
key
]
=
value
else
:
raise
KeyError
(
"Unable to process key {}"
.
format
(
key
))
weights_name
=
SAFE_WEIGHTS_NAME
if
save_safetensors
else
WEIGHTS_NAME
shards
,
index
=
shard_checkpoint
(
llama2_state_dict
,
max_shard_size
=
shard_size
,
weights_name
=
weights_name
)
for
shard_file
,
shard
in
tqdm
(
shards
.
items
(),
desc
=
"Save weights"
):
if
save_safetensors
:
save_file
(
shard
,
os
.
path
.
join
(
output_dir
,
shard_file
),
metadata
=
{
"format"
:
"pt"
})
else
:
torch
.
save
(
shard
,
os
.
path
.
join
(
output_dir
,
shard_file
))
if
index
is
None
:
print
(
"Model weights saved in {}"
.
format
(
os
.
path
.
join
(
output_dir
,
weights_name
)))
else
:
index_name
=
SAFE_WEIGHTS_INDEX_NAME
if
save_safetensors
else
WEIGHTS_INDEX_NAME
with
open
(
os
.
path
.
join
(
output_dir
,
index_name
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
index
,
f
,
indent
=
2
,
sort_keys
=
True
)
print
(
"Model weights saved in {}"
.
format
(
output_dir
))
return
str
(
torch_dtype
).
replace
(
"torch."
,
""
)
def
save_config
(
input_dir
:
str
,
output_dir
:
str
,
torch_dtype
:
str
):
with
open
(
os
.
path
.
join
(
input_dir
,
CONFIG_NAME
),
"r"
,
encoding
=
"utf-8"
)
as
f
:
qwen_config_dict
:
Dict
[
str
,
Any
]
=
json
.
load
(
f
)
llama2_config_dict
:
Dict
[
str
,
Any
]
=
OrderedDict
()
llama2_config_dict
[
"architectures"
]
=
[
"LlamaForCausalLM"
]
llama2_config_dict
[
"hidden_act"
]
=
"silu"
llama2_config_dict
[
"hidden_size"
]
=
qwen_config_dict
[
"hidden_size"
]
llama2_config_dict
[
"initializer_range"
]
=
qwen_config_dict
[
"initializer_range"
]
llama2_config_dict
[
"intermediate_size"
]
=
qwen_config_dict
[
"intermediate_size"
]
//
2
llama2_config_dict
[
"max_position_embeddings"
]
=
qwen_config_dict
[
"max_position_embeddings"
]
llama2_config_dict
[
"model_type"
]
=
"llama"
llama2_config_dict
[
"num_attention_heads"
]
=
qwen_config_dict
[
"num_attention_heads"
]
llama2_config_dict
[
"num_hidden_layers"
]
=
qwen_config_dict
[
"num_hidden_layers"
]
llama2_config_dict
[
"num_key_value_heads"
]
=
qwen_config_dict
[
"hidden_size"
]
//
qwen_config_dict
[
"kv_channels"
]
llama2_config_dict
[
"pretraining_tp"
]
=
1
llama2_config_dict
[
"rms_norm_eps"
]
=
qwen_config_dict
[
"layer_norm_epsilon"
]
llama2_config_dict
[
"rope_scaling"
]
=
None
llama2_config_dict
[
"tie_word_embeddings"
]
=
qwen_config_dict
[
"tie_word_embeddings"
]
llama2_config_dict
[
"torch_dtype"
]
=
torch_dtype
llama2_config_dict
[
"transformers_version"
]
=
"4.34.0"
llama2_config_dict
[
"use_cache"
]
=
True
llama2_config_dict
[
"vocab_size"
]
=
qwen_config_dict
[
"vocab_size"
]
llama2_config_dict
[
"attention_bias"
]
=
True
with
open
(
os
.
path
.
join
(
output_dir
,
CONFIG_NAME
),
"w"
,
encoding
=
"utf-8"
)
as
f
:
json
.
dump
(
llama2_config_dict
,
f
,
indent
=
2
)
print
(
"Model config saved in {}"
.
format
(
os
.
path
.
join
(
output_dir
,
CONFIG_NAME
)))
def
llamafy_qwen
(
input_dir
:
str
,
output_dir
:
str
,
shard_size
:
Optional
[
str
]
=
"2GB"
,
save_safetensors
:
Optional
[
bool
]
=
False
):
r
"""
Converts the Qwen models in the same format as LLaMA2.
Usage: python llamafy_qwen.py --input_dir input --output_dir output
Converted model: https://huggingface.co/hiyouga/Qwen-14B-Chat-LLaMAfied
"""
try
:
os
.
makedirs
(
output_dir
,
exist_ok
=
False
)
except
Exception
as
e
:
raise
print
(
"Output dir already exists"
,
e
)
torch_dtype
=
save_weight
(
input_dir
,
output_dir
,
shard_size
,
save_safetensors
)
save_config
(
input_dir
,
output_dir
,
torch_dtype
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
llamafy_qwen
)
LLaMA-Factory/scripts/loftq_init.py
0 → 100644
View file @
032b90a1
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/examples/loftq_finetuning/quantize_save_load.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
import
fire
from
peft
import
LoftQConfig
,
LoraConfig
,
TaskType
,
get_peft_model
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
def
quantize_loftq
(
model_name_or_path
:
str
,
output_dir
:
str
,
loftq_bits
:
int
=
4
,
loftq_iter
:
int
=
4
,
lora_alpha
:
int
=
None
,
lora_rank
:
int
=
16
,
lora_dropout
:
float
=
0
,
lora_target
:
tuple
=
(
"q_proj"
,
"v_proj"
),
save_safetensors
:
bool
=
True
,
):
r
"""
Initializes LoRA weights with LoRA-fine-tuning-aware Quantization (LoftQ)
Usage: python loftq_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
if
isinstance
(
lora_target
,
str
):
lora_target
=
[
name
.
strip
()
for
name
in
lora_target
.
split
(
","
)]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
,
torch_dtype
=
"auto"
)
loftq_config
=
LoftQConfig
(
loftq_bits
=
loftq_bits
,
loftq_iter
=
loftq_iter
)
lora_config
=
LoraConfig
(
task_type
=
TaskType
.
CAUSAL_LM
,
inference_mode
=
True
,
r
=
lora_rank
,
lora_alpha
=
lora_alpha
if
lora_alpha
is
not
None
else
lora_rank
*
2
,
lora_dropout
=
lora_dropout
,
target_modules
=
lora_target
,
init_lora_weights
=
"loftq"
,
loftq_config
=
loftq_config
,
)
# Init LoftQ model
print
(
"Initializing LoftQ weights, it may be take several minutes, wait patiently."
)
peft_model
=
get_peft_model
(
model
,
lora_config
)
loftq_dir
=
os
.
path
.
join
(
output_dir
,
"loftq_init"
)
# Save LoftQ model
setattr
(
peft_model
.
peft_config
[
"default"
],
"base_model_name_or_path"
,
os
.
path
.
abspath
(
output_dir
))
setattr
(
peft_model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
# don't apply loftq again
peft_model
.
save_pretrained
(
loftq_dir
,
safe_serialization
=
save_safetensors
)
print
(
"Adapter weights saved in {}"
.
format
(
loftq_dir
))
# Save base model
base_model
:
"PreTrainedModel"
=
peft_model
.
unload
()
base_model
.
save_pretrained
(
output_dir
,
safe_serialization
=
save_safetensors
)
tokenizer
.
save_pretrained
(
output_dir
)
print
(
"Model weights saved in {}"
.
format
(
output_dir
))
print
(
"- Fine-tune this model with:"
)
print
(
"model_name_or_path: {}"
.
format
(
output_dir
))
print
(
"adapter_name_or_path: {}"
.
format
(
loftq_dir
))
print
(
"finetuning_type: lora"
)
print
(
"quantization_bit: {}"
.
format
(
loftq_bits
))
if
__name__
==
"__main__"
:
fire
.
Fire
(
quantize_loftq
)
LLaMA-Factory/scripts/pissa_init.py
0 → 100644
View file @
032b90a1
# coding=utf-8
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is based on the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.11.0/examples/pissa_finetuning/preprocess.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
typing
import
TYPE_CHECKING
import
fire
from
peft
import
LoraConfig
,
TaskType
,
get_peft_model
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedModel
def
quantize_pissa
(
model_name_or_path
:
str
,
output_dir
:
str
,
pissa_iter
:
int
=
4
,
lora_alpha
:
int
=
None
,
lora_rank
:
int
=
16
,
lora_dropout
:
float
=
0
,
lora_target
:
tuple
=
(
"q_proj"
,
"v_proj"
),
save_safetensors
:
bool
=
True
,
):
r
"""
Initializes LoRA weights with Principal Singular values and Singular vectors Adaptation (PiSSA)
Usage: python pissa_init.py --model_name_or_path path_to_model --output_dir output_dir
"""
if
isinstance
(
lora_target
,
str
):
lora_target
=
[
name
.
strip
()
for
name
in
lora_target
.
split
(
","
)]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
)
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name_or_path
,
trust_remote_code
=
True
,
torch_dtype
=
"auto"
)
lora_config
=
LoraConfig
(
task_type
=
TaskType
.
CAUSAL_LM
,
r
=
lora_rank
,
lora_alpha
=
lora_alpha
if
lora_alpha
is
not
None
else
lora_rank
*
2
,
lora_dropout
=
lora_dropout
,
target_modules
=
lora_target
,
init_lora_weights
=
"pissa"
if
pissa_iter
==
-
1
else
"pissa_niter_{}"
.
format
(
pissa_iter
),
)
# Init PiSSA model
peft_model
=
get_peft_model
(
model
,
lora_config
)
pissa_dir
=
os
.
path
.
join
(
output_dir
,
"pissa_init"
)
# Save PiSSA model
setattr
(
peft_model
.
peft_config
[
"default"
],
"base_model_name_or_path"
,
os
.
path
.
abspath
(
output_dir
))
setattr
(
peft_model
.
peft_config
[
"default"
],
"init_lora_weights"
,
True
)
# don't apply pissa again
peft_model
.
save_pretrained
(
pissa_dir
,
safe_serialization
=
save_safetensors
)
print
(
"Adapter weights saved in {}"
.
format
(
pissa_dir
))
# Save base model
base_model
:
"PreTrainedModel"
=
peft_model
.
unload
()
base_model
.
save_pretrained
(
output_dir
,
safe_serialization
=
save_safetensors
)
tokenizer
.
save_pretrained
(
output_dir
)
print
(
"Model weights saved in {}"
.
format
(
output_dir
))
print
(
"- Fine-tune this model with:"
)
print
(
"model_name_or_path: {}"
.
format
(
output_dir
))
print
(
"adapter_name_or_path: {}"
.
format
(
pissa_dir
))
print
(
"finetuning_type: lora"
)
print
(
"pissa_init: false"
)
print
(
"pissa_convert: true"
)
print
(
"- and optionally with:"
)
print
(
"quantization_bit: 4"
)
if
__name__
==
"__main__"
:
fire
.
Fire
(
quantize_pissa
)
LLaMA-Factory/scripts/test_toolcall.py
0 → 100644
View file @
032b90a1
# coding=utf-8
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
os
from
typing
import
Sequence
from
openai
import
OpenAI
from
transformers.utils.versions
import
require_version
require_version
(
"openai>=1.5.0"
,
"To fix: pip install openai>=1.5.0"
)
def
calculate_gpa
(
grades
:
Sequence
[
str
],
hours
:
Sequence
[
int
])
->
float
:
grade_to_score
=
{
"A"
:
4
,
"B"
:
3
,
"C"
:
2
}
total_score
,
total_hour
=
0
,
0
for
grade
,
hour
in
zip
(
grades
,
hours
):
total_score
+=
grade_to_score
[
grade
]
*
hour
total_hour
+=
hour
return
round
(
total_score
/
total_hour
,
2
)
def
main
():
client
=
OpenAI
(
api_key
=
"{}"
.
format
(
os
.
environ
.
get
(
"API_KEY"
,
"0"
)),
base_url
=
"http://localhost:{}/v1"
.
format
(
os
.
environ
.
get
(
"API_PORT"
,
8000
)),
)
tools
=
[
{
"type"
:
"function"
,
"function"
:
{
"name"
:
"calculate_gpa"
,
"description"
:
"Calculate the Grade Point Average (GPA) based on grades and credit hours"
,
"parameters"
:
{
"type"
:
"object"
,
"properties"
:
{
"grades"
:
{
"type"
:
"array"
,
"items"
:
{
"type"
:
"string"
},
"description"
:
"The grades"
},
"hours"
:
{
"type"
:
"array"
,
"items"
:
{
"type"
:
"integer"
},
"description"
:
"The credit hours"
},
},
"required"
:
[
"grades"
,
"hours"
],
},
},
}
]
tool_map
=
{
"calculate_gpa"
:
calculate_gpa
}
messages
=
[]
messages
.
append
({
"role"
:
"user"
,
"content"
:
"My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."
})
result
=
client
.
chat
.
completions
.
create
(
messages
=
messages
,
model
=
"test"
,
tools
=
tools
)
if
result
.
choices
[
0
].
message
.
tool_calls
is
None
:
raise
ValueError
(
"Cannot retrieve function call from the response."
)
messages
.
append
(
result
.
choices
[
0
].
message
)
tool_call
=
result
.
choices
[
0
].
message
.
tool_calls
[
0
].
function
print
(
tool_call
)
# Function(arguments='{"grades": ["A", "A", "B", "C"], "hours": [3, 4, 3, 2]}', name='calculate_gpa')
name
,
arguments
=
tool_call
.
name
,
json
.
loads
(
tool_call
.
arguments
)
tool_result
=
tool_map
[
name
](
**
arguments
)
messages
.
append
({
"role"
:
"tool"
,
"content"
:
json
.
dumps
({
"gpa"
:
tool_result
},
ensure_ascii
=
False
)})
result
=
client
.
chat
.
completions
.
create
(
messages
=
messages
,
model
=
"test"
,
tools
=
tools
)
print
(
result
.
choices
[
0
].
message
.
content
)
# Based on the grades and credit hours you provided, your Grade Point Average (GPA) is 3.42.
if
__name__
==
"__main__"
:
main
()
LLaMA-Factory/setup.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
re
from
setuptools
import
find_packages
,
setup
def
get_version
():
with
open
(
os
.
path
.
join
(
"src"
,
"llamafactory"
,
"extras"
,
"env.py"
),
"r"
,
encoding
=
"utf-8"
)
as
f
:
file_content
=
f
.
read
()
pattern
=
r
"{}\W*=\W*\"([^\"]+)\""
.
format
(
"VERSION"
)
(
version
,)
=
re
.
findall
(
pattern
,
file_content
)
return
version
def
get_requires
():
with
open
(
"requirements.txt"
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
file_content
=
f
.
read
()
lines
=
[
line
.
strip
()
for
line
in
file_content
.
strip
().
split
(
"
\n
"
)
if
not
line
.
startswith
(
"#"
)]
return
lines
extra_require
=
{
"torch"
:
[
"torch>=1.13.1"
],
"torch-npu"
:
[
"torch==2.1.0"
,
"torch-npu==2.1.0.post3"
,
"decorator"
],
"metrics"
:
[
"nltk"
,
"jieba"
,
"rouge-chinese"
],
"deepspeed"
:
[
"deepspeed>=0.10.0"
],
"bitsandbytes"
:
[
"bitsandbytes>=0.39.0"
],
"hqq"
:
[
"hqq"
],
"eetq"
:
[
"eetq"
],
"gptq"
:
[
"optimum>=1.17.0"
,
"auto-gptq>=0.5.0"
],
"awq"
:
[
"autoawq"
],
"aqlm"
:
[
"aqlm[gpu]>=1.1.0"
],
"vllm"
:
[
"vllm>=0.4.3"
],
"galore"
:
[
"galore-torch"
],
"badam"
:
[
"badam>=1.2.1"
],
"qwen"
:
[
"transformers_stream_generator"
],
"modelscope"
:
[
"modelscope"
],
"dev"
:
[
"ruff"
,
"pytest"
],
}
def
main
():
setup
(
name
=
"llamafactory"
,
version
=
get_version
(),
author
=
"hiyouga"
,
author_email
=
"hiyouga"
"@"
"buaa.edu.cn"
,
description
=
"Easy-to-use LLM fine-tuning framework"
,
long_description
=
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
).
read
(),
long_description_content_type
=
"text/markdown"
,
keywords
=
[
"LLaMA"
,
"BLOOM"
,
"Falcon"
,
"LLM"
,
"ChatGPT"
,
"transformer"
,
"pytorch"
,
"deep learning"
],
license
=
"Apache 2.0 License"
,
url
=
"https://github.com/hiyouga/LLaMA-Factory"
,
package_dir
=
{
""
:
"src"
},
packages
=
find_packages
(
"src"
),
python_requires
=
">=3.8.0"
,
install_requires
=
get_requires
(),
extras_require
=
extra_require
,
entry_points
=
{
"console_scripts"
:
[
"llamafactory-cli = llamafactory.cli:main"
]},
classifiers
=
[
"Development Status :: 4 - Beta"
,
"Intended Audience :: Developers"
,
"Intended Audience :: Education"
,
"Intended Audience :: Science/Research"
,
"License :: OSI Approved :: Apache Software License"
,
"Operating System :: OS Independent"
,
"Programming Language :: Python :: 3"
,
"Programming Language :: Python :: 3.8"
,
"Programming Language :: Python :: 3.9"
,
"Programming Language :: Python :: 3.10"
,
"Programming Language :: Python :: 3.11"
,
"Topic :: Scientific/Engineering :: Artificial Intelligence"
,
],
)
if
__name__
==
"__main__"
:
main
()
LLaMA-Factory/src/api.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
uvicorn
from
llamafactory.api.app
import
create_app
from
llamafactory.chat
import
ChatModel
def
main
():
chat_model
=
ChatModel
()
app
=
create_app
(
chat_model
)
api_host
=
os
.
environ
.
get
(
"API_HOST"
,
"0.0.0.0"
)
api_port
=
int
(
os
.
environ
.
get
(
"API_PORT"
,
"8000"
))
print
(
"Visit http://localhost:{}/docs for API document."
.
format
(
api_port
))
uvicorn
.
run
(
app
,
host
=
api_host
,
port
=
api_port
)
if
__name__
==
"__main__"
:
main
()
LLaMA-Factory/src/llamafactory/__init__.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""
Efficient fine-tuning of large language models.
Level:
api, webui > chat, eval, train > data, model > hparams > extras
Dependency graph:
main:
transformers>=4.41.2
datasets>=2.16.0
accelerate>=0.30.1
peft>=0.11.1
trl>=0.8.6
attention:
transformers>=4.42.4 (gemma+fa2)
longlora:
transformers>=4.41.2,<=4.42.4
packing:
transformers>=4.41.2,<=4.42.4
patcher:
transformers==4.41.2 (chatglm)
"""
from
.cli
import
VERSION
__version__
=
VERSION
LLaMA-Factory/src/llamafactory/api/__init__.py
0 → 100644
View file @
032b90a1
LLaMA-Factory/src/llamafactory/api/app.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
contextlib
import
asynccontextmanager
from
typing
import
Optional
from
typing_extensions
import
Annotated
from
..chat
import
ChatModel
from
..extras.misc
import
torch_gc
from
..extras.packages
import
is_fastapi_available
,
is_starlette_available
,
is_uvicorn_available
from
.chat
import
(
create_chat_completion_response
,
create_score_evaluation_response
,
create_stream_chat_completion_response
,
)
from
.protocol
import
(
ChatCompletionRequest
,
ChatCompletionResponse
,
ModelCard
,
ModelList
,
ScoreEvaluationRequest
,
ScoreEvaluationResponse
,
)
if
is_fastapi_available
():
from
fastapi
import
Depends
,
FastAPI
,
HTTPException
,
status
from
fastapi.middleware.cors
import
CORSMiddleware
from
fastapi.security.http
import
HTTPAuthorizationCredentials
,
HTTPBearer
if
is_starlette_available
():
from
sse_starlette
import
EventSourceResponse
if
is_uvicorn_available
():
import
uvicorn
@
asynccontextmanager
async
def
lifespan
(
app
:
"FastAPI"
):
# collects GPU memory
yield
torch_gc
()
def
create_app
(
chat_model
:
"ChatModel"
)
->
"FastAPI"
:
app
=
FastAPI
(
lifespan
=
lifespan
)
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
[
"*"
],
allow_credentials
=
True
,
allow_methods
=
[
"*"
],
allow_headers
=
[
"*"
],
)
api_key
=
os
.
environ
.
get
(
"API_KEY"
)
security
=
HTTPBearer
(
auto_error
=
False
)
async
def
verify_api_key
(
auth
:
Annotated
[
Optional
[
HTTPAuthorizationCredentials
],
Depends
(
security
)]):
if
api_key
and
(
auth
is
None
or
auth
.
credentials
!=
api_key
):
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
"Invalid API key."
)
@
app
.
get
(
"/v1/models"
,
response_model
=
ModelList
,
status_code
=
status
.
HTTP_200_OK
,
dependencies
=
[
Depends
(
verify_api_key
)],
)
async
def
list_models
():
model_card
=
ModelCard
(
id
=
"gpt-3.5-turbo"
)
return
ModelList
(
data
=
[
model_card
])
@
app
.
post
(
"/v1/chat/completions"
,
response_model
=
ChatCompletionResponse
,
status_code
=
status
.
HTTP_200_OK
,
dependencies
=
[
Depends
(
verify_api_key
)],
)
async
def
create_chat_completion
(
request
:
ChatCompletionRequest
):
if
not
chat_model
.
engine
.
can_generate
:
raise
HTTPException
(
status_code
=
status
.
HTTP_405_METHOD_NOT_ALLOWED
,
detail
=
"Not allowed"
)
if
request
.
stream
:
generate
=
create_stream_chat_completion_response
(
request
,
chat_model
)
return
EventSourceResponse
(
generate
,
media_type
=
"text/event-stream"
)
else
:
return
await
create_chat_completion_response
(
request
,
chat_model
)
@
app
.
post
(
"/v1/score/evaluation"
,
response_model
=
ScoreEvaluationResponse
,
status_code
=
status
.
HTTP_200_OK
,
dependencies
=
[
Depends
(
verify_api_key
)],
)
async
def
create_score_evaluation
(
request
:
ScoreEvaluationRequest
):
if
chat_model
.
engine
.
can_generate
:
raise
HTTPException
(
status_code
=
status
.
HTTP_405_METHOD_NOT_ALLOWED
,
detail
=
"Not allowed"
)
return
await
create_score_evaluation_response
(
request
,
chat_model
)
return
app
def
run_api
()
->
None
:
chat_model
=
ChatModel
()
app
=
create_app
(
chat_model
)
api_host
=
os
.
environ
.
get
(
"API_HOST"
,
"0.0.0.0"
)
api_port
=
int
(
os
.
environ
.
get
(
"API_PORT"
,
"8000"
))
print
(
"Visit http://localhost:{}/docs for API document."
.
format
(
api_port
))
uvicorn
.
run
(
app
,
host
=
api_host
,
port
=
api_port
)
LLaMA-Factory/src/llamafactory/api/chat.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
base64
import
io
import
json
import
os
import
uuid
from
typing
import
TYPE_CHECKING
,
AsyncGenerator
,
Dict
,
List
,
Optional
,
Tuple
from
..data
import
Role
as
DataRole
from
..extras.logging
import
get_logger
from
..extras.packages
import
is_fastapi_available
,
is_pillow_available
,
is_requests_available
from
.common
import
dictify
,
jsonify
from
.protocol
import
(
ChatCompletionMessage
,
ChatCompletionResponse
,
ChatCompletionResponseChoice
,
ChatCompletionResponseUsage
,
ChatCompletionStreamResponse
,
ChatCompletionStreamResponseChoice
,
Finish
,
Function
,
FunctionCall
,
Role
,
ScoreEvaluationResponse
,
)
if
is_fastapi_available
():
from
fastapi
import
HTTPException
,
status
if
is_pillow_available
():
from
PIL
import
Image
if
is_requests_available
():
import
requests
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
..chat
import
ChatModel
from
.protocol
import
ChatCompletionRequest
,
ScoreEvaluationRequest
logger
=
get_logger
(
__name__
)
ROLE_MAPPING
=
{
Role
.
USER
:
DataRole
.
USER
.
value
,
Role
.
ASSISTANT
:
DataRole
.
ASSISTANT
.
value
,
Role
.
SYSTEM
:
DataRole
.
SYSTEM
.
value
,
Role
.
FUNCTION
:
DataRole
.
FUNCTION
.
value
,
Role
.
TOOL
:
DataRole
.
OBSERVATION
.
value
,
}
def
_process_request
(
request
:
"ChatCompletionRequest"
,
)
->
Tuple
[
List
[
Dict
[
str
,
str
]],
Optional
[
str
],
Optional
[
str
],
Optional
[
"NDArray"
]]:
logger
.
info
(
"==== request ====
\n
{}"
.
format
(
json
.
dumps
(
dictify
(
request
),
indent
=
2
,
ensure_ascii
=
False
)))
if
len
(
request
.
messages
)
==
0
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid length"
)
if
request
.
messages
[
0
].
role
==
Role
.
SYSTEM
:
system
=
request
.
messages
.
pop
(
0
).
content
else
:
system
=
None
if
len
(
request
.
messages
)
%
2
==
0
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Only supports u/a/u/a/u..."
)
input_messages
=
[]
image
=
None
for
i
,
message
in
enumerate
(
request
.
messages
):
if
i
%
2
==
0
and
message
.
role
not
in
[
Role
.
USER
,
Role
.
TOOL
]:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid role"
)
elif
i
%
2
==
1
and
message
.
role
not
in
[
Role
.
ASSISTANT
,
Role
.
FUNCTION
]:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid role"
)
if
message
.
role
==
Role
.
ASSISTANT
and
isinstance
(
message
.
tool_calls
,
list
)
and
len
(
message
.
tool_calls
):
tool_calls
=
[
{
"name"
:
tool_call
.
function
.
name
,
"arguments"
:
tool_call
.
function
.
arguments
}
for
tool_call
in
message
.
tool_calls
]
content
=
json
.
dumps
(
tool_calls
,
ensure_ascii
=
False
)
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
Role
.
FUNCTION
],
"content"
:
content
})
elif
isinstance
(
message
.
content
,
list
):
for
input_item
in
message
.
content
:
if
input_item
.
type
==
"text"
:
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
input_item
.
text
})
else
:
image_url
=
input_item
.
image_url
.
url
if
image_url
.
startswith
(
"data:image"
):
# base64 image
image_data
=
base64
.
b64decode
(
image_url
.
split
(
","
,
maxsplit
=
1
)[
1
])
image_path
=
io
.
BytesIO
(
image_data
)
elif
os
.
path
.
isfile
(
image_url
):
# local file
image_path
=
open
(
image_url
,
"rb"
)
else
:
# web uri
image_path
=
requests
.
get
(
image_url
,
stream
=
True
).
raw
image
=
Image
.
open
(
image_path
).
convert
(
"RGB"
)
else
:
input_messages
.
append
({
"role"
:
ROLE_MAPPING
[
message
.
role
],
"content"
:
message
.
content
})
tool_list
=
request
.
tools
if
isinstance
(
tool_list
,
list
)
and
len
(
tool_list
):
try
:
tools
=
json
.
dumps
([
dictify
(
tool
.
function
)
for
tool
in
tool_list
],
ensure_ascii
=
False
)
except
json
.
JSONDecodeError
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid tools"
)
else
:
tools
=
None
return
input_messages
,
system
,
tools
,
image
def
_create_stream_chat_completion_chunk
(
completion_id
:
str
,
model
:
str
,
delta
:
"ChatCompletionMessage"
,
index
:
Optional
[
int
]
=
0
,
finish_reason
:
Optional
[
"Finish"
]
=
None
,
)
->
str
:
choice_data
=
ChatCompletionStreamResponseChoice
(
index
=
index
,
delta
=
delta
,
finish_reason
=
finish_reason
)
chunk
=
ChatCompletionStreamResponse
(
id
=
completion_id
,
model
=
model
,
choices
=
[
choice_data
])
return
jsonify
(
chunk
)
async
def
create_chat_completion_response
(
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
)
->
"ChatCompletionResponse"
:
completion_id
=
"chatcmpl-{}"
.
format
(
uuid
.
uuid4
().
hex
)
input_messages
,
system
,
tools
,
image
=
_process_request
(
request
)
responses
=
await
chat_model
.
achat
(
input_messages
,
system
,
tools
,
image
,
do_sample
=
request
.
do_sample
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
max_new_tokens
=
request
.
max_tokens
,
num_return_sequences
=
request
.
n
,
stop
=
request
.
stop
,
)
prompt_length
,
response_length
=
0
,
0
choices
=
[]
for
i
,
response
in
enumerate
(
responses
):
if
tools
:
result
=
chat_model
.
engine
.
template
.
extract_tool
(
response
.
response_text
)
else
:
result
=
response
.
response_text
if
isinstance
(
result
,
list
):
tool_calls
=
[]
for
tool
in
result
:
function
=
Function
(
name
=
tool
[
0
],
arguments
=
tool
[
1
])
tool_calls
.
append
(
FunctionCall
(
id
=
"call_{}"
.
format
(
uuid
.
uuid4
().
hex
),
function
=
function
))
response_message
=
ChatCompletionMessage
(
role
=
Role
.
ASSISTANT
,
tool_calls
=
tool_calls
)
finish_reason
=
Finish
.
TOOL
else
:
response_message
=
ChatCompletionMessage
(
role
=
Role
.
ASSISTANT
,
content
=
result
)
finish_reason
=
Finish
.
STOP
if
response
.
finish_reason
==
"stop"
else
Finish
.
LENGTH
choices
.
append
(
ChatCompletionResponseChoice
(
index
=
i
,
message
=
response_message
,
finish_reason
=
finish_reason
))
prompt_length
=
response
.
prompt_length
response_length
+=
response
.
response_length
usage
=
ChatCompletionResponseUsage
(
prompt_tokens
=
prompt_length
,
completion_tokens
=
response_length
,
total_tokens
=
prompt_length
+
response_length
,
)
return
ChatCompletionResponse
(
id
=
completion_id
,
model
=
request
.
model
,
choices
=
choices
,
usage
=
usage
)
async
def
create_stream_chat_completion_response
(
request
:
"ChatCompletionRequest"
,
chat_model
:
"ChatModel"
)
->
AsyncGenerator
[
str
,
None
]:
completion_id
=
"chatcmpl-{}"
.
format
(
uuid
.
uuid4
().
hex
)
input_messages
,
system
,
tools
,
image
=
_process_request
(
request
)
if
tools
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Cannot stream function calls."
)
if
request
.
n
>
1
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Cannot stream multiple responses."
)
yield
_create_stream_chat_completion_chunk
(
completion_id
=
completion_id
,
model
=
request
.
model
,
delta
=
ChatCompletionMessage
(
role
=
Role
.
ASSISTANT
,
content
=
""
)
)
async
for
new_token
in
chat_model
.
astream_chat
(
input_messages
,
system
,
tools
,
image
,
do_sample
=
request
.
do_sample
,
temperature
=
request
.
temperature
,
top_p
=
request
.
top_p
,
max_new_tokens
=
request
.
max_tokens
,
stop
=
request
.
stop
,
):
if
len
(
new_token
)
!=
0
:
yield
_create_stream_chat_completion_chunk
(
completion_id
=
completion_id
,
model
=
request
.
model
,
delta
=
ChatCompletionMessage
(
content
=
new_token
)
)
yield
_create_stream_chat_completion_chunk
(
completion_id
=
completion_id
,
model
=
request
.
model
,
delta
=
ChatCompletionMessage
(),
finish_reason
=
Finish
.
STOP
)
yield
"[DONE]"
async
def
create_score_evaluation_response
(
request
:
"ScoreEvaluationRequest"
,
chat_model
:
"ChatModel"
)
->
"ScoreEvaluationResponse"
:
if
len
(
request
.
messages
)
==
0
:
raise
HTTPException
(
status_code
=
status
.
HTTP_400_BAD_REQUEST
,
detail
=
"Invalid request"
)
scores
=
await
chat_model
.
aget_scores
(
request
.
messages
,
max_length
=
request
.
max_length
)
return
ScoreEvaluationResponse
(
model
=
request
.
model
,
scores
=
scores
)
Prev
1
2
3
4
5
6
7
8
9
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment