Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Qwen2_pytorch
Commits
032b90a1
Commit
032b90a1
authored
Sep 12, 2024
by
luopl
Browse files
init commit
parents
Pipeline
#1684
canceled with stages
Changes
233
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4599 additions
and
0 deletions
+4599
-0
LLaMA-Factory/src/llamafactory/data/processors/processor_utils.py
...ctory/src/llamafactory/data/processors/processor_utils.py
+95
-0
LLaMA-Factory/src/llamafactory/data/processors/supervised.py
LLaMA-Factory/src/llamafactory/data/processors/supervised.py
+208
-0
LLaMA-Factory/src/llamafactory/data/processors/unsupervised.py
...-Factory/src/llamafactory/data/processors/unsupervised.py
+106
-0
LLaMA-Factory/src/llamafactory/data/template.py
LLaMA-Factory/src/llamafactory/data/template.py
+906
-0
LLaMA-Factory/src/llamafactory/data/tool_utils.py
LLaMA-Factory/src/llamafactory/data/tool_utils.py
+140
-0
LLaMA-Factory/src/llamafactory/eval/__init__.py
LLaMA-Factory/src/llamafactory/eval/__init__.py
+0
-0
LLaMA-Factory/src/llamafactory/eval/evaluator.py
LLaMA-Factory/src/llamafactory/eval/evaluator.py
+154
-0
LLaMA-Factory/src/llamafactory/eval/template.py
LLaMA-Factory/src/llamafactory/eval/template.py
+81
-0
LLaMA-Factory/src/llamafactory/extras/__init__.py
LLaMA-Factory/src/llamafactory/extras/__init__.py
+0
-0
LLaMA-Factory/src/llamafactory/extras/constants.py
LLaMA-Factory/src/llamafactory/extras/constants.py
+1624
-0
LLaMA-Factory/src/llamafactory/extras/env.py
LLaMA-Factory/src/llamafactory/extras/env.py
+75
-0
LLaMA-Factory/src/llamafactory/extras/logging.py
LLaMA-Factory/src/llamafactory/extras/logging.py
+82
-0
LLaMA-Factory/src/llamafactory/extras/misc.py
LLaMA-Factory/src/llamafactory/extras/misc.py
+228
-0
LLaMA-Factory/src/llamafactory/extras/packages.py
LLaMA-Factory/src/llamafactory/extras/packages.py
+88
-0
LLaMA-Factory/src/llamafactory/extras/ploting.py
LLaMA-Factory/src/llamafactory/extras/ploting.py
+101
-0
LLaMA-Factory/src/llamafactory/hparams/__init__.py
LLaMA-Factory/src/llamafactory/hparams/__init__.py
+32
-0
LLaMA-Factory/src/llamafactory/hparams/data_args.py
LLaMA-Factory/src/llamafactory/hparams/data_args.py
+143
-0
LLaMA-Factory/src/llamafactory/hparams/evaluation_args.py
LLaMA-Factory/src/llamafactory/hparams/evaluation_args.py
+62
-0
LLaMA-Factory/src/llamafactory/hparams/finetuning_args.py
LLaMA-Factory/src/llamafactory/hparams/finetuning_args.py
+400
-0
LLaMA-Factory/src/llamafactory/hparams/generating_args.py
LLaMA-Factory/src/llamafactory/hparams/generating_args.py
+74
-0
No files found.
LLaMA-Factory/src/llamafactory/data/processors/processor_utils.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
bisect
from
typing
import
TYPE_CHECKING
,
List
,
Sequence
,
Tuple
from
...extras.packages
import
is_pillow_available
if
is_pillow_available
():
from
PIL
import
Image
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
PIL.Image
import
Image
as
ImageObject
from
transformers
import
ProcessorMixin
from
transformers.image_processing_utils
import
BaseImageProcessor
def
search_for_fit
(
numbers
:
Sequence
[
int
],
capacity
:
int
)
->
int
:
r
"""
Finds the index of largest number that fits into the knapsack with the given capacity.
"""
index
=
bisect
.
bisect
(
numbers
,
capacity
)
return
-
1
if
index
==
0
else
(
index
-
1
)
def
greedy_knapsack
(
numbers
:
List
[
int
],
capacity
:
int
)
->
List
[
List
[
int
]]:
r
"""
An efficient greedy algorithm with binary search for the knapsack problem.
"""
numbers
.
sort
()
# sort numbers in ascending order for binary search
knapsacks
=
[]
while
numbers
:
current_knapsack
=
[]
remaining_capacity
=
capacity
while
True
:
index
=
search_for_fit
(
numbers
,
remaining_capacity
)
if
index
==
-
1
:
break
# no more numbers fit in this knapsack
remaining_capacity
-=
numbers
[
index
]
# update the remaining capacity
current_knapsack
.
append
(
numbers
.
pop
(
index
))
# add the number to knapsack
knapsacks
.
append
(
current_knapsack
)
return
knapsacks
def
get_pixel_values
(
images
:
Sequence
[
"ImageObject"
],
processor
:
"ProcessorMixin"
)
->
"NDArray"
:
r
"""
Processes visual inputs. (currently only supports a single image)
"""
image_processor
:
"BaseImageProcessor"
=
getattr
(
processor
,
"image_processor"
)
image
=
images
[
0
]
if
len
(
images
)
!=
0
else
Image
.
new
(
"RGB"
,
(
100
,
100
),
(
255
,
255
,
255
))
return
image_processor
(
image
,
return_tensors
=
"pt"
)[
"pixel_values"
][
0
]
# shape (C, H, W)
def
get_paligemma_token_type_ids
(
input_len
:
int
,
processor
:
"ProcessorMixin"
)
->
List
[
int
]:
r
"""
Gets paligemma token type ids for computing loss.
"""
image_seq_length
=
getattr
(
processor
,
"image_seq_length"
)
return
[
0
]
*
image_seq_length
+
[
1
]
*
(
input_len
-
image_seq_length
)
def
infer_seqlen
(
source_len
:
int
,
target_len
:
int
,
cutoff_len
:
int
)
->
Tuple
[
int
,
int
]:
r
"""
Computes the real sequence length after truncation by the cutoff_len.
"""
if
target_len
*
2
<
cutoff_len
:
# truncate source
max_target_len
=
cutoff_len
elif
source_len
*
2
<
cutoff_len
:
# truncate target
max_target_len
=
cutoff_len
-
source_len
else
:
# truncate both
max_target_len
=
int
(
cutoff_len
*
(
target_len
/
(
source_len
+
target_len
)))
new_target_len
=
min
(
max_target_len
,
target_len
)
max_source_len
=
max
(
cutoff_len
-
new_target_len
,
0
)
new_source_len
=
min
(
max_source_len
,
source_len
)
return
new_source_len
,
new_target_len
LLaMA-Factory/src/llamafactory/data/processors/supervised.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
defaultdict
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras.constants
import
IGNORE_INDEX
from
...extras.logging
import
get_logger
from
.processor_utils
import
get_paligemma_token_type_ids
,
get_pixel_values
,
greedy_knapsack
,
infer_seqlen
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..template
import
Template
logger
=
get_logger
(
__name__
)
def
_encode_supervised_example
(
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
train_on_prompt
:
bool
,
mask_history
:
bool
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
messages
=
prompt
+
response
input_ids
,
labels
=
[],
[]
if
processor
is
not
None
and
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
template
.
image_token
)
input_ids
+=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
labels
+=
[
IGNORE_INDEX
]
*
getattr
(
processor
,
"image_seq_length"
)
encoded_pairs
=
template
.
encode_multiturn
(
tokenizer
,
messages
,
system
,
tools
)
total_length
=
1
if
template
.
efficient_eos
else
0
for
turn_idx
,
(
source_ids
,
target_ids
)
in
enumerate
(
encoded_pairs
):
if
total_length
>=
cutoff_len
:
break
source_len
,
target_len
=
infer_seqlen
(
len
(
source_ids
),
len
(
target_ids
),
cutoff_len
-
total_length
)
source_ids
=
source_ids
[:
source_len
]
target_ids
=
target_ids
[:
target_len
]
total_length
+=
source_len
+
target_len
if
train_on_prompt
:
source_label
=
source_ids
elif
turn_idx
!=
0
and
template
.
efficient_eos
:
source_label
=
[
tokenizer
.
eos_token_id
]
+
[
IGNORE_INDEX
]
*
(
source_len
-
1
)
else
:
source_label
=
[
IGNORE_INDEX
]
*
source_len
if
mask_history
and
turn_idx
!=
len
(
encoded_pairs
)
-
1
:
target_label
=
[
IGNORE_INDEX
]
*
target_len
else
:
target_label
=
target_ids
input_ids
+=
source_ids
+
target_ids
labels
+=
source_label
+
target_label
if
template
.
efficient_eos
:
input_ids
+=
[
tokenizer
.
eos_token_id
]
labels
+=
[
tokenizer
.
eos_token_id
]
return
input_ids
,
labels
def
preprocess_supervised_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
List
[
int
]]]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
model_inputs
=
{
"input_ids"
:
[],
"attention_mask"
:
[],
"labels"
:
[]}
if
processor
is
not
None
:
model_inputs
[
"pixel_values"
]
=
[]
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
model_inputs
[
"token_type_ids"
]
=
[]
for
i
in
range
(
len
(
examples
[
"prompt"
])):
if
len
(
examples
[
"prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"response"
][
i
])
!=
1
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"prompt"
][
i
]
+
examples
[
"response"
][
i
]))
continue
input_ids
,
labels
=
_encode_supervised_example
(
prompt
=
examples
[
"prompt"
][
i
],
response
=
examples
[
"response"
][
i
],
system
=
examples
[
"system"
][
i
],
tools
=
examples
[
"tools"
][
i
],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
train_on_prompt
=
data_args
.
train_on_prompt
,
mask_history
=
data_args
.
mask_history
,
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
if
processor
is
not
None
:
model_inputs
[
"pixel_values"
].
append
(
get_pixel_values
(
examples
[
"images"
][
i
],
processor
))
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
model_inputs
[
"token_type_ids"
].
append
(
get_paligemma_token_type_ids
(
len
(
input_ids
),
processor
))
return
model_inputs
def
preprocess_packed_supervised_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
List
[
int
]]]:
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
valid_num
=
0
batch_input_ids
,
batch_labels
=
[],
[]
lengths
=
[]
length2indexes
=
defaultdict
(
list
)
for
i
in
range
(
len
(
examples
[
"prompt"
])):
if
len
(
examples
[
"prompt"
][
i
])
%
2
!=
1
or
len
(
examples
[
"response"
][
i
])
!=
1
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"prompt"
][
i
]
+
examples
[
"response"
][
i
]))
continue
input_ids
,
labels
=
_encode_supervised_example
(
prompt
=
examples
[
"prompt"
][
i
],
response
=
examples
[
"response"
][
i
],
system
=
examples
[
"system"
][
i
],
tools
=
examples
[
"tools"
][
i
],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
None
,
cutoff_len
=
data_args
.
cutoff_len
-
1
,
# reserved for the padding token
train_on_prompt
=
data_args
.
train_on_prompt
,
mask_history
=
data_args
.
mask_history
,
)
length
=
len
(
input_ids
)
if
length
>
data_args
.
cutoff_len
:
logger
.
warning
(
"Dropped lengthy example with length {} > {}."
.
format
(
length
,
data_args
.
cutoff_len
))
else
:
lengths
.
append
(
length
)
length2indexes
[
length
].
append
(
valid_num
)
batch_input_ids
.
append
(
input_ids
)
batch_labels
.
append
(
labels
)
valid_num
+=
1
model_inputs
=
{
"input_ids"
:
[],
"attention_mask"
:
[],
"labels"
:
[]}
knapsacks
=
greedy_knapsack
(
lengths
,
data_args
.
cutoff_len
-
1
)
# reserved for the padding token
for
knapsack
in
knapsacks
:
packed_input_ids
,
packed_attention_masks
,
packed_labels
=
[],
[],
[]
for
i
,
length
in
enumerate
(
knapsack
):
index
=
length2indexes
[
length
].
pop
()
packed_input_ids
+=
batch_input_ids
[
index
]
packed_labels
+=
batch_labels
[
index
]
if
data_args
.
neat_packing
:
packed_attention_masks
+=
[
i
+
1
]
*
len
(
batch_input_ids
[
index
])
# start from 1
else
:
packed_attention_masks
+=
[
1
]
*
len
(
batch_input_ids
[
index
])
if
len
(
packed_input_ids
)
<
data_args
.
cutoff_len
:
pad_length
=
data_args
.
cutoff_len
-
len
(
packed_input_ids
)
packed_input_ids
+=
[
tokenizer
.
pad_token_id
]
*
pad_length
packed_labels
+=
[
IGNORE_INDEX
]
*
pad_length
if
data_args
.
neat_packing
:
packed_attention_masks
+=
[
0
]
*
pad_length
else
:
packed_attention_masks
+=
[
1
]
*
pad_length
# more efficient flash_attn
if
len
(
packed_input_ids
)
!=
data_args
.
cutoff_len
:
raise
ValueError
(
"The length of packed example should be identical to the cutoff length."
)
model_inputs
[
"input_ids"
].
append
(
packed_input_ids
)
model_inputs
[
"attention_mask"
].
append
(
packed_attention_masks
)
model_inputs
[
"labels"
].
append
(
packed_labels
)
return
model_inputs
def
print_supervised_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
valid_labels
=
list
(
filter
(
lambda
x
:
x
!=
IGNORE_INDEX
,
example
[
"labels"
]))
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
print
(
"label_ids:
\n
{}"
.
format
(
example
[
"labels"
]))
print
(
"labels:
\n
{}"
.
format
(
tokenizer
.
decode
(
valid_labels
,
skip_special_tokens
=
False
)))
LLaMA-Factory/src/llamafactory/data/processors/unsupervised.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
from
...extras.logging
import
get_logger
from
..data_utils
import
Role
from
.processor_utils
import
get_paligemma_token_type_ids
,
get_pixel_values
,
infer_seqlen
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
,
ProcessorMixin
from
...hparams
import
DataArguments
from
..template
import
Template
logger
=
get_logger
(
__name__
)
def
_encode_unsupervised_example
(
prompt
:
Sequence
[
Dict
[
str
,
str
]],
response
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
cutoff_len
:
int
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
if
processor
is
not
None
and
not
hasattr
(
processor
,
"image_seq_length"
):
# llava-like models
prompt
[
0
][
"content"
]
=
template
.
image_token
+
prompt
[
0
][
"content"
]
if
len
(
response
)
==
1
:
messages
=
prompt
+
response
else
:
messages
=
prompt
+
[{
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
""
}]
input_ids
,
labels
=
template
.
encode_oneturn
(
tokenizer
,
messages
,
system
,
tools
)
if
template
.
efficient_eos
:
labels
+=
[
tokenizer
.
eos_token_id
]
if
processor
is
not
None
and
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
template
.
image_token
)
input_ids
=
[
image_token_id
]
*
getattr
(
processor
,
"image_seq_length"
)
+
input_ids
source_len
,
target_len
=
infer_seqlen
(
len
(
input_ids
),
len
(
labels
),
cutoff_len
)
input_ids
=
input_ids
[:
source_len
]
labels
=
labels
[:
target_len
]
return
input_ids
,
labels
def
preprocess_unsupervised_dataset
(
examples
:
Dict
[
str
,
List
[
Any
]],
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
,
processor
:
Optional
[
"ProcessorMixin"
],
data_args
:
"DataArguments"
,
)
->
Dict
[
str
,
List
[
List
[
int
]]]:
# build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs
=
{
"input_ids"
:
[],
"attention_mask"
:
[],
"labels"
:
[]}
if
processor
is
not
None
:
model_inputs
[
"pixel_values"
]
=
[]
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
model_inputs
[
"token_type_ids"
]
=
[]
for
i
in
range
(
len
(
examples
[
"prompt"
])):
if
len
(
examples
[
"prompt"
][
i
])
%
2
!=
1
:
logger
.
warning
(
"Dropped invalid example: {}"
.
format
(
examples
[
"prompt"
][
i
]
+
examples
[
"response"
][
i
]))
continue
input_ids
,
labels
=
_encode_unsupervised_example
(
prompt
=
examples
[
"prompt"
][
i
],
response
=
examples
[
"response"
][
i
],
system
=
examples
[
"system"
][
i
],
tools
=
examples
[
"tools"
][
i
],
template
=
template
,
tokenizer
=
tokenizer
,
processor
=
processor
,
cutoff_len
=
data_args
.
cutoff_len
,
)
model_inputs
[
"input_ids"
].
append
(
input_ids
)
model_inputs
[
"attention_mask"
].
append
([
1
]
*
len
(
input_ids
))
model_inputs
[
"labels"
].
append
(
labels
)
if
processor
is
not
None
:
model_inputs
[
"pixel_values"
].
append
(
get_pixel_values
(
examples
[
"images"
][
i
],
processor
))
if
hasattr
(
processor
,
"image_seq_length"
):
# paligemma models
model_inputs
[
"token_type_ids"
].
append
(
get_paligemma_token_type_ids
(
len
(
input_ids
),
processor
))
return
model_inputs
def
print_unsupervised_dataset_example
(
example
:
Dict
[
str
,
List
[
int
]],
tokenizer
:
"PreTrainedTokenizer"
)
->
None
:
print
(
"input_ids:
\n
{}"
.
format
(
example
[
"input_ids"
]))
print
(
"inputs:
\n
{}"
.
format
(
tokenizer
.
decode
(
example
[
"input_ids"
],
skip_special_tokens
=
False
)))
LLaMA-Factory/src/llamafactory/data/template.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
..extras.logging
import
get_logger
from
.data_utils
import
Role
from
.formatter
import
EmptyFormatter
,
FunctionFormatter
,
StringFormatter
,
ToolFormatter
if
TYPE_CHECKING
:
from
transformers
import
PreTrainedTokenizer
from
.formatter
import
SLOTS
,
Formatter
logger
=
get_logger
(
__name__
)
@
dataclass
class
Template
:
format_user
:
"Formatter"
format_assistant
:
"Formatter"
format_system
:
"Formatter"
format_function
:
"Formatter"
format_observation
:
"Formatter"
format_tools
:
"Formatter"
format_separator
:
"Formatter"
format_prefix
:
"Formatter"
default_system
:
str
stop_words
:
List
[
str
]
image_token
:
str
efficient_eos
:
bool
replace_eos
:
bool
def
encode_oneturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
r
"""
Returns a single pair of token ids representing prompt and response respectively.
"""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
prompt_ids
=
[]
for
encoded_ids
in
encoded_messages
[:
-
1
]:
prompt_ids
+=
encoded_ids
answer_ids
=
encoded_messages
[
-
1
]
return
prompt_ids
,
answer_ids
def
encode_multiturn
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
]
=
None
,
tools
:
Optional
[
str
]
=
None
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
r
"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
encoded_messages
=
self
.
_encode
(
tokenizer
,
messages
,
system
,
tools
)
return
[(
encoded_messages
[
i
],
encoded_messages
[
i
+
1
])
for
i
in
range
(
0
,
len
(
encoded_messages
),
2
)]
def
extract_tool
(
self
,
content
:
str
)
->
Union
[
str
,
List
[
Tuple
[
str
,
str
]]]:
r
"""
Extracts tool message.
"""
return
self
.
format_tools
.
extract
(
content
)
def
_encode
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
Optional
[
str
],
tools
:
Optional
[
str
],
)
->
List
[
List
[
int
]]:
r
"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system
=
system
or
self
.
default_system
encoded_messages
=
[]
for
i
,
message
in
enumerate
(
messages
):
elements
=
[]
if
i
==
0
:
elements
+=
self
.
format_prefix
.
apply
()
if
system
or
tools
:
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
elements
+=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))
if
i
>
0
and
i
%
2
==
0
:
elements
+=
self
.
format_separator
.
apply
()
if
message
[
"role"
]
==
Role
.
USER
.
value
:
elements
+=
self
.
format_user
.
apply
(
content
=
message
[
"content"
],
idx
=
str
(
i
//
2
))
elif
message
[
"role"
]
==
Role
.
ASSISTANT
.
value
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
OBSERVATION
.
value
:
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
FUNCTION
.
value
:
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"content"
])
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
encoded_messages
.
append
(
self
.
_convert_elements_to_ids
(
tokenizer
,
elements
))
return
encoded_messages
def
_convert_elements_to_ids
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
elements
:
"SLOTS"
)
->
List
[
int
]:
r
"""
Converts elements to token ids.
"""
token_ids
=
[]
for
elem
in
elements
:
if
isinstance
(
elem
,
str
):
if
len
(
elem
)
!=
0
:
token_ids
+=
tokenizer
.
encode
(
elem
,
add_special_tokens
=
False
)
elif
isinstance
(
elem
,
dict
):
token_ids
+=
[
tokenizer
.
convert_tokens_to_ids
(
elem
.
get
(
"token"
))]
elif
isinstance
(
elem
,
set
):
if
"bos_token"
in
elem
and
tokenizer
.
bos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
bos_token_id
]
elif
"eos_token"
in
elem
and
tokenizer
.
eos_token_id
is
not
None
:
token_ids
+=
[
tokenizer
.
eos_token_id
]
else
:
raise
ValueError
(
"Input must be string, set[str] or dict[str, str], got {}"
.
format
(
type
(
elem
)))
return
token_ids
@
dataclass
class
Llama2Template
(
Template
):
def
_encode
(
self
,
tokenizer
:
"PreTrainedTokenizer"
,
messages
:
Sequence
[
Dict
[
str
,
str
]],
system
:
str
,
tools
:
str
,
)
->
List
[
List
[
int
]]:
r
"""
Encodes formatted inputs to pairs of token ids.
Turn 0: prefix + system + query resp
Turn t: sep + query resp
"""
system
=
system
or
self
.
default_system
encoded_messages
=
[]
for
i
,
message
in
enumerate
(
messages
):
elements
=
[]
system_text
=
""
if
i
==
0
:
elements
+=
self
.
format_prefix
.
apply
()
if
system
or
tools
:
tool_text
=
self
.
format_tools
.
apply
(
content
=
tools
)[
0
]
if
tools
else
""
system_text
=
self
.
format_system
.
apply
(
content
=
(
system
+
tool_text
))[
0
]
if
i
>
0
and
i
%
2
==
0
:
elements
+=
self
.
format_separator
.
apply
()
if
message
[
"role"
]
==
Role
.
USER
.
value
:
elements
+=
self
.
format_user
.
apply
(
content
=
system_text
+
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
ASSISTANT
.
value
:
elements
+=
self
.
format_assistant
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
OBSERVATION
.
value
:
elements
+=
self
.
format_observation
.
apply
(
content
=
message
[
"content"
])
elif
message
[
"role"
]
==
Role
.
FUNCTION
.
value
:
elements
+=
self
.
format_function
.
apply
(
content
=
message
[
"content"
])
else
:
raise
NotImplementedError
(
"Unexpected role: {}"
.
format
(
message
[
"role"
]))
encoded_messages
.
append
(
self
.
_convert_elements_to_ids
(
tokenizer
,
elements
))
return
encoded_messages
TEMPLATES
:
Dict
[
str
,
Template
]
=
{}
def
_register_template
(
name
:
str
,
format_user
:
Optional
[
"Formatter"
]
=
None
,
format_assistant
:
Optional
[
"Formatter"
]
=
None
,
format_system
:
Optional
[
"Formatter"
]
=
None
,
format_function
:
Optional
[
"Formatter"
]
=
None
,
format_observation
:
Optional
[
"Formatter"
]
=
None
,
format_tools
:
Optional
[
"Formatter"
]
=
None
,
format_separator
:
Optional
[
"Formatter"
]
=
None
,
format_prefix
:
Optional
[
"Formatter"
]
=
None
,
default_system
:
str
=
""
,
stop_words
:
Sequence
[
str
]
=
[],
image_token
:
str
=
"<image>"
,
efficient_eos
:
bool
=
False
,
replace_eos
:
bool
=
False
,
)
->
None
:
r
"""
Registers a chat template.
To add the following chat template:
```
[HUMAN]:
user prompt here
[AI]:
model response here
[HUMAN]:
user prompt here
[AI]:
model response here
```
The corresponding code should be:
```
_register_template(
name="custom",
format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
format_separator=EmptyFormatter(slots=["\n\n"]),
efficient_eos=True,
)
```
"""
eos_slots
=
[]
if
efficient_eos
else
[{
"eos_token"
}]
template_class
=
Llama2Template
if
name
.
startswith
(
"llama2"
)
else
Template
default_user_formatter
=
StringFormatter
(
slots
=
[
"{{content}}"
])
default_assistant_formatter
=
StringFormatter
(
slots
=
[
"{{content}}"
]
+
eos_slots
)
default_function_formatter
=
FunctionFormatter
(
slots
=
eos_slots
,
tool_format
=
"default"
)
default_tool_formatter
=
ToolFormatter
(
tool_format
=
"default"
)
default_separator_formatter
=
EmptyFormatter
()
default_prefix_formatter
=
EmptyFormatter
()
TEMPLATES
[
name
]
=
template_class
(
format_user
=
format_user
or
default_user_formatter
,
format_assistant
=
format_assistant
or
default_assistant_formatter
,
format_system
=
format_system
or
default_user_formatter
,
format_function
=
format_function
or
default_function_formatter
,
format_observation
=
format_observation
or
format_user
or
default_user_formatter
,
format_tools
=
format_tools
or
default_tool_formatter
,
format_separator
=
format_separator
or
default_separator_formatter
,
format_prefix
=
format_prefix
or
default_prefix_formatter
,
default_system
=
default_system
,
stop_words
=
stop_words
,
image_token
=
image_token
,
efficient_eos
=
efficient_eos
,
replace_eos
=
replace_eos
,
)
def
_add_or_replace_eos_token
(
tokenizer
:
"PreTrainedTokenizer"
,
eos_token
:
str
)
->
None
:
is_added
=
tokenizer
.
eos_token_id
is
None
num_added_tokens
=
tokenizer
.
add_special_tokens
({
"eos_token"
:
eos_token
})
if
is_added
:
logger
.
info
(
"Add eos token: {}"
.
format
(
tokenizer
.
eos_token
))
else
:
logger
.
info
(
"Replace eos token: {}"
.
format
(
tokenizer
.
eos_token
))
if
num_added_tokens
>
0
:
logger
.
warning
(
"New tokens have been added, make sure `resize_vocab` is True."
)
def
_jinja_escape
(
content
:
str
)
->
str
:
return
content
.
replace
(
"'"
,
r
"\'"
)
def
_convert_slots_to_jinja
(
slots
:
"SLOTS"
,
tokenizer
:
"PreTrainedTokenizer"
,
placeholder
:
str
=
"content"
)
->
str
:
slot_items
=
[]
for
slot
in
slots
:
if
isinstance
(
slot
,
str
):
slot_pieces
=
slot
.
split
(
"{{content}}"
)
if
slot_pieces
[
0
]:
slot_items
.
append
(
"'"
+
_jinja_escape
(
slot_pieces
[
0
])
+
"'"
)
if
len
(
slot_pieces
)
>
1
:
slot_items
.
append
(
placeholder
)
if
slot_pieces
[
1
]:
slot_items
.
append
(
"'"
+
_jinja_escape
(
slot_pieces
[
1
])
+
"'"
)
elif
isinstance
(
slot
,
set
):
# do not use {{ eos_token }} since it may be replaced
if
"bos_token"
in
slot
and
tokenizer
.
bos_token_id
is
not
None
:
slot_items
.
append
(
"'"
+
tokenizer
.
bos_token
+
"'"
)
elif
"eos_token"
in
slot
and
tokenizer
.
eos_token_id
is
not
None
:
slot_items
.
append
(
"'"
+
tokenizer
.
eos_token
+
"'"
)
elif
isinstance
(
slot
,
dict
):
raise
ValueError
(
"Dict is not supported."
)
return
" + "
.
join
(
slot_items
)
def
_get_jinja_template
(
template
:
"Template"
,
tokenizer
:
"PreTrainedTokenizer"
)
->
str
:
jinja_template
=
""
prefix
=
_convert_slots_to_jinja
(
template
.
format_prefix
.
apply
(),
tokenizer
)
if
prefix
:
jinja_template
+=
"{{ "
+
prefix
+
" }}"
if
template
.
default_system
:
jinja_template
+=
"{% set system_message = '"
+
_jinja_escape
(
template
.
default_system
)
+
"' %}"
jinja_template
+=
(
"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
)
system_message
=
_convert_slots_to_jinja
(
template
.
format_system
.
apply
(),
tokenizer
,
placeholder
=
"system_message"
)
if
not
isinstance
(
template
,
Llama2Template
):
jinja_template
+=
"{% if system_message is defined %}{{ "
+
system_message
+
" }}{% endif %}"
jinja_template
+=
"{% for message in messages %}"
jinja_template
+=
"{% set content = message['content'] %}"
if
isinstance
(
template
,
Llama2Template
):
jinja_template
+=
"{% if loop.index0 == 0 and system_message is defined %}"
jinja_template
+=
"{% set content = "
+
system_message
+
" + message['content'] %}"
jinja_template
+=
"{% endif %}"
jinja_template
+=
"{% if message['role'] == 'user' %}"
user_message
=
_convert_slots_to_jinja
(
template
.
format_user
.
apply
(),
tokenizer
)
jinja_template
+=
"{{ "
+
user_message
+
" }}"
jinja_template
+=
"{% elif message['role'] == 'assistant' %}"
assistant_message
=
_convert_slots_to_jinja
(
template
.
format_assistant
.
apply
()
+
template
.
format_separator
.
apply
(),
tokenizer
)
jinja_template
+=
"{{ "
+
assistant_message
+
" }}"
jinja_template
+=
"{% endif %}"
jinja_template
+=
"{% endfor %}"
return
jinja_template
def
get_template_and_fix_tokenizer
(
tokenizer
:
"PreTrainedTokenizer"
,
name
:
Optional
[
str
]
=
None
,
tool_format
:
Optional
[
str
]
=
None
,
)
->
Template
:
if
name
is
None
:
template
=
TEMPLATES
[
"empty"
]
# placeholder
else
:
template
=
TEMPLATES
.
get
(
name
,
None
)
if
template
is
None
:
raise
ValueError
(
"Template {} does not exist."
.
format
(
name
))
if
tool_format
is
not
None
:
logger
.
info
(
"Using tool format: {}."
.
format
(
tool_format
))
eos_slots
=
[]
if
template
.
efficient_eos
else
[{
"eos_token"
}]
template
.
format_tools
=
ToolFormatter
(
tool_format
=
tool_format
)
template
.
format_function
=
FunctionFormatter
(
slots
=
eos_slots
,
tool_format
=
tool_format
)
stop_words
=
template
.
stop_words
if
template
.
replace_eos
:
if
not
stop_words
:
raise
ValueError
(
"Stop words are required to replace the EOS token."
)
_add_or_replace_eos_token
(
tokenizer
,
eos_token
=
stop_words
[
0
])
stop_words
=
stop_words
[
1
:]
if
tokenizer
.
eos_token_id
is
None
:
_add_or_replace_eos_token
(
tokenizer
,
eos_token
=
"<|endoftext|>"
)
if
tokenizer
.
pad_token_id
is
None
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
logger
.
info
(
"Add pad token: {}"
.
format
(
tokenizer
.
pad_token
))
if
stop_words
:
num_added_tokens
=
tokenizer
.
add_special_tokens
(
dict
(
additional_special_tokens
=
stop_words
),
replace_additional_special_tokens
=
False
)
logger
.
info
(
"Add {} to stop words."
.
format
(
","
.
join
(
stop_words
)))
if
num_added_tokens
>
0
:
logger
.
warning
(
"New tokens have been added, make sure `resize_vocab` is True."
)
try
:
tokenizer
.
chat_template
=
_get_jinja_template
(
template
,
tokenizer
)
except
ValueError
:
logger
.
info
(
"Cannot add this chat template to tokenizer."
)
return
template
_register_template
(
name
=
"alpaca"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n\n
### Response:
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n\n
"
]),
default_system
=
(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.
\n\n
"
),
)
_register_template
(
name
=
"aquila"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}###Assistant:"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"###"
]),
default_system
=
(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
stop_words
=
[
"</s>"
],
efficient_eos
=
True
,
)
_register_template
(
name
=
"atom"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"Human: {{content}}
\n
"
,
{
"eos_token"
},
{
"bos_token"
},
"Assistant:"
]
),
format_assistant
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
,
{
"eos_token"
}]),
)
_register_template
(
name
=
"baichuan"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<reserved_102>"
},
"{{content}}"
,
{
"token"
:
"<reserved_103>"
}]),
efficient_eos
=
True
,
)
_register_template
(
name
=
"baichuan2"
,
format_user
=
StringFormatter
(
slots
=
[
"<reserved_106>{{content}}<reserved_107>"
]),
efficient_eos
=
True
,
)
_register_template
(
name
=
"belle"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Belle: "
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
_register_template
(
name
=
"bluelm"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"[|Human|]:"
},
"{{content}}"
,
{
"token"
:
"[|AI|]:"
}]),
)
_register_template
(
name
=
"breeze"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}} [/INST] "
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
efficient_eos
=
True
,
)
_register_template
(
name
=
"chatglm2"
,
format_user
=
StringFormatter
(
slots
=
[
"[Round {{idx}}]
\n\n
问:{{content}}
\n\n
答:"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
efficient_eos
=
True
,
)
_register_template
(
name
=
"chatglm3"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<|user|>"
},
"
\n
"
,
"{{content}}"
,
{
"token"
:
"<|assistant|>"
}]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
"
,
"{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[{
"token"
:
"<|system|>"
},
"
\n
"
,
"{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[{
"token"
:
"<|observation|>"
},
"
\n
"
,
"{{content}}"
,
{
"token"
:
"<|assistant|>"
}]
),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
],
efficient_eos
=
True
,
)
_register_template
(
name
=
"chatml"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
stop_words
=
[
"<|im_end|>"
,
"<|im_start|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"chatml_de"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
default_system
=
"Du bist ein freundlicher und hilfsbereiter KI-Assistent."
,
stop_words
=
[
"<|im_end|>"
,
"<|im_start|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"codegeex2"
,
format_prefix
=
EmptyFormatter
(
slots
=
[{
"token"
:
"[gMASK]"
},
{
"token"
:
"sop"
}]),
)
_register_template
(
name
=
"codegeex4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>
\n
"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
default_system
=
(
"你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,"
"并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。"
),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
],
efficient_eos
=
True
,
)
_register_template
(
name
=
"cohere"
,
format_user
=
StringFormatter
(
slots
=
[
(
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
)
]
),
format_system
=
StringFormatter
(
slots
=
[
"<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
_register_template
(
name
=
"cpm"
,
format_user
=
StringFormatter
(
slots
=
[
"<用户>{{content}}<AI>"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
_register_template
(
name
=
"dbrx"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
default_system
=
(
"You are DBRX, created by Databricks. You were last updated in December 2023. "
"You answer questions based on information available up to that point.
\n
"
"YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
"responses to more complex and open-ended questions.
\n
You assist with various tasks, "
"from writing to coding (using markdown for code blocks — remember to use ``` with "
"code, JSON, and tables).
\n
(You do not have real-time data access or code execution "
"capabilities. You avoid stereotyping and provide balanced perspectives on "
"controversial topics. You do not provide song lyrics, poems, or news articles and "
"do not divulge details of your training data.)
\n
This is your system prompt, "
"guiding your responses. Do not reference it, just respond to the user. If you find "
"yourself talking about this message, stop. You should be responding appropriately "
"and usually that means not mentioning this.
\n
YOU DO NOT MENTION ANY OF THIS INFORMATION "
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"deepseek"
,
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n\n
Assistant:"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
_register_template
(
name
=
"deepseekcoder"
,
format_user
=
StringFormatter
(
slots
=
[
"### Instruction:
\n
{{content}}
\n
### Response:"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
default_system
=
(
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
"developed by Deepseek Company, and you only answer questions related to computer science. "
"For politically sensitive questions, security and privacy issues, "
"and other non-computer science questions, you will refuse to answer
\n
"
),
)
_register_template
(
name
=
"default"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n
Assistant:"
]),
format_system
=
StringFormatter
(
slots
=
[
"{{content}}
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
)
_register_template
(
name
=
"empty"
,
efficient_eos
=
True
,
)
_register_template
(
name
=
"falcon"
,
format_user
=
StringFormatter
(
slots
=
[
"User: {{content}}
\n
Falcon:"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
efficient_eos
=
True
,
)
_register_template
(
name
=
"fewshot"
,
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n\n
"
]),
efficient_eos
=
True
,
)
_register_template
(
name
=
"gemma"
,
format_user
=
StringFormatter
(
slots
=
[
"<start_of_turn>user
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<start_of_turn>tool
\n
{{content}}<end_of_turn>
\n
<start_of_turn>model
\n
"
]
),
format_separator
=
EmptyFormatter
(
slots
=
[
"<end_of_turn>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
efficient_eos
=
True
,
)
_register_template
(
name
=
"glm4"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>"
]),
format_assistant
=
StringFormatter
(
slots
=
[
"
\n
{{content}}"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
]),
format_function
=
FunctionFormatter
(
slots
=
[],
tool_format
=
"glm4"
),
format_observation
=
StringFormatter
(
slots
=
[
"<|observation|>
\n
{{content}}<|assistant|>"
]),
format_tools
=
ToolFormatter
(
tool_format
=
"glm4"
),
format_prefix
=
EmptyFormatter
(
slots
=
[
"[gMASK]<sop>"
]),
stop_words
=
[
"<|user|>"
,
"<|observation|>"
],
efficient_eos
=
True
,
)
_register_template
(
name
=
"intern"
,
format_user
=
StringFormatter
(
slots
=
[
"<|User|>:{{content}}
\n
<|Bot|>:"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|System|>:{{content}}
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"<eoa>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<eoa>"
],
efficient_eos
=
True
,
# internlm tokenizer cannot set eos_token_id
)
_register_template
(
name
=
"intern2"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"<|im_end|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|im_end|>"
],
efficient_eos
=
True
,
# internlm2 tokenizer cannot set eos_token_id
)
_register_template
(
name
=
"llama2"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"[INST] {{content}} [/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"<<SYS>>
\n
{{content}}
\n
<</SYS>>
\n\n
"
]),
)
_register_template
(
name
=
"llama2_zh"
,
format_user
=
StringFormatter
(
slots
=
[{
"bos_token"
},
"[INST] {{content}} [/INST]"
]),
format_system
=
StringFormatter
(
slots
=
[
"<<SYS>>
\n
{{content}}
\n
<</SYS>>
\n\n
"
]),
default_system
=
"You are a helpful assistant. 你是一个乐于助人的助手。"
,
)
_register_template
(
name
=
"llama3"
,
format_user
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>user<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_system
=
StringFormatter
(
slots
=
[
"<|start_header_id|>system<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
]),
format_observation
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>tool<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>
\n\n
"
)
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"mistral"
,
format_user
=
StringFormatter
(
slots
=
[
"[INST] {{content}} [/INST]"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
_register_template
(
name
=
"olmo"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|assistant|>
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"eos_token"
}]),
)
_register_template
(
name
=
"openchat"
,
format_user
=
StringFormatter
(
slots
=
[
"GPT4 Correct User: {{content}}"
,
{
"eos_token"
},
"GPT4 Correct Assistant:"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
_register_template
(
name
=
"openchat-3.6"
,
format_user
=
StringFormatter
(
slots
=
[
(
"<|start_header_id|>GPT4 Correct User<|end_header_id|>
\n\n
{{content}}<|eot_id|>"
"<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>
\n\n
"
)
]
),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|eot_id|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"orion"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Assistant: "
,
{
"eos_token"
}]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
)
_register_template
(
name
=
"phi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
format_prefix
=
EmptyFormatter
(
slots
=
[{
"bos_token"
}]),
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"qwen"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_observation
=
StringFormatter
(
slots
=
[
"<|im_start|>tool
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
default_system
=
"You are a helpful assistant."
,
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"solar"
,
format_user
=
StringFormatter
(
slots
=
[
"### User:
\n
{{content}}
\n\n
### Assistant:
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"### System:
\n
{{content}}
\n\n
"
]),
efficient_eos
=
True
,
)
_register_template
(
name
=
"starchat"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}<|end|>
\n
<|assistant|>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}<|end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
stop_words
=
[
"<|end|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"telechat"
,
format_user
=
StringFormatter
(
slots
=
[
"<_user>{{content}}<_bot>"
]),
format_system
=
StringFormatter
(
slots
=
[
"<_system>{{content}}<_end>"
]),
stop_words
=
[
"<_end>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"vicuna"
,
format_user
=
StringFormatter
(
slots
=
[
"USER: {{content}} ASSISTANT:"
]),
default_system
=
(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
)
_register_template
(
name
=
"xuanyuan"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}} Assistant:"
]),
default_system
=
(
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
"不安全、有争议、政治敏感等相关的话题、问题和指示。
\n
"
),
)
_register_template
(
name
=
"xverse"
,
format_user
=
StringFormatter
(
slots
=
[
"Human: {{content}}
\n\n
Assistant: "
]),
)
_register_template
(
name
=
"yayi"
,
format_user
=
StringFormatter
(
slots
=
[{
"token"
:
"<|Human|>"
},
":
\n
{{content}}
\n\n
"
,
{
"token"
:
"<|YaYi|>"
},
":"
]),
format_system
=
StringFormatter
(
slots
=
[{
"token"
:
"<|System|>"
},
":
\n
{{content}}
\n\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n\n
"
]),
default_system
=
(
"You are a helpful, respectful and honest assistant named YaYi "
"developed by Beijing Wenge Technology Co.,Ltd. "
"Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, "
"racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.
\n\n
"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
stop_words
=
[
"<|End|>"
],
)
_register_template
(
name
=
"yi"
,
format_user
=
StringFormatter
(
slots
=
[
"<|im_start|>user
\n
{{content}}<|im_end|>
\n
<|im_start|>assistant
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|im_start|>system
\n
{{content}}<|im_end|>
\n
"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
stop_words
=
[
"<|im_end|>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"yi_vl"
,
format_user
=
StringFormatter
(
slots
=
[
"### Human: {{content}}
\n
### Assistant:"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
default_system
=
(
"This is a chat between an inquisitive human and an AI assistant. "
"Assume the role of the AI assistant. Read all the images carefully, "
"and respond to the human's questions with informative, helpful, detailed and polite answers. "
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
"仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。
\n\n
"
),
stop_words
=
[
"###"
],
efficient_eos
=
True
,
)
_register_template
(
name
=
"yuan"
,
format_user
=
StringFormatter
(
slots
=
[
"{{content}}"
,
{
"token"
:
"<sep>"
}]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
stop_words
=
[
"<eod>"
],
replace_eos
=
True
,
)
_register_template
(
name
=
"zephyr"
,
format_user
=
StringFormatter
(
slots
=
[
"<|user|>
\n
{{content}}"
,
{
"eos_token"
},
"<|assistant|>
\n
"
]),
format_system
=
StringFormatter
(
slots
=
[
"<|system|>
\n
{{content}}"
,
{
"eos_token"
}]),
default_system
=
"You are Zephyr, a helpful assistant."
,
)
_register_template
(
name
=
"ziya"
,
format_user
=
StringFormatter
(
slots
=
[
"<human>:{{content}}
\n
<bot>:"
]),
format_separator
=
EmptyFormatter
(
slots
=
[
"
\n
"
]),
)
LLaMA-Factory/src/llamafactory/data/tool_utils.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
re
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
.data_utils
import
SLOTS
DEFAULT_TOOL_PROMPT
=
(
"You have access to the following tools:
\n
{tool_text}"
"Use the following format if using a tool:
\n
"
"```
\n
"
"Action: tool name (one of [{tool_names}])
\n
"
"Action Input: the input to the tool, in a JSON format representing the kwargs "
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```)
\n
"""
"```
\n
"
)
GLM4_TOOL_PROMPT
=
(
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
)
@
dataclass
class
ToolUtils
(
ABC
):
@
staticmethod
@
abstractmethod
def
get_function_slots
()
->
SLOTS
:
...
@
staticmethod
@
abstractmethod
def
tool_formatter
(
tools
:
List
[
Dict
[
str
,
Any
]])
->
str
:
...
@
staticmethod
@
abstractmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
List
[
Tuple
[
str
,
str
]]]:
...
class
DefaultToolUtils
(
ToolUtils
):
@
staticmethod
def
get_function_slots
()
->
SLOTS
:
return
[
"Action: {{name}}
\n
Action Input: {{arguments}}
\n
"
]
@
staticmethod
def
tool_formatter
(
tools
:
List
[
Dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
tool_names
=
[]
for
tool
in
tools
:
param_text
=
""
for
name
,
param
in
tool
[
"parameters"
][
"properties"
].
items
():
required
,
enum
,
items
=
""
,
""
,
""
if
name
in
tool
[
"parameters"
].
get
(
"required"
,
[]):
required
=
", required"
if
param
.
get
(
"enum"
,
None
):
enum
=
", should be one of [{}]"
.
format
(
", "
.
join
(
param
[
"enum"
]))
if
param
.
get
(
"items"
,
None
):
items
=
", where each item should be {}"
.
format
(
param
[
"items"
].
get
(
"type"
,
""
))
param_text
+=
" - {name} ({type}{required}): {desc}{enum}{items}
\n
"
.
format
(
name
=
name
,
type
=
param
.
get
(
"type"
,
""
),
required
=
required
,
desc
=
param
.
get
(
"description"
,
""
),
enum
=
enum
,
items
=
items
,
)
tool_text
+=
"> Tool Name: {name}
\n
Tool Description: {desc}
\n
Tool Args:
\n
{args}
\n
"
.
format
(
name
=
tool
[
"name"
],
desc
=
tool
.
get
(
"description"
,
""
),
args
=
param_text
)
tool_names
.
append
(
tool
[
"name"
])
return
DEFAULT_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
,
tool_names
=
", "
.
join
(
tool_names
))
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
List
[
Tuple
[
str
,
str
]]]:
regex
=
re
.
compile
(
r
"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)"
,
re
.
DOTALL
)
action_match
:
List
[
Tuple
[
str
,
str
]]
=
re
.
findall
(
regex
,
content
)
if
not
action_match
:
return
content
results
=
[]
for
match
in
action_match
:
tool_name
=
match
[
0
].
strip
()
tool_input
=
match
[
1
].
strip
().
strip
(
'"'
).
strip
(
"```"
)
try
:
arguments
=
json
.
loads
(
tool_input
)
results
.
append
((
tool_name
,
json
.
dumps
(
arguments
,
ensure_ascii
=
False
)))
except
json
.
JSONDecodeError
:
return
content
return
results
class
GLM4ToolUtils
(
ToolUtils
):
@
staticmethod
def
get_function_slots
()
->
SLOTS
:
return
[
"{{name}}
\n
{{arguments}}"
]
@
staticmethod
def
tool_formatter
(
tools
:
List
[
Dict
[
str
,
Any
]])
->
str
:
tool_text
=
""
for
tool
in
tools
:
tool_text
+=
"
\n\n
## {name}
\n\n
{body}
\n
在调用上述函数时,请使用 Json 格式表示调用的参数。"
.
format
(
name
=
tool
[
"name"
],
body
=
json
.
dumps
(
tool
,
indent
=
4
,
ensure_ascii
=
False
)
)
return
GLM4_TOOL_PROMPT
.
format
(
tool_text
=
tool_text
)
@
staticmethod
def
tool_extractor
(
content
:
str
)
->
Union
[
str
,
List
[
Tuple
[
str
,
str
]]]:
if
"
\n
"
not
in
content
:
return
content
tool_name
,
tool_input
=
content
.
split
(
"
\n
"
,
maxsplit
=
1
)
try
:
arguments
=
json
.
loads
(
tool_input
)
except
json
.
JSONDecodeError
:
return
content
return
[(
tool_name
,
json
.
dumps
(
arguments
,
ensure_ascii
=
False
))]
LLaMA-Factory/src/llamafactory/eval/__init__.py
0 → 100644
View file @
032b90a1
LLaMA-Factory/src/llamafactory/eval/evaluator.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# This code is inspired by the Dan's test library.
# https://github.com/hendrycks/test/blob/master/evaluate_flan.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# MIT License
#
# Copyright (c) 2020 Dan Hendrycks
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import
json
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
import
numpy
as
np
import
torch
from
datasets
import
load_dataset
from
tqdm
import
tqdm
,
trange
from
transformers.utils
import
cached_file
from
..data
import
get_template_and_fix_tokenizer
from
..extras.constants
import
CHOICES
,
SUBJECTS
from
..hparams
import
get_eval_args
from
..model
import
load_model
,
load_tokenizer
from
.template
import
get_eval_template
class
Evaluator
:
def
__init__
(
self
,
args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
)
->
None
:
self
.
model_args
,
self
.
data_args
,
self
.
eval_args
,
finetuning_args
=
get_eval_args
(
args
)
self
.
tokenizer
=
load_tokenizer
(
self
.
model_args
)[
"tokenizer"
]
self
.
tokenizer
.
padding_side
=
"right"
# avoid overflow issue in batched inference for llama2
self
.
template
=
get_template_and_fix_tokenizer
(
self
.
tokenizer
,
self
.
data_args
.
template
)
self
.
model
=
load_model
(
self
.
tokenizer
,
self
.
model_args
,
finetuning_args
)
self
.
eval_template
=
get_eval_template
(
self
.
eval_args
.
lang
)
self
.
choice_inputs
=
[
self
.
tokenizer
.
encode
(
ch
,
add_special_tokens
=
False
)[
-
1
]
for
ch
in
CHOICES
]
@
torch
.
inference_mode
()
def
batch_inference
(
self
,
batch_input
:
Dict
[
str
,
torch
.
Tensor
])
->
List
[
str
]:
logits
=
self
.
model
(
**
batch_input
).
logits
lengths
=
torch
.
sum
(
batch_input
[
"attention_mask"
],
dim
=-
1
)
word_probs
=
torch
.
stack
([
logits
[
i
,
lengths
[
i
]
-
1
]
for
i
in
range
(
len
(
lengths
))],
dim
=
0
)
choice_probs
=
torch
.
nn
.
functional
.
softmax
(
word_probs
[:,
self
.
choice_inputs
],
dim
=-
1
).
detach
()
return
[
chr
(
ord
(
"A"
)
+
offset
.
item
())
for
offset
in
torch
.
argmax
(
choice_probs
,
dim
=-
1
)]
def
eval
(
self
)
->
None
:
eval_task
=
self
.
eval_args
.
task
.
split
(
"_"
)[
0
]
eval_split
=
self
.
eval_args
.
task
.
split
(
"_"
)[
1
]
mapping
=
cached_file
(
path_or_repo_id
=
os
.
path
.
join
(
self
.
eval_args
.
task_dir
,
eval_task
),
filename
=
"mapping.json"
,
cache_dir
=
self
.
model_args
.
cache_dir
,
token
=
self
.
model_args
.
hf_hub_token
,
)
with
open
(
mapping
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
categorys
:
Dict
[
str
,
Dict
[
str
,
str
]]
=
json
.
load
(
f
)
category_corrects
=
{
subj
:
np
.
array
([],
dtype
=
"bool"
)
for
subj
in
SUBJECTS
}
pbar
=
tqdm
(
categorys
.
keys
(),
desc
=
"Processing subjects"
,
position
=
0
)
results
=
{}
for
subject
in
pbar
:
dataset
=
load_dataset
(
path
=
os
.
path
.
join
(
self
.
eval_args
.
task_dir
,
eval_task
),
name
=
subject
,
cache_dir
=
self
.
model_args
.
cache_dir
,
download_mode
=
self
.
eval_args
.
download_mode
,
token
=
self
.
model_args
.
hf_hub_token
,
trust_remote_code
=
True
,
)
pbar
.
set_postfix_str
(
categorys
[
subject
][
"name"
])
inputs
,
outputs
,
labels
=
[],
[],
[]
for
i
in
trange
(
len
(
dataset
[
eval_split
]),
desc
=
"Formatting batches"
,
position
=
1
,
leave
=
False
):
support_set
=
(
dataset
[
"train"
].
shuffle
().
select
(
range
(
min
(
self
.
eval_args
.
n_shot
,
len
(
dataset
[
"train"
]))))
)
messages
=
self
.
eval_template
.
format_example
(
target_data
=
dataset
[
eval_split
][
i
],
support_set
=
support_set
,
subject_name
=
categorys
[
subject
][
"name"
],
)
input_ids
,
_
=
self
.
template
.
encode_oneturn
(
tokenizer
=
self
.
tokenizer
,
messages
=
messages
)
inputs
.
append
({
"input_ids"
:
input_ids
,
"attention_mask"
:
[
1
]
*
len
(
input_ids
)})
labels
.
append
(
messages
[
-
1
][
"content"
])
for
i
in
trange
(
0
,
len
(
inputs
),
self
.
eval_args
.
batch_size
,
desc
=
"Predicting batches"
,
position
=
1
,
leave
=
False
):
batch_input
=
self
.
tokenizer
.
pad
(
inputs
[
i
:
i
+
self
.
eval_args
.
batch_size
],
return_attention_mask
=
True
,
return_tensors
=
"pt"
).
to
(
self
.
model
.
device
)
preds
=
self
.
batch_inference
(
batch_input
)
outputs
+=
preds
corrects
=
np
.
array
(
outputs
)
==
np
.
array
(
labels
)
category_name
=
categorys
[
subject
][
"category"
]
category_corrects
[
category_name
]
=
np
.
concatenate
([
category_corrects
[
category_name
],
corrects
],
axis
=
0
)
category_corrects
[
"Average"
]
=
np
.
concatenate
([
category_corrects
[
"Average"
],
corrects
],
axis
=
0
)
results
[
subject
]
=
{
str
(
i
):
outputs
[
i
]
for
i
in
range
(
len
(
outputs
))}
pbar
.
close
()
self
.
_save_results
(
category_corrects
,
results
)
def
_save_results
(
self
,
category_corrects
:
Dict
[
str
,
np
.
ndarray
],
results
:
Dict
[
str
,
Dict
[
int
,
str
]])
->
None
:
score_info
=
"
\n
"
.
join
(
[
"{:>15}: {:.2f}"
.
format
(
category_name
,
100
*
np
.
mean
(
category_correct
))
for
category_name
,
category_correct
in
category_corrects
.
items
()
if
len
(
category_correct
)
]
)
print
(
score_info
)
if
self
.
eval_args
.
save_dir
is
not
None
:
os
.
makedirs
(
self
.
eval_args
.
save_dir
,
exist_ok
=
False
)
with
open
(
os
.
path
.
join
(
self
.
eval_args
.
save_dir
,
"results.json"
),
"w"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
json
.
dump
(
results
,
f
,
indent
=
2
)
with
open
(
os
.
path
.
join
(
self
.
eval_args
.
save_dir
,
"results.log"
),
"w"
,
encoding
=
"utf-8"
,
newline
=
"
\n
"
)
as
f
:
f
.
write
(
score_info
)
def
run_eval
()
->
None
:
Evaluator
().
eval
()
LLaMA-Factory/src/llamafactory/eval/template.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Sequence
,
Tuple
from
..data
import
Role
from
..extras.constants
import
CHOICES
@
dataclass
class
EvalTemplate
:
system
:
str
choice
:
str
answer
:
str
def
_parse_example
(
self
,
example
:
Dict
[
str
,
str
])
->
Tuple
[
str
,
str
]:
r
"""
input: a dict with keys {"question", "A", "B", "C", "D", "answer"}
output: a tuple of (prompt, response)
"""
candidates
=
[
self
.
choice
.
format
(
choice
=
ch
,
content
=
example
[
ch
])
for
ch
in
CHOICES
if
ch
in
example
]
return
""
.
join
([
example
[
"question"
]]
+
candidates
+
[
self
.
answer
]),
example
[
"answer"
]
def
format_example
(
self
,
target_data
:
Dict
[
str
,
str
],
support_set
:
Sequence
[
Dict
[
str
,
str
]],
subject_name
:
str
)
->
List
[
Dict
[
str
,
str
]]:
r
"""
Converts dataset examples to messages.
"""
messages
=
[]
for
k
in
range
(
len
(
support_set
)):
prompt
,
response
=
self
.
_parse_example
(
support_set
[
k
])
messages
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
prompt
})
messages
.
append
({
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
response
})
prompt
,
response
=
self
.
_parse_example
(
target_data
)
messages
.
append
({
"role"
:
Role
.
USER
.
value
,
"content"
:
prompt
})
messages
.
append
({
"role"
:
Role
.
ASSISTANT
.
value
,
"content"
:
response
})
messages
[
0
][
"content"
]
=
self
.
system
.
format
(
subject
=
subject_name
)
+
messages
[
0
][
"content"
]
return
messages
eval_templates
:
Dict
[
str
,
"EvalTemplate"
]
=
{}
def
_register_eval_template
(
name
:
str
,
system
:
str
,
choice
:
str
,
answer
:
str
)
->
None
:
eval_templates
[
name
]
=
EvalTemplate
(
system
=
system
,
choice
=
choice
,
answer
=
answer
)
def
get_eval_template
(
name
:
str
)
->
"EvalTemplate"
:
eval_template
=
eval_templates
.
get
(
name
,
None
)
assert
eval_template
is
not
None
,
"Template {} does not exist."
.
format
(
name
)
return
eval_template
_register_eval_template
(
name
=
"en"
,
system
=
"The following are multiple choice questions (with answers) about {subject}.
\n\n
"
,
choice
=
"
\n
{choice}. {content}"
,
answer
=
"
\n
Answer:"
,
)
_register_eval_template
(
name
=
"zh"
,
system
=
"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。
\n\n
"
,
choice
=
"
\n
{choice}. {content}"
,
answer
=
"
\n
答案:"
,
)
LLaMA-Factory/src/llamafactory/extras/__init__.py
0 → 100644
View file @
032b90a1
LLaMA-Factory/src/llamafactory/extras/constants.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
collections
import
OrderedDict
,
defaultdict
from
enum
import
Enum
from
typing
import
Dict
,
Optional
from
peft.utils
import
SAFETENSORS_WEIGHTS_NAME
as
SAFE_ADAPTER_WEIGHTS_NAME
from
peft.utils
import
WEIGHTS_NAME
as
ADAPTER_WEIGHTS_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
CHECKPOINT_NAMES
=
{
SAFE_ADAPTER_WEIGHTS_NAME
,
ADAPTER_WEIGHTS_NAME
,
SAFE_WEIGHTS_INDEX_NAME
,
SAFE_WEIGHTS_NAME
,
WEIGHTS_INDEX_NAME
,
WEIGHTS_NAME
,
}
CHOICES
=
[
"A"
,
"B"
,
"C"
,
"D"
]
DATA_CONFIG
=
"dataset_info.json"
DEFAULT_TEMPLATE
=
defaultdict
(
str
)
FILEEXT2TYPE
=
{
"arrow"
:
"arrow"
,
"csv"
:
"csv"
,
"json"
:
"json"
,
"jsonl"
:
"json"
,
"parquet"
:
"parquet"
,
"txt"
:
"text"
,
}
IGNORE_INDEX
=
-
100
LAYERNORM_NAMES
=
{
"norm"
,
"ln"
}
LLAMABOARD_CONFIG
=
"llamaboard_config.yaml"
METHODS
=
[
"full"
,
"freeze"
,
"lora"
]
MOD_SUPPORTED_MODELS
=
{
"bloom"
,
"falcon"
,
"gemma"
,
"llama"
,
"mistral"
,
"mixtral"
,
"phi"
,
"starcoder2"
}
PEFT_METHODS
=
{
"lora"
}
RUNNING_LOG
=
"running_log.txt"
SUBJECTS
=
[
"Average"
,
"STEM"
,
"Social Sciences"
,
"Humanities"
,
"Other"
]
SUPPORTED_MODELS
=
OrderedDict
()
TRAINER_LOG
=
"trainer_log.jsonl"
TRAINING_ARGS
=
"training_args.yaml"
TRAINING_STAGES
=
{
"Supervised Fine-Tuning"
:
"sft"
,
"Reward Modeling"
:
"rm"
,
"PPO"
:
"ppo"
,
"DPO"
:
"dpo"
,
"KTO"
:
"kto"
,
"Pre-Training"
:
"pt"
,
}
STAGES_USE_PAIR_DATA
=
{
"rm"
,
"dpo"
}
SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
=
{
"cohere"
,
"falcon"
,
"gemma"
,
"gemma2"
,
"llama"
,
"mistral"
,
"phi"
,
"phi3"
,
"qwen2"
,
"starcoder2"
,
}
SUPPORTED_CLASS_FOR_S2ATTN
=
{
"llama"
}
V_HEAD_WEIGHTS_NAME
=
"value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME
=
"value_head.safetensors"
VISION_MODELS
=
set
()
class
DownloadSource
(
str
,
Enum
):
DEFAULT
=
"hf"
MODELSCOPE
=
"ms"
def
register_model_group
(
models
:
Dict
[
str
,
Dict
[
DownloadSource
,
str
]],
template
:
Optional
[
str
]
=
None
,
vision
:
bool
=
False
,
)
->
None
:
prefix
=
None
for
name
,
path
in
models
.
items
():
if
prefix
is
None
:
prefix
=
name
.
split
(
"-"
)[
0
]
else
:
assert
prefix
==
name
.
split
(
"-"
)[
0
],
"prefix should be identical."
SUPPORTED_MODELS
[
name
]
=
path
if
template
is
not
None
:
DEFAULT_TEMPLATE
[
prefix
]
=
template
if
vision
:
VISION_MODELS
.
add
(
prefix
)
register_model_group
(
models
=
{
"Aya-23-8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"CohereForAI/aya-23-8B"
,
},
"Aya-23-35B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"CohereForAI/aya-23-35B"
,
},
},
template
=
"cohere"
,
)
register_model_group
(
models
=
{
"Baichuan-7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"baichuan-inc/Baichuan-7B"
,
DownloadSource
.
MODELSCOPE
:
"baichuan-inc/baichuan-7B"
,
},
"Baichuan-13B-Base"
:
{
DownloadSource
.
DEFAULT
:
"baichuan-inc/Baichuan-13B-Base"
,
DownloadSource
.
MODELSCOPE
:
"baichuan-inc/Baichuan-13B-Base"
,
},
"Baichuan-13B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"baichuan-inc/Baichuan-13B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"baichuan-inc/Baichuan-13B-Chat"
,
},
},
template
=
"baichuan"
,
)
register_model_group
(
models
=
{
"Baichuan2-7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"baichuan-inc/Baichuan2-7B-Base"
,
DownloadSource
.
MODELSCOPE
:
"baichuan-inc/Baichuan2-7B-Base"
,
},
"Baichuan2-13B-Base"
:
{
DownloadSource
.
DEFAULT
:
"baichuan-inc/Baichuan2-13B-Base"
,
DownloadSource
.
MODELSCOPE
:
"baichuan-inc/Baichuan2-13B-Base"
,
},
"Baichuan2-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"baichuan-inc/Baichuan2-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"baichuan-inc/Baichuan2-7B-Chat"
,
},
"Baichuan2-13B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"baichuan-inc/Baichuan2-13B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"baichuan-inc/Baichuan2-13B-Chat"
,
},
},
template
=
"baichuan2"
,
)
register_model_group
(
models
=
{
"BLOOM-560M"
:
{
DownloadSource
.
DEFAULT
:
"bigscience/bloom-560m"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/bloom-560m"
,
},
"BLOOM-3B"
:
{
DownloadSource
.
DEFAULT
:
"bigscience/bloom-3b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/bloom-3b"
,
},
"BLOOM-7B1"
:
{
DownloadSource
.
DEFAULT
:
"bigscience/bloom-7b1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/bloom-7b1"
,
},
},
)
register_model_group
(
models
=
{
"BLOOMZ-560M"
:
{
DownloadSource
.
DEFAULT
:
"bigscience/bloomz-560m"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/bloomz-560m"
,
},
"BLOOMZ-3B"
:
{
DownloadSource
.
DEFAULT
:
"bigscience/bloomz-3b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/bloomz-3b"
,
},
"BLOOMZ-7B1-mt"
:
{
DownloadSource
.
DEFAULT
:
"bigscience/bloomz-7b1-mt"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/bloomz-7b1-mt"
,
},
},
)
register_model_group
(
models
=
{
"BlueLM-7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"vivo-ai/BlueLM-7B-Base"
,
DownloadSource
.
MODELSCOPE
:
"vivo-ai/BlueLM-7B-Base"
,
},
"BlueLM-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"vivo-ai/BlueLM-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"vivo-ai/BlueLM-7B-Chat"
,
},
},
template
=
"bluelm"
,
)
register_model_group
(
models
=
{
"Breeze-7B"
:
{
DownloadSource
.
DEFAULT
:
"MediaTek-Research/Breeze-7B-Base-v1_0"
,
},
"Breeze-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"MediaTek-Research/Breeze-7B-Instruct-v1_0"
,
},
},
template
=
"breeze"
,
)
register_model_group
(
models
=
{
"ChatGLM2-6B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/chatglm2-6b"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/chatglm2-6b"
,
}
},
template
=
"chatglm2"
,
)
register_model_group
(
models
=
{
"ChatGLM3-6B-Base"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/chatglm3-6b-base"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/chatglm3-6b-base"
,
},
"ChatGLM3-6B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/chatglm3-6b"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/chatglm3-6b"
,
},
},
template
=
"chatglm3"
,
)
register_model_group
(
models
=
{
"ChineseLLaMA2-1.3B"
:
{
DownloadSource
.
DEFAULT
:
"hfl/chinese-llama-2-1.3b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/chinese-llama-2-1.3b"
,
},
"ChineseLLaMA2-7B"
:
{
DownloadSource
.
DEFAULT
:
"hfl/chinese-llama-2-7b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/chinese-llama-2-7b"
,
},
"ChineseLLaMA2-13B"
:
{
DownloadSource
.
DEFAULT
:
"hfl/chinese-llama-2-13b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/chinese-llama-2-13b"
,
},
"ChineseLLaMA2-1.3B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"hfl/chinese-alpaca-2-1.3b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/chinese-alpaca-2-1.3b"
,
},
"ChineseLLaMA2-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"hfl/chinese-alpaca-2-7b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/chinese-alpaca-2-7b"
,
},
"ChineseLLaMA2-13B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"hfl/chinese-alpaca-2-13b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/chinese-alpaca-2-13b"
,
},
},
template
=
"llama2_zh"
,
)
register_model_group
(
models
=
{
"CodeGeeX4-9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/codegeex4-all-9b"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/codegeex4-all-9b"
,
},
},
template
=
"codegeex4"
,
)
register_model_group
(
models
=
{
"CodeGemma-7B"
:
{
DownloadSource
.
DEFAULT
:
"google/codegemma-7b"
,
},
"CodeGemma-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"google/codegemma-7b-it"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/codegemma-7b-it"
,
},
"CodeGemma-1.1-2B"
:
{
DownloadSource
.
DEFAULT
:
"google/codegemma-1.1-2b"
,
},
"CodeGemma-1.1-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"google/codegemma-1.1-7b-it"
,
},
},
template
=
"gemma"
,
)
register_model_group
(
models
=
{
"Codestral-22B-v0.1-Chat"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Codestral-22B-v0.1"
,
},
},
template
=
"mistral"
,
)
register_model_group
(
models
=
{
"CommandR-35B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"CohereForAI/c4ai-command-r-v01"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/c4ai-command-r-v01"
,
},
"CommandR-Plus-104B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"CohereForAI/c4ai-command-r-plus"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/c4ai-command-r-plus"
,
},
"CommandR-35B-4bit-Chat"
:
{
DownloadSource
.
DEFAULT
:
"CohereForAI/c4ai-command-r-v01-4bit"
,
DownloadSource
.
MODELSCOPE
:
"mirror013/c4ai-command-r-v01-4bit"
,
},
"CommandR-Plus-104B-4bit-Chat"
:
{
DownloadSource
.
DEFAULT
:
"CohereForAI/c4ai-command-r-plus-4bit"
,
},
},
template
=
"cohere"
,
)
register_model_group
(
models
=
{
"DBRX-132B-Base"
:
{
DownloadSource
.
DEFAULT
:
"databricks/dbrx-base"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/dbrx-base"
,
},
"DBRX-132B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"databricks/dbrx-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/dbrx-instruct"
,
},
},
template
=
"dbrx"
,
)
register_model_group
(
models
=
{
"DeepSeek-LLM-7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-llm-7b-base"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-llm-7b-base"
,
},
"DeepSeek-LLM-67B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-llm-67b-base"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-llm-67b-base"
,
},
"DeepSeek-LLM-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-llm-7b-chat"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-llm-7b-chat"
,
},
"DeepSeek-LLM-67B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-llm-67b-chat"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-llm-67b-chat"
,
},
"DeepSeek-Math-7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-math-7b-base"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-math-7b-base"
,
},
"DeepSeek-Math-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-math-7b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-math-7b-instruct"
,
},
"DeepSeek-MoE-16B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-moe-16b-base"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-moe-16b-base"
,
},
"DeepSeek-MoE-16B-v2-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2-Lite"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2-Lite"
,
},
"DeepSeek-MoE-236B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2"
,
},
"DeepSeek-MoE-16B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-moe-16b-chat"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-moe-16b-chat"
,
},
"DeepSeek-MoE-16B-v2-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2-Lite-Chat"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2-Lite-Chat"
,
},
"DeepSeek-MoE-236B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-V2-Chat"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/DeepSeek-V2-Chat"
,
},
"DeepSeek-MoE-Coder-16B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-Coder-V2-Lite-Base"
,
},
"DeepSeek-MoE-Coder-236B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-Coder-V2-Base"
,
},
"DeepSeek-MoE-Coder-16B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
,
},
"DeepSeek-MoE-Coder-236B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/DeepSeek-Coder-V2-Instruct"
,
},
},
template
=
"deepseek"
,
)
register_model_group
(
models
=
{
"DeepSeekCoder-6.7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-coder-6.7b-base"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-coder-6.7b-base"
,
},
"DeepSeekCoder-7B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-coder-7b-base-v1.5"
,
},
"DeepSeekCoder-33B-Base"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-coder-33b-base"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-coder-33b-base"
,
},
"DeepSeekCoder-6.7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-coder-6.7b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-coder-6.7b-instruct"
,
},
"DeepSeekCoder-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-coder-7b-instruct-v1.5"
,
},
"DeepSeekCoder-33B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"deepseek-ai/deepseek-coder-33b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"deepseek-ai/deepseek-coder-33b-instruct"
,
},
},
template
=
"deepseekcoder"
,
)
register_model_group
(
models
=
{
"Falcon-7B"
:
{
DownloadSource
.
DEFAULT
:
"tiiuae/falcon-7b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/falcon-7b"
,
},
"Falcon-11B"
:
{
DownloadSource
.
DEFAULT
:
"tiiuae/falcon-11B"
,
},
"Falcon-40B"
:
{
DownloadSource
.
DEFAULT
:
"tiiuae/falcon-40b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/falcon-40b"
,
},
"Falcon-180B"
:
{
DownloadSource
.
DEFAULT
:
"tiiuae/falcon-180b"
,
DownloadSource
.
MODELSCOPE
:
"modelscope/falcon-180B"
,
},
"Falcon-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"tiiuae/falcon-7b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/falcon-7b-instruct"
,
},
"Falcon-40B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"tiiuae/falcon-40b-instruct"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/falcon-40b-instruct"
,
},
"Falcon-180B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"tiiuae/falcon-180b-chat"
,
DownloadSource
.
MODELSCOPE
:
"modelscope/falcon-180B-chat"
,
},
},
template
=
"falcon"
,
)
register_model_group
(
models
=
{
"Gemma-2B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/gemma-2b"
,
},
"Gemma-7B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-7b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/gemma-2b-it"
,
},
"Gemma-2B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2b-it"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/gemma-7b"
,
},
"Gemma-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-7b-it"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/gemma-7b-it"
,
},
"Gemma-1.1-2B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-1.1-2b-it"
,
},
"Gemma-1.1-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-1.1-7b-it"
,
},
"Gemma-2-9B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2-9b"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-9b"
,
},
"Gemma-2-27B"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2-27b"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-27b"
,
},
"Gemma-2-9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2-9b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-9b-it"
,
},
"Gemma-2-27B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"google/gemma-2-27b-it"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/gemma-2-27b-it"
,
},
},
template
=
"gemma"
,
)
register_model_group
(
models
=
{
"GLM-4-9B"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/glm-4-9b"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/glm-4-9b"
,
},
"GLM-4-9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/glm-4-9b-chat"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/glm-4-9b-chat"
,
},
"GLM-4-9B-1M-Chat"
:
{
DownloadSource
.
DEFAULT
:
"THUDM/glm-4-9b-chat-1m"
,
DownloadSource
.
MODELSCOPE
:
"ZhipuAI/glm-4-9b-chat-1m"
,
},
},
template
=
"glm4"
,
)
register_model_group
(
models
=
{
"InternLM-7B"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm-7b"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm-7b"
,
},
"InternLM-20B"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm-20b"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm-20b"
,
},
"InternLM-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm-chat-7b"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm-chat-7b"
,
},
"InternLM-20B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm-chat-20b"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm-chat-20b"
,
},
},
template
=
"intern"
,
)
register_model_group
(
models
=
{
"InternLM2-7B"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2-7b"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2-7b"
,
},
"InternLM2-20B"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2-20b"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2-20b"
,
},
"InternLM2-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2-chat-7b"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2-chat-7b"
,
},
"InternLM2-20B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2-chat-20b"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2-chat-20b"
,
},
},
template
=
"intern2"
,
)
register_model_group
(
models
=
{
"InternLM2.5-7B"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2_5-7b"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2_5-7b"
,
},
"InternLM2.5-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2_5-7b-chat"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2_5-7b-chat"
,
},
"InternLM2.5-7B-1M-Chat"
:
{
DownloadSource
.
DEFAULT
:
"internlm/internlm2_5-7b-chat-1m"
,
DownloadSource
.
MODELSCOPE
:
"Shanghai_AI_Laboratory/internlm2_5-7b-chat-1m"
,
},
},
template
=
"intern2"
,
)
register_model_group
(
models
=
{
"Jamba-v0.1"
:
{
DownloadSource
.
DEFAULT
:
"ai21labs/Jamba-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Jamba-v0.1"
,
}
},
)
register_model_group
(
models
=
{
"LingoWhale-8B"
:
{
DownloadSource
.
DEFAULT
:
"deeplang-ai/LingoWhale-8B"
,
DownloadSource
.
MODELSCOPE
:
"DeepLang/LingoWhale-8B"
,
}
},
)
register_model_group
(
models
=
{
"LLaMA-7B"
:
{
DownloadSource
.
DEFAULT
:
"huggyllama/llama-7b"
,
DownloadSource
.
MODELSCOPE
:
"skyline2006/llama-7b"
,
},
"LLaMA-13B"
:
{
DownloadSource
.
DEFAULT
:
"huggyllama/llama-13b"
,
DownloadSource
.
MODELSCOPE
:
"skyline2006/llama-13b"
,
},
"LLaMA-30B"
:
{
DownloadSource
.
DEFAULT
:
"huggyllama/llama-30b"
,
DownloadSource
.
MODELSCOPE
:
"skyline2006/llama-30b"
,
},
"LLaMA-65B"
:
{
DownloadSource
.
DEFAULT
:
"huggyllama/llama-65b"
,
DownloadSource
.
MODELSCOPE
:
"skyline2006/llama-65b"
,
},
}
)
register_model_group
(
models
=
{
"LLaMA2-7B"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-2-7b-hf"
,
DownloadSource
.
MODELSCOPE
:
"modelscope/Llama-2-7b-ms"
,
},
"LLaMA2-13B"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-2-13b-hf"
,
DownloadSource
.
MODELSCOPE
:
"modelscope/Llama-2-13b-ms"
,
},
"LLaMA2-70B"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-2-70b-hf"
,
DownloadSource
.
MODELSCOPE
:
"modelscope/Llama-2-70b-ms"
,
},
"LLaMA2-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-2-7b-chat-hf"
,
DownloadSource
.
MODELSCOPE
:
"modelscope/Llama-2-7b-chat-ms"
,
},
"LLaMA2-13B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-2-13b-chat-hf"
,
DownloadSource
.
MODELSCOPE
:
"modelscope/Llama-2-13b-chat-ms"
,
},
"LLaMA2-70B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Llama-2-70b-chat-hf"
,
DownloadSource
.
MODELSCOPE
:
"modelscope/Llama-2-70b-chat-ms"
,
},
},
template
=
"llama2"
,
)
register_model_group
(
models
=
{
"LLaMA3-8B"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3-8B"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Meta-Llama-3-8B"
,
},
"LLaMA3-70B"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3-70B"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Meta-Llama-3-70B"
,
},
"LLaMA3-8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3-8B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Meta-Llama-3-8B-Instruct"
,
},
"LLaMA3-70B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3-70B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Meta-Llama-3-70B-Instruct"
,
},
"LLaMA3-8B-Chinese-Chat"
:
{
DownloadSource
.
DEFAULT
:
"shenzhi-wang/Llama3-8B-Chinese-Chat"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Llama3-8B-Chinese-Chat"
,
},
"LLaMA3-70B-Chinese-Chat"
:
{
DownloadSource
.
DEFAULT
:
"shenzhi-wang/Llama3-70B-Chinese-Chat"
,
},
},
template
=
"llama3"
,
)
register_model_group
(
models
=
{
"LLaMA3.1-8B"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3.1-8B"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Meta-Llama-3.1-8B"
,
},
"LLaMA3.1-70B"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3.1-70B"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Meta-Llama-3.1-70B"
,
},
"LLaMA3.1-405B"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3.1-405B"
,
},
"LLaMA3.1-8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3.1-8B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Meta-Llama-3.1-8B-Instruct"
,
},
"LLaMA3.1-70B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3.1-70B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Meta-Llama-3.1-70B-Instruct"
,
},
"LLaMA3.1-405B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"meta-llama/Meta-Llama-3.1-405B-Instruct"
,
},
},
template
=
"llama3"
,
)
register_model_group
(
models
=
{
"LLaVA1.5-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"llava-hf/llava-1.5-7b-hf"
,
},
"LLaVA1.5-13B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"llava-hf/llava-1.5-13b-hf"
,
},
},
template
=
"vicuna"
,
vision
=
True
,
)
register_model_group
(
models
=
{
"MiniCPM-2B-SFT-Chat"
:
{
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM-2B-sft-bf16"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/miniCPM-bf16"
,
},
"MiniCPM-2B-DPO-Chat"
:
{
DownloadSource
.
DEFAULT
:
"openbmb/MiniCPM-2B-dpo-bf16"
,
DownloadSource
.
MODELSCOPE
:
"OpenBMB/MiniCPM-2B-dpo-bf16"
,
},
},
template
=
"cpm"
,
)
register_model_group
(
models
=
{
"Mistral-7B-v0.1"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-7B-v0.1"
,
},
"Mistral-7B-v0.1-Chat"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-Instruct-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-7B-Instruct-v0.1"
,
},
"Mistral-7B-v0.2"
:
{
DownloadSource
.
DEFAULT
:
"alpindale/Mistral-7B-v0.2-hf"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-7B-v0.2-hf"
,
},
"Mistral-7B-v0.2-Chat"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-Instruct-v0.2"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-7B-Instruct-v0.2"
,
},
"Mistral-7B-v0.3"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-v0.3"
,
},
"Mistral-7B-v0.3-Chat"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-7B-Instruct-v0.3"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Mistral-7B-Instruct-v0.3"
,
},
"Mistral-Nemo-Chat"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mistral-Nemo-Instruct-2407"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mistral-Nemo-Instruct-2407"
,
},
},
template
=
"mistral"
,
)
register_model_group
(
models
=
{
"Mixtral-8x7B-v0.1"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mixtral-8x7B-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mixtral-8x7B-v0.1"
,
},
"Mixtral-8x7B-v0.1-Chat"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mixtral-8x7B-Instruct-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mixtral-8x7B-Instruct-v0.1"
,
},
"Mixtral-8x22B-v0.1"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mixtral-8x22B-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mixtral-8x22B-v0.1"
,
},
"Mixtral-8x22B-v0.1-Chat"
:
{
DownloadSource
.
DEFAULT
:
"mistralai/Mixtral-8x22B-Instruct-v0.1"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/Mixtral-8x22B-Instruct-v0.1"
,
},
},
template
=
"mistral"
,
)
register_model_group
(
models
=
{
"OLMo-1B"
:
{
DownloadSource
.
DEFAULT
:
"allenai/OLMo-1B-hf"
,
},
"OLMo-7B"
:
{
DownloadSource
.
DEFAULT
:
"allenai/OLMo-7B-hf"
,
},
"OLMo-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"ssec-uw/OLMo-7B-Instruct-hf"
,
},
"OLMo-1.7-7B"
:
{
DownloadSource
.
DEFAULT
:
"allenai/OLMo-1.7-7B-hf"
,
},
},
)
register_model_group
(
models
=
{
"OpenChat3.5-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"openchat/openchat-3.5-0106"
,
DownloadSource
.
MODELSCOPE
:
"xcwzxcwz/openchat-3.5-0106"
,
}
},
template
=
"openchat"
,
)
register_model_group
(
models
=
{
"OpenChat3.6-8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"openchat/openchat-3.6-8b-20240522"
,
}
},
template
=
"openchat-3.6"
,
)
register_model_group
(
models
=
{
"Orion-14B-Base"
:
{
DownloadSource
.
DEFAULT
:
"OrionStarAI/Orion-14B-Base"
,
DownloadSource
.
MODELSCOPE
:
"OrionStarAI/Orion-14B-Base"
,
},
"Orion-14B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"OrionStarAI/Orion-14B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"OrionStarAI/Orion-14B-Chat"
,
},
"Orion-14B-Long-Chat"
:
{
DownloadSource
.
DEFAULT
:
"OrionStarAI/Orion-14B-LongChat"
,
DownloadSource
.
MODELSCOPE
:
"OrionStarAI/Orion-14B-LongChat"
,
},
"Orion-14B-RAG-Chat"
:
{
DownloadSource
.
DEFAULT
:
"OrionStarAI/Orion-14B-Chat-RAG"
,
DownloadSource
.
MODELSCOPE
:
"OrionStarAI/Orion-14B-Chat-RAG"
,
},
"Orion-14B-Plugin-Chat"
:
{
DownloadSource
.
DEFAULT
:
"OrionStarAI/Orion-14B-Chat-Plugin"
,
DownloadSource
.
MODELSCOPE
:
"OrionStarAI/Orion-14B-Chat-Plugin"
,
},
},
template
=
"orion"
,
)
register_model_group
(
models
=
{
"PaliGemma-3B-pt-224"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma-3b-pt-224"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma-3b-pt-224"
,
},
"PaliGemma-3B-pt-448"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma-3b-pt-448"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma-3b-pt-448"
,
},
"PaliGemma-3B-pt-896"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma-3b-pt-896"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma-3b-pt-896"
,
},
"PaliGemma-3B-mix-224"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma-3b-mix-224"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma-3b-mix-224"
,
},
"PaliGemma-3B-mix-448"
:
{
DownloadSource
.
DEFAULT
:
"google/paligemma-3b-mix-448"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/paligemma-3b-mix-448"
,
},
},
vision
=
True
,
)
register_model_group
(
models
=
{
"Phi-1.5-1.3B"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/phi-1_5"
,
DownloadSource
.
MODELSCOPE
:
"allspace/PHI_1-5"
,
},
"Phi-2-2.7B"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/phi-2"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/phi-2"
,
},
}
)
register_model_group
(
models
=
{
"Phi3-4B-4k-Chat"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-mini-4k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-mini-4k-instruct"
,
},
"Phi3-4B-128k-Chat"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-mini-128k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-mini-128k-instruct"
,
},
"Phi3-7B-8k-Chat"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-small-8k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-small-8k-instruct"
,
},
"Phi3-7B-128k-Chat"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-small-128k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-small-128k-instruct"
,
},
"Phi3-14B-8k-Chat"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-medium-4k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-medium-4k-instruct"
,
},
"Phi3-14B-128k-Chat"
:
{
DownloadSource
.
DEFAULT
:
"microsoft/Phi-3-medium-128k-instruct"
,
DownloadSource
.
MODELSCOPE
:
"LLM-Research/Phi-3-medium-128k-instruct"
,
},
},
template
=
"phi"
,
)
register_model_group
(
models
=
{
"Qwen-1.8B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-1_8B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-1_8B"
,
},
"Qwen-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-7B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-7B"
,
},
"Qwen-14B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-14B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-14B"
,
},
"Qwen-72B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-72B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-72B"
,
},
"Qwen-1.8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-1_8B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-1_8B-Chat"
,
},
"Qwen-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-7B-Chat"
,
},
"Qwen-14B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-14B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-14B-Chat"
,
},
"Qwen-72B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-72B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-72B-Chat"
,
},
"Qwen-1.8B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-1_8B-Chat-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-1_8B-Chat-Int8"
,
},
"Qwen-1.8B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-1_8B-Chat-Int4"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-1_8B-Chat-Int4"
,
},
"Qwen-7B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-7B-Chat-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-7B-Chat-Int8"
,
},
"Qwen-7B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-7B-Chat-Int4"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-7B-Chat-Int4"
,
},
"Qwen-14B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-14B-Chat-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-14B-Chat-Int8"
,
},
"Qwen-14B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-14B-Chat-Int4"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-14B-Chat-Int4"
,
},
"Qwen-72B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-72B-Chat-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-72B-Chat-Int8"
,
},
"Qwen-72B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen-72B-Chat-Int4"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen-72B-Chat-Int4"
,
},
},
template
=
"qwen"
,
)
register_model_group
(
models
=
{
"Qwen1.5-0.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-0.5B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-0.5B"
,
},
"Qwen1.5-1.8B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-1.8B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-1.8B"
,
},
"Qwen1.5-4B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-4B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-4B"
,
},
"Qwen1.5-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-7B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-7B"
,
},
"Qwen1.5-14B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-14B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-14B"
,
},
"Qwen1.5-32B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-32B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-32B"
,
},
"Qwen1.5-72B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-72B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-72B"
,
},
"Qwen1.5-110B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-110B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-110B"
,
},
"Qwen1.5-MoE-A2.7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-MoE-A2.7B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-MoE-A2.7B"
,
},
"Qwen1.5-Code-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/CodeQwen1.5-7B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/CodeQwen1.5-7B"
,
},
"Qwen1.5-0.5B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-0.5B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-0.5B-Chat"
,
},
"Qwen1.5-1.8B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-1.8B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-1.8B-Chat"
,
},
"Qwen1.5-4B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-4B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-4B-Chat"
,
},
"Qwen1.5-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-7B-Chat"
,
},
"Qwen1.5-14B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-14B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-14B-Chat"
,
},
"Qwen1.5-32B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-32B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-32B-Chat"
,
},
"Qwen1.5-72B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-72B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-72B-Chat"
,
},
"Qwen1.5-110B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-110B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-110B-Chat"
,
},
"Qwen1.5-MoE-A2.7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-MoE-A2.7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-MoE-A2.7B-Chat"
,
},
"Qwen1.5-Code-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/CodeQwen1.5-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"qwen/CodeQwen1.5-7B-Chat"
,
},
"Qwen1.5-0.5B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8"
,
},
"Qwen1.5-0.5B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-0.5B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-0.5B-Chat-AWQ"
,
},
"Qwen1.5-1.8B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8"
,
},
"Qwen1.5-1.8B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-1.8B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-1.8B-Chat-AWQ"
,
},
"Qwen1.5-4B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-4B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-4B-Chat-GPTQ-Int8"
,
},
"Qwen1.5-4B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-4B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-4B-Chat-AWQ"
,
},
"Qwen1.5-7B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-7B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-7B-Chat-GPTQ-Int8"
,
},
"Qwen1.5-7B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-7B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-7B-Chat-AWQ"
,
},
"Qwen1.5-14B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-14B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-14B-Chat-GPTQ-Int8"
,
},
"Qwen1.5-14B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-14B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-14B-Chat-AWQ"
,
},
"Qwen1.5-32B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-32B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-32B-Chat-AWQ"
,
},
"Qwen1.5-72B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-72B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-72B-Chat-GPTQ-Int8"
,
},
"Qwen1.5-72B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-72B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-72B-Chat-AWQ"
,
},
"Qwen1.5-110B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-110B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-110B-Chat-AWQ"
,
},
"Qwen1.5-MoE-A2.7B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4"
,
},
"Qwen1.5-Code-7B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/CodeQwen1.5-7B-Chat-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/CodeQwen1.5-7B-Chat-AWQ"
,
},
},
template
=
"qwen"
,
)
register_model_group
(
models
=
{
"Qwen2-0.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-0.5B"
,
},
"Qwen2-1.5B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-1.5B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-1.5B"
,
},
"Qwen2-7B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-7B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-7B"
,
},
"Qwen2-72B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-72B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-72B"
,
},
"Qwen2-MoE-57B"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-57B-A14B"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-57B-A14B"
,
},
"Qwen2-0.5B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-0.5B-Instruct"
,
},
"Qwen2-1.5B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-1.5B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-1.5B-Instruct"
,
},
"Qwen2-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-7B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-7B-Instruct"
,
},
"Qwen2-72B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-72B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-72B-Instruct"
,
},
"Qwen2-MoE-57B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-57B-A14B-Instruct"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-57B-A14B-Instruct"
,
},
"Qwen2-0.5B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-0.5B-Instruct-GPTQ-Int8"
,
},
"Qwen2-0.5B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-0.5B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-0.5B-Instruct-AWQ"
,
},
"Qwen2-1.5B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-1.5B-Instruct-GPTQ-Int8"
,
},
"Qwen2-1.5B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-1.5B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-1.5B-Instruct-AWQ"
,
},
"Qwen2-7B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-7B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-7B-Instruct-GPTQ-Int8"
,
},
"Qwen2-7B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-7B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-7B-Instruct-AWQ"
,
},
"Qwen2-72B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-72B-Instruct-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-72B-Instruct-GPTQ-Int8"
,
},
"Qwen2-72B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-72B-Instruct-AWQ"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-72B-Instruct-AWQ"
,
},
"Qwen2-MoE-57B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4"
,
},
},
template
=
"qwen"
,
)
register_model_group
(
models
=
{
"SOLAR-10.7B"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-v1.0"
,
},
"SOLAR-10.7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"upstage/SOLAR-10.7B-Instruct-v1.0"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/SOLAR-10.7B-Instruct-v1.0"
,
},
},
template
=
"solar"
,
)
register_model_group
(
models
=
{
"Skywork-13B-Base"
:
{
DownloadSource
.
DEFAULT
:
"Skywork/Skywork-13B-base"
,
DownloadSource
.
MODELSCOPE
:
"skywork/Skywork-13B-base"
,
}
}
)
register_model_group
(
models
=
{
"StarCoder2-3B"
:
{
DownloadSource
.
DEFAULT
:
"bigcode/starcoder2-3b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/starcoder2-3b"
,
},
"StarCoder2-7B"
:
{
DownloadSource
.
DEFAULT
:
"bigcode/starcoder2-7b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/starcoder2-7b"
,
},
"StarCoder2-15B"
:
{
DownloadSource
.
DEFAULT
:
"bigcode/starcoder2-15b"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/starcoder2-15b"
,
},
}
)
register_model_group
(
models
=
{
"TeleChat-1B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/TeleChat-1B"
,
DownloadSource
.
MODELSCOPE
:
"TeleAI/TeleChat-1B"
,
},
"TeleChat-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/telechat-7B"
,
DownloadSource
.
MODELSCOPE
:
"TeleAI/telechat-7B"
,
},
"TeleChat-12B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/TeleChat-12B"
,
DownloadSource
.
MODELSCOPE
:
"TeleAI/TeleChat-12B"
,
},
"TeleChat-12B-v2-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Tele-AI/TeleChat-12B-v2"
,
DownloadSource
.
MODELSCOPE
:
"TeleAI/TeleChat-12B-v2"
,
},
},
template
=
"telechat"
,
)
register_model_group
(
models
=
{
"Vicuna1.5-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"lmsys/vicuna-7b-v1.5"
,
DownloadSource
.
MODELSCOPE
:
"Xorbits/vicuna-7b-v1.5"
,
},
"Vicuna1.5-13B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"lmsys/vicuna-13b-v1.5"
,
DownloadSource
.
MODELSCOPE
:
"Xorbits/vicuna-13b-v1.5"
,
},
},
template
=
"vicuna"
,
)
register_model_group
(
models
=
{
"XuanYuan-6B"
:
{
DownloadSource
.
DEFAULT
:
"Duxiaoman-DI/XuanYuan-6B"
,
DownloadSource
.
MODELSCOPE
:
"Duxiaoman-DI/XuanYuan-6B"
,
},
"XuanYuan-70B"
:
{
DownloadSource
.
DEFAULT
:
"Duxiaoman-DI/XuanYuan-70B"
,
DownloadSource
.
MODELSCOPE
:
"Duxiaoman-DI/XuanYuan-70B"
,
},
"XuanYuan-2-70B"
:
{
DownloadSource
.
DEFAULT
:
"Duxiaoman-DI/XuanYuan2-70B"
,
DownloadSource
.
MODELSCOPE
:
"Duxiaoman-DI/XuanYuan2-70B"
,
},
"XuanYuan-6B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Duxiaoman-DI/XuanYuan-6B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"Duxiaoman-DI/XuanYuan-6B-Chat"
,
},
"XuanYuan-70B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Duxiaoman-DI/XuanYuan-70B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"Duxiaoman-DI/XuanYuan-70B-Chat"
,
},
"XuanYuan-2-70B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Duxiaoman-DI/XuanYuan2-70B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"Duxiaoman-DI/XuanYuan2-70B-Chat"
,
},
"XuanYuan-6B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Duxiaoman-DI/XuanYuan-6B-Chat-8bit"
,
DownloadSource
.
MODELSCOPE
:
"Duxiaoman-DI/XuanYuan-6B-Chat-8bit"
,
},
"XuanYuan-6B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Duxiaoman-DI/XuanYuan-6B-Chat-4bit"
,
DownloadSource
.
MODELSCOPE
:
"Duxiaoman-DI/XuanYuan-6B-Chat-4bit"
,
},
"XuanYuan-70B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Duxiaoman-DI/XuanYuan-70B-Chat-8bit"
,
DownloadSource
.
MODELSCOPE
:
"Duxiaoman-DI/XuanYuan-70B-Chat-8bit"
,
},
"XuanYuan-70B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Duxiaoman-DI/XuanYuan-70B-Chat-4bit"
,
DownloadSource
.
MODELSCOPE
:
"Duxiaoman-DI/XuanYuan-70B-Chat-4bit"
,
},
"XuanYuan-2-70B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Duxiaoman-DI/XuanYuan2-70B-Chat-8bit"
,
DownloadSource
.
MODELSCOPE
:
"Duxiaoman-DI/XuanYuan2-70B-Chat-8bit"
,
},
"XuanYuan-2-70B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"Duxiaoman-DI/XuanYuan2-70B-Chat-4bit"
,
DownloadSource
.
MODELSCOPE
:
"Duxiaoman-DI/XuanYuan2-70B-Chat-4bit"
,
},
},
template
=
"xuanyuan"
,
)
register_model_group
(
models
=
{
"XVERSE-7B"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-7B"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-7B"
,
},
"XVERSE-13B"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-13B"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-13B"
,
},
"XVERSE-65B"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-65B"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-65B"
,
},
"XVERSE-65B-2"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-65B-2"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-65B-2"
,
},
"XVERSE-7B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-7B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-7B-Chat"
,
},
"XVERSE-13B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-13B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-13B-Chat"
,
},
"XVERSE-65B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-65B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-65B-Chat"
,
},
"XVERSE-MoE-A4.2B"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-MoE-A4.2B"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-MoE-A4.2B"
,
},
"XVERSE-7B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-7B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-7B-Chat-GPTQ-Int8"
,
},
"XVERSE-7B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-7B-Chat-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-7B-Chat-GPTQ-Int4"
,
},
"XVERSE-13B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-13B-Chat-GPTQ-Int8"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-13B-Chat-GPTQ-Int8"
,
},
"XVERSE-13B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-13B-Chat-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-13B-Chat-GPTQ-Int4"
,
},
"XVERSE-65B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"xverse/XVERSE-65B-Chat-GPTQ-Int4"
,
DownloadSource
.
MODELSCOPE
:
"xverse/XVERSE-65B-Chat-GPTQ-Int4"
,
},
},
template
=
"xverse"
,
)
register_model_group
(
models
=
{
"Yayi-7B"
:
{
DownloadSource
.
DEFAULT
:
"wenge-research/yayi-7b-llama2"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/yayi-7b-llama2"
,
},
"Yayi-13B"
:
{
DownloadSource
.
DEFAULT
:
"wenge-research/yayi-13b-llama2"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/yayi-13b-llama2"
,
},
},
template
=
"yayi"
,
)
register_model_group
(
models
=
{
"Yi-6B"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-6B"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-6B"
,
},
"Yi-9B"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-9B"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-9B"
,
},
"Yi-34B"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-34B"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-34B"
,
},
"Yi-6B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-6B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-6B-Chat"
,
},
"Yi-34B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-34B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-34B-Chat"
,
},
"Yi-6B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-6B-Chat-8bits"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-6B-Chat-8bits"
,
},
"Yi-6B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-6B-Chat-4bits"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-6B-Chat-4bits"
,
},
"Yi-34B-int8-Chat"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-34B-Chat-8bits"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-34B-Chat-8bits"
,
},
"Yi-34B-int4-Chat"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-34B-Chat-4bits"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-34B-Chat-4bits"
,
},
"Yi-1.5-6B"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-1.5-6B"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-1.5-6B"
,
},
"Yi-1.5-9B"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-1.5-9B"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-1.5-9B"
,
},
"Yi-1.5-34B"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-1.5-34B"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-1.5-34B"
,
},
"Yi-1.5-6B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-1.5-6B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-1.5-6B-Chat"
,
},
"Yi-1.5-9B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-1.5-9B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-1.5-9B-Chat"
,
},
"Yi-1.5-34B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"01-ai/Yi-1.5-34B-Chat"
,
DownloadSource
.
MODELSCOPE
:
"01ai/Yi-1.5-34B-Chat"
,
},
},
template
=
"yi"
,
)
register_model_group
(
models
=
{
"YiVL-6B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"BUAADreamer/Yi-VL-6B-hf"
,
},
"YiVL-34B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"BUAADreamer/Yi-VL-34B-hf"
,
},
},
template
=
"yi_vl"
,
vision
=
True
,
)
register_model_group
(
models
=
{
"Yuan2-2B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"IEITYuan/Yuan2-2B-hf"
,
DownloadSource
.
MODELSCOPE
:
"YuanLLM/Yuan2.0-2B-hf"
,
},
"Yuan2-51B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"IEITYuan/Yuan2-51B-hf"
,
DownloadSource
.
MODELSCOPE
:
"YuanLLM/Yuan2.0-51B-hf"
,
},
"Yuan2-102B-Chat"
:
{
DownloadSource
.
DEFAULT
:
"IEITYuan/Yuan2-102B-hf"
,
DownloadSource
.
MODELSCOPE
:
"YuanLLM/Yuan2.0-102B-hf"
,
},
},
template
=
"yuan"
,
)
register_model_group
(
models
=
{
"Zephyr-7B-Alpha-Chat"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceH4/zephyr-7b-alpha"
,
DownloadSource
.
MODELSCOPE
:
"AI-ModelScope/zephyr-7b-alpha"
,
},
"Zephyr-7B-Beta-Chat"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceH4/zephyr-7b-beta"
,
DownloadSource
.
MODELSCOPE
:
"modelscope/zephyr-7b-beta"
,
},
"Zephyr-141B-ORPO-Chat"
:
{
DownloadSource
.
DEFAULT
:
"HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1"
,
},
},
template
=
"zephyr"
,
)
LLaMA-Factory/src/llamafactory/extras/env.py
0 → 100644
View file @
032b90a1
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
platform
import
accelerate
import
datasets
import
peft
import
torch
import
transformers
import
trl
from
transformers.utils
import
is_torch_cuda_available
,
is_torch_npu_available
VERSION
=
"0.8.4.dev0"
def
print_env
()
->
None
:
info
=
{
"`llamafactory` version"
:
VERSION
,
"Platform"
:
platform
.
platform
(),
"Python version"
:
platform
.
python_version
(),
"PyTorch version"
:
torch
.
__version__
,
"Transformers version"
:
transformers
.
__version__
,
"Datasets version"
:
datasets
.
__version__
,
"Accelerate version"
:
accelerate
.
__version__
,
"PEFT version"
:
peft
.
__version__
,
"TRL version"
:
trl
.
__version__
,
}
if
is_torch_cuda_available
():
info
[
"PyTorch version"
]
+=
" (GPU)"
info
[
"GPU type"
]
=
torch
.
cuda
.
get_device_name
()
if
is_torch_npu_available
():
info
[
"PyTorch version"
]
+=
" (NPU)"
info
[
"NPU type"
]
=
torch
.
npu
.
get_device_name
()
info
[
"CANN version"
]
=
torch
.
version
.
cann
try
:
import
deepspeed
# type: ignore
info
[
"DeepSpeed version"
]
=
deepspeed
.
__version__
except
Exception
:
pass
try
:
import
bitsandbytes
info
[
"Bitsandbytes version"
]
=
bitsandbytes
.
__version__
except
Exception
:
pass
try
:
import
vllm
info
[
"vLLM version"
]
=
vllm
.
__version__
except
Exception
:
pass
print
(
"
\n
"
+
"
\n
"
.
join
([
"- {}: {}"
.
format
(
key
,
value
)
for
key
,
value
in
info
.
items
()])
+
"
\n
"
)
LLaMA-Factory/src/llamafactory/extras/logging.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
os
import
sys
from
concurrent.futures
import
ThreadPoolExecutor
from
.constants
import
RUNNING_LOG
class
LoggerHandler
(
logging
.
Handler
):
r
"""
Logger handler used in Web UI.
"""
def
__init__
(
self
,
output_dir
:
str
)
->
None
:
super
().
__init__
()
formatter
=
logging
.
Formatter
(
fmt
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
)
self
.
setLevel
(
logging
.
INFO
)
self
.
setFormatter
(
formatter
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
self
.
running_log
=
os
.
path
.
join
(
output_dir
,
RUNNING_LOG
)
if
os
.
path
.
exists
(
self
.
running_log
):
os
.
remove
(
self
.
running_log
)
self
.
thread_pool
=
ThreadPoolExecutor
(
max_workers
=
1
)
def
_write_log
(
self
,
log_entry
:
str
)
->
None
:
with
open
(
self
.
running_log
,
"a"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
log_entry
+
"
\n\n
"
)
def
emit
(
self
,
record
)
->
None
:
if
record
.
name
==
"httpx"
:
return
log_entry
=
self
.
format
(
record
)
self
.
thread_pool
.
submit
(
self
.
_write_log
,
log_entry
)
def
close
(
self
)
->
None
:
self
.
thread_pool
.
shutdown
(
wait
=
True
)
return
super
().
close
()
def
get_logger
(
name
:
str
)
->
logging
.
Logger
:
r
"""
Gets a standard logger with a stream hander to stdout.
"""
formatter
=
logging
.
Formatter
(
fmt
=
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
)
handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
handler
.
setFormatter
(
formatter
)
logger
=
logging
.
getLogger
(
name
)
logger
.
setLevel
(
logging
.
INFO
)
logger
.
addHandler
(
handler
)
return
logger
def
reset_logging
()
->
None
:
r
"""
Removes basic config of root logger. (unused in script)
"""
root
=
logging
.
getLogger
()
list
(
map
(
root
.
removeHandler
,
root
.
handlers
))
list
(
map
(
root
.
removeFilter
,
root
.
filters
))
LLaMA-Factory/src/llamafactory/extras/misc.py
0 → 100644
View file @
032b90a1
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's PEFT library.
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
gc
import
os
from
typing
import
TYPE_CHECKING
,
Tuple
,
Union
import
torch
import
transformers.dynamic_module_utils
from
transformers
import
InfNanRemoveLogitsProcessor
,
LogitsProcessorList
from
transformers.dynamic_module_utils
import
get_relative_imports
from
transformers.utils
import
(
is_torch_bf16_gpu_available
,
is_torch_cuda_available
,
is_torch_mps_available
,
is_torch_npu_available
,
is_torch_xpu_available
,
)
from
transformers.utils.versions
import
require_version
from
.logging
import
get_logger
_is_fp16_available
=
is_torch_npu_available
()
or
is_torch_cuda_available
()
try
:
_is_bf16_available
=
is_torch_bf16_gpu_available
()
except
Exception
:
_is_bf16_available
=
False
if
TYPE_CHECKING
:
from
numpy.typing
import
NDArray
from
..hparams
import
ModelArguments
logger
=
get_logger
(
__name__
)
class
AverageMeter
:
r
"""
Computes and stores the average and current value.
"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
def
check_dependencies
()
->
None
:
r
"""
Checks the version of the required packages.
"""
if
os
.
environ
.
get
(
"DISABLE_VERSION_CHECK"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]:
logger
.
warning
(
"Version checking has been disabled, may lead to unexpected behaviors."
)
else
:
require_version
(
"transformers>=4.41.2"
,
"To fix: pip install transformers>=4.41.2"
)
require_version
(
"datasets>=2.16.0"
,
"To fix: pip install datasets>=2.16.0"
)
require_version
(
"accelerate>=0.30.1"
,
"To fix: pip install accelerate>=0.30.1"
)
require_version
(
"peft>=0.11.1"
,
"To fix: pip install peft>=0.11.1"
)
require_version
(
"trl>=0.8.6"
,
"To fix: pip install trl>=0.8.6"
)
def
count_parameters
(
model
:
"torch.nn.Module"
)
->
Tuple
[
int
,
int
]:
r
"""
Returns the number of trainable parameters and number of all parameters in the model.
"""
trainable_params
,
all_param
=
0
,
0
for
param
in
model
.
parameters
():
num_params
=
param
.
numel
()
# if using DS Zero 3 and the weights are initialized empty
if
num_params
==
0
and
hasattr
(
param
,
"ds_numel"
):
num_params
=
param
.
ds_numel
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
if
param
.
__class__
.
__name__
==
"Params4bit"
:
if
hasattr
(
param
,
"quant_storage"
)
and
hasattr
(
param
.
quant_storage
,
"itemsize"
):
num_bytes
=
param
.
quant_storage
.
itemsize
elif
hasattr
(
param
,
"element_size"
):
# for older pytorch version
num_bytes
=
param
.
element_size
()
else
:
num_bytes
=
1
num_params
=
num_params
*
2
*
num_bytes
all_param
+=
num_params
if
param
.
requires_grad
:
trainable_params
+=
num_params
return
trainable_params
,
all_param
def
get_current_device
()
->
"torch.device"
:
r
"""
Gets the current available device.
"""
if
is_torch_xpu_available
():
device
=
"xpu:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_npu_available
():
device
=
"npu:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_mps_available
():
device
=
"mps:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
elif
is_torch_cuda_available
():
device
=
"cuda:{}"
.
format
(
os
.
environ
.
get
(
"LOCAL_RANK"
,
"0"
))
else
:
device
=
"cpu"
return
torch
.
device
(
device
)
def
get_device_count
()
->
int
:
r
"""
Gets the number of available GPU or NPU devices.
"""
if
is_torch_npu_available
():
return
torch
.
npu
.
device_count
()
elif
is_torch_cuda_available
():
return
torch
.
cuda
.
device_count
()
else
:
return
0
def
get_logits_processor
()
->
"LogitsProcessorList"
:
r
"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor
=
LogitsProcessorList
()
logits_processor
.
append
(
InfNanRemoveLogitsProcessor
())
return
logits_processor
def
has_tokenized_data
(
path
:
"os.PathLike"
)
->
bool
:
r
"""
Checks if the path has a tokenized dataset.
"""
return
os
.
path
.
isdir
(
path
)
and
len
(
os
.
listdir
(
path
))
>
0
def
infer_optim_dtype
(
model_dtype
:
"torch.dtype"
)
->
"torch.dtype"
:
r
"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
if
_is_bf16_available
and
model_dtype
==
torch
.
bfloat16
:
return
torch
.
bfloat16
elif
_is_fp16_available
:
return
torch
.
float16
else
:
return
torch
.
float32
def
is_gpu_or_npu_available
()
->
bool
:
r
"""
Checks if the GPU or NPU is available.
"""
return
is_torch_npu_available
()
or
is_torch_cuda_available
()
def
numpify
(
inputs
:
Union
[
"NDArray"
,
"torch.Tensor"
])
->
"NDArray"
:
if
isinstance
(
inputs
,
torch
.
Tensor
):
inputs
=
inputs
.
cpu
()
if
inputs
.
dtype
==
torch
.
bfloat16
:
# numpy does not support bfloat16 until 1.21.4
inputs
=
inputs
.
to
(
torch
.
float32
)
inputs
=
inputs
.
numpy
()
return
inputs
def
skip_check_imports
()
->
None
:
if
os
.
environ
.
get
(
"FORCE_CHECK_IMPORTS"
,
"0"
).
lower
()
not
in
[
"true"
,
"1"
]:
transformers
.
dynamic_module_utils
.
check_imports
=
get_relative_imports
def
torch_gc
()
->
None
:
r
"""
Collects GPU or NPU memory.
"""
gc
.
collect
()
if
is_torch_xpu_available
():
torch
.
xpu
.
empty_cache
()
elif
is_torch_npu_available
():
torch
.
npu
.
empty_cache
()
elif
is_torch_mps_available
():
torch
.
mps
.
empty_cache
()
elif
is_torch_cuda_available
():
torch
.
cuda
.
empty_cache
()
def
try_download_model_from_ms
(
model_args
:
"ModelArguments"
)
->
str
:
if
not
use_modelscope
()
or
os
.
path
.
exists
(
model_args
.
model_name_or_path
):
return
model_args
.
model_name_or_path
try
:
from
modelscope
import
snapshot_download
revision
=
"master"
if
model_args
.
model_revision
==
"main"
else
model_args
.
model_revision
return
snapshot_download
(
model_args
.
model_name_or_path
,
revision
=
revision
,
cache_dir
=
model_args
.
cache_dir
)
except
ImportError
:
raise
ImportError
(
"Please install modelscope via `pip install modelscope -U`"
)
def
use_modelscope
()
->
bool
:
return
os
.
environ
.
get
(
"USE_MODELSCOPE_HUB"
,
"0"
).
lower
()
in
[
"true"
,
"1"
]
LLaMA-Factory/src/llamafactory/extras/packages.py
0 → 100644
View file @
032b90a1
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib.metadata
import
importlib.util
from
functools
import
lru_cache
from
typing
import
TYPE_CHECKING
from
packaging
import
version
if
TYPE_CHECKING
:
from
packaging.version
import
Version
def
_is_package_available
(
name
:
str
)
->
bool
:
return
importlib
.
util
.
find_spec
(
name
)
is
not
None
def
_get_package_version
(
name
:
str
)
->
"Version"
:
try
:
return
version
.
parse
(
importlib
.
metadata
.
version
(
name
))
except
Exception
:
return
version
.
parse
(
"0.0.0"
)
def
is_fastapi_available
():
return
_is_package_available
(
"fastapi"
)
def
is_galore_available
():
return
_is_package_available
(
"galore_torch"
)
def
is_gradio_available
():
return
_is_package_available
(
"gradio"
)
def
is_matplotlib_available
():
return
_is_package_available
(
"matplotlib"
)
def
is_pillow_available
():
return
_is_package_available
(
"PIL"
)
def
is_requests_available
():
return
_is_package_available
(
"requests"
)
def
is_rouge_available
():
return
_is_package_available
(
"rouge_chinese"
)
def
is_starlette_available
():
return
_is_package_available
(
"sse_starlette"
)
def
is_uvicorn_available
():
return
_is_package_available
(
"uvicorn"
)
def
is_vllm_available
():
return
_is_package_available
(
"vllm"
)
@
lru_cache
def
is_vllm_version_greater_than_0_5
():
return
_get_package_version
(
"vllm"
)
>=
version
.
parse
(
"0.5.0"
)
@
lru_cache
def
is_vllm_version_greater_than_0_5_1
():
return
_get_package_version
(
"vllm"
)
>=
version
.
parse
(
"0.5.1"
)
LLaMA-Factory/src/llamafactory/extras/ploting.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
math
import
os
from
typing
import
Any
,
Dict
,
List
from
transformers.trainer
import
TRAINER_STATE_NAME
from
.logging
import
get_logger
from
.packages
import
is_matplotlib_available
if
is_matplotlib_available
():
import
matplotlib.figure
import
matplotlib.pyplot
as
plt
logger
=
get_logger
(
__name__
)
def
smooth
(
scalars
:
List
[
float
])
->
List
[
float
]:
r
"""
EMA implementation according to TensorBoard.
"""
if
len
(
scalars
)
==
0
:
return
[]
last
=
scalars
[
0
]
smoothed
=
[]
weight
=
1.8
*
(
1
/
(
1
+
math
.
exp
(
-
0.05
*
len
(
scalars
)))
-
0.5
)
# a sigmoid function
for
next_val
in
scalars
:
smoothed_val
=
last
*
weight
+
(
1
-
weight
)
*
next_val
smoothed
.
append
(
smoothed_val
)
last
=
smoothed_val
return
smoothed
def
gen_loss_plot
(
trainer_log
:
List
[
Dict
[
str
,
Any
]])
->
"matplotlib.figure.Figure"
:
r
"""
Plots loss curves in LlamaBoard.
"""
plt
.
close
(
"all"
)
plt
.
switch_backend
(
"agg"
)
fig
=
plt
.
figure
()
ax
=
fig
.
add_subplot
(
111
)
steps
,
losses
=
[],
[]
for
log
in
trainer_log
:
if
log
.
get
(
"loss"
,
None
):
steps
.
append
(
log
[
"current_steps"
])
losses
.
append
(
log
[
"loss"
])
ax
.
plot
(
steps
,
losses
,
color
=
"#1f77b4"
,
alpha
=
0.4
,
label
=
"original"
)
ax
.
plot
(
steps
,
smooth
(
losses
),
color
=
"#1f77b4"
,
label
=
"smoothed"
)
ax
.
legend
()
ax
.
set_xlabel
(
"step"
)
ax
.
set_ylabel
(
"loss"
)
return
fig
def
plot_loss
(
save_dictionary
:
os
.
PathLike
,
keys
:
List
[
str
]
=
[
"loss"
])
->
None
:
r
"""
Plots loss curves and saves the image.
"""
plt
.
switch_backend
(
"agg"
)
with
open
(
os
.
path
.
join
(
save_dictionary
,
TRAINER_STATE_NAME
),
"r"
,
encoding
=
"utf-8"
)
as
f
:
data
=
json
.
load
(
f
)
for
key
in
keys
:
steps
,
metrics
=
[],
[]
for
i
in
range
(
len
(
data
[
"log_history"
])):
if
key
in
data
[
"log_history"
][
i
]:
steps
.
append
(
data
[
"log_history"
][
i
][
"step"
])
metrics
.
append
(
data
[
"log_history"
][
i
][
key
])
if
len
(
metrics
)
==
0
:
logger
.
warning
(
f
"No metric
{
key
}
to plot."
)
continue
plt
.
figure
()
plt
.
plot
(
steps
,
metrics
,
color
=
"#1f77b4"
,
alpha
=
0.4
,
label
=
"original"
)
plt
.
plot
(
steps
,
smooth
(
metrics
),
color
=
"#1f77b4"
,
label
=
"smoothed"
)
plt
.
title
(
"training {} of {}"
.
format
(
key
,
save_dictionary
))
plt
.
xlabel
(
"step"
)
plt
.
ylabel
(
key
)
plt
.
legend
()
figure_path
=
os
.
path
.
join
(
save_dictionary
,
"training_{}.png"
.
format
(
key
.
replace
(
"/"
,
"_"
)))
plt
.
savefig
(
figure_path
,
format
=
"png"
,
dpi
=
100
)
print
(
"Figure saved at:"
,
figure_path
)
LLaMA-Factory/src/llamafactory/hparams/__init__.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.data_args
import
DataArguments
from
.evaluation_args
import
EvaluationArguments
from
.finetuning_args
import
FinetuningArguments
from
.generating_args
import
GeneratingArguments
from
.model_args
import
ModelArguments
from
.parser
import
get_eval_args
,
get_infer_args
,
get_train_args
__all__
=
[
"DataArguments"
,
"EvaluationArguments"
,
"FinetuningArguments"
,
"GeneratingArguments"
,
"ModelArguments"
,
"get_eval_args"
,
"get_infer_args"
,
"get_train_args"
,
]
LLaMA-Factory/src/llamafactory/hparams/data_args.py
0 → 100644
View file @
032b90a1
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
,
field
from
typing
import
Literal
,
Optional
@
dataclass
class
DataArguments
:
r
"""
Arguments pertaining to what data we are going to input our model for training and evaluation.
"""
template
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Which template to use for constructing prompts in training and inference."
},
)
dataset
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of dataset(s) to use for training. Use commas to separate multiple datasets."
},
)
eval_dataset
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."
},
)
dataset_dir
:
str
=
field
(
default
=
"data"
,
metadata
=
{
"help"
:
"Path to the folder containing the datasets."
},
)
cutoff_len
:
int
=
field
(
default
=
1024
,
metadata
=
{
"help"
:
"The cutoff length of the tokenized inputs in the dataset."
},
)
train_on_prompt
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to disable the mask on the prompt."
},
)
mask_history
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to mask the history and train on the last turn only."
},
)
streaming
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Enable dataset streaming."
},
)
buffer_size
:
int
=
field
(
default
=
16384
,
metadata
=
{
"help"
:
"Size of the buffer to randomly sample examples from in dataset streaming."
},
)
mix_strategy
:
Literal
[
"concat"
,
"interleave_under"
,
"interleave_over"
]
=
field
(
default
=
"concat"
,
metadata
=
{
"help"
:
"Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."
},
)
interleave_probs
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Probabilities to sample data from datasets. Use commas to separate multiple datasets."
},
)
overwrite_cache
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Overwrite the cached training and evaluation sets."
},
)
preprocessing_num_workers
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of processes to use for the pre-processing."
},
)
max_samples
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"For debugging purposes, truncate the number of examples for each dataset."
},
)
eval_num_beams
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Number of beams to use for evaluation. This argument will be passed to `model.generate`"
},
)
ignore_pad_token_for_loss
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to ignore the tokens corresponding to the pad label in loss computation."
},
)
val_size
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Size of the development set, should be an integer or a float in range `[0,1)`."
},
)
packing
:
Optional
[
bool
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Enable sequences packing in training. Will automatically enable in pre-training."
},
)
neat_packing
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Enable sequence packing without cross-attention."
},
)
tool_format
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Tool format to use for constructing function calling examples."
},
)
tokenized_path
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to save or load the tokenized datasets."
},
)
def
__post_init__
(
self
):
def
split_arg
(
arg
):
if
isinstance
(
arg
,
str
):
return
[
item
.
strip
()
for
item
in
arg
.
split
(
","
)]
return
arg
self
.
dataset
=
split_arg
(
self
.
dataset
)
self
.
eval_dataset
=
split_arg
(
self
.
eval_dataset
)
if
self
.
dataset
is
None
and
self
.
val_size
>
1e-6
:
raise
ValueError
(
"Cannot specify `val_size` if `dataset` is None."
)
if
self
.
eval_dataset
is
not
None
and
self
.
val_size
>
1e-6
:
raise
ValueError
(
"Cannot specify `val_size` if `eval_dataset` is not None."
)
if
self
.
interleave_probs
is
not
None
:
if
self
.
mix_strategy
==
"concat"
:
raise
ValueError
(
"`interleave_probs` is only valid for interleaved mixing."
)
self
.
interleave_probs
=
list
(
map
(
float
,
split_arg
(
self
.
interleave_probs
)))
if
self
.
dataset
is
not
None
and
len
(
self
.
dataset
)
!=
len
(
self
.
interleave_probs
):
raise
ValueError
(
"The length of dataset and interleave probs should be identical."
)
if
self
.
eval_dataset
is
not
None
and
len
(
self
.
eval_dataset
)
!=
len
(
self
.
interleave_probs
):
raise
ValueError
(
"The length of eval dataset and interleave probs should be identical."
)
if
self
.
streaming
and
self
.
val_size
>
1e-6
and
self
.
val_size
<
1
:
raise
ValueError
(
"Streaming mode should have an integer val size."
)
if
self
.
streaming
and
self
.
max_samples
is
not
None
:
raise
ValueError
(
"`max_samples` is incompatible with `streaming`."
)
LLaMA-Factory/src/llamafactory/hparams/evaluation_args.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
dataclasses
import
dataclass
,
field
from
typing
import
Literal
,
Optional
from
datasets
import
DownloadMode
@
dataclass
class
EvaluationArguments
:
r
"""
Arguments pertaining to specify the evaluation parameters.
"""
task
:
str
=
field
(
metadata
=
{
"help"
:
"Name of the evaluation task."
},
)
task_dir
:
str
=
field
(
default
=
"evaluation"
,
metadata
=
{
"help"
:
"Path to the folder containing the evaluation datasets."
},
)
batch_size
:
int
=
field
(
default
=
4
,
metadata
=
{
"help"
:
"The batch size per GPU for evaluation."
},
)
seed
:
int
=
field
(
default
=
42
,
metadata
=
{
"help"
:
"Random seed to be used with data loaders."
},
)
lang
:
Literal
[
"en"
,
"zh"
]
=
field
(
default
=
"en"
,
metadata
=
{
"help"
:
"Language used at evaluation."
},
)
n_shot
:
int
=
field
(
default
=
5
,
metadata
=
{
"help"
:
"Number of examplars for few-shot learning."
},
)
save_dir
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to save the evaluation results."
},
)
download_mode
:
DownloadMode
=
field
(
default
=
DownloadMode
.
REUSE_DATASET_IF_EXISTS
,
metadata
=
{
"help"
:
"Download mode used for the evaluation datasets."
},
)
def
__post_init__
(
self
):
if
self
.
save_dir
is
not
None
and
os
.
path
.
exists
(
self
.
save_dir
):
raise
ValueError
(
"`save_dir` already exists, use another one."
)
LLaMA-Factory/src/llamafactory/hparams/finetuning_args.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Literal
,
Optional
@
dataclass
class
FreezeArguments
:
r
"""
Arguments pertaining to the freeze (partial-parameter) training.
"""
freeze_trainable_layers
:
int
=
field
(
default
=
2
,
metadata
=
{
"help"
:
(
"The number of trainable layers for freeze (partial-parameter) fine-tuning. "
"Positive numbers mean the last n layers are set as trainable, "
"negative numbers mean the first n layers are set as trainable."
)
},
)
freeze_trainable_modules
:
str
=
field
(
default
=
"all"
,
metadata
=
{
"help"
:
(
"Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the available modules."
)
},
)
freeze_extra_modules
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
"Name(s) of modules apart from hidden layers to be set as trainable "
"for freeze (partial-parameter) fine-tuning. "
"Use commas to separate multiple modules."
)
},
)
@
dataclass
class
LoraArguments
:
r
"""
Arguments pertaining to the LoRA training.
"""
additional_target
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
(
"Name(s) of modules apart from LoRA layers to be set as trainable "
"and saved in the final checkpoint. "
"Use commas to separate multiple modules."
)
},
)
lora_alpha
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The scale factor for LoRA fine-tuning (default: lora_rank * 2)."
},
)
lora_dropout
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Dropout rate for the LoRA fine-tuning."
},
)
lora_rank
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"The intrinsic dimension for LoRA fine-tuning."
},
)
lora_target
:
str
=
field
(
default
=
"all"
,
metadata
=
{
"help"
:
(
"Name(s) of target modules to apply LoRA. "
"Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
loraplus_lr_ratio
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"LoRA plus learning rate ratio (lr_B / lr_A)."
},
)
loraplus_lr_embedding
:
float
=
field
(
default
=
1e-6
,
metadata
=
{
"help"
:
"LoRA plus learning rate for lora embedding layers."
},
)
use_rslora
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the rank stabilization scaling factor for LoRA layer."
},
)
use_dora
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the weight-decomposed lora method (DoRA)."
},
)
pissa_init
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to initialize a PiSSA adapter."
},
)
pissa_iter
:
int
=
field
(
default
=
16
,
metadata
=
{
"help"
:
"The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."
},
)
pissa_convert
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to convert the PiSSA adapter to a normal LoRA adapter."
},
)
create_new_adapter
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to create a new adapter with randomly initialized weight."
},
)
@
dataclass
class
RLHFArguments
:
r
"""
Arguments pertaining to the PPO, DPO and KTO training.
"""
pref_beta
:
float
=
field
(
default
=
0.1
,
metadata
=
{
"help"
:
"The beta parameter in the preference loss."
},
)
pref_ftx
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"The supervised fine-tuning loss coefficient in DPO training."
},
)
pref_loss
:
Literal
[
"sigmoid"
,
"hinge"
,
"ipo"
,
"kto_pair"
,
"orpo"
,
"simpo"
]
=
field
(
default
=
"sigmoid"
,
metadata
=
{
"help"
:
"The type of DPO loss to use."
},
)
dpo_label_smoothing
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."
},
)
kto_chosen_weight
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"The weight factor of the desirable losses in KTO training."
},
)
kto_rejected_weight
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"The weight factor of the undesirable losses in KTO training."
},
)
simpo_gamma
:
float
=
field
(
default
=
0.5
,
metadata
=
{
"help"
:
"The target reward margin term in SimPO loss."
},
)
ppo_buffer_size
:
int
=
field
(
default
=
1
,
metadata
=
{
"help"
:
"The number of mini-batches to make experience buffer in a PPO optimization step."
},
)
ppo_epochs
:
int
=
field
(
default
=
4
,
metadata
=
{
"help"
:
"The number of epochs to perform in a PPO optimization step."
},
)
ppo_score_norm
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Use score normalization in PPO training."
},
)
ppo_target
:
float
=
field
(
default
=
6.0
,
metadata
=
{
"help"
:
"Target KL value for adaptive KL control in PPO training."
},
)
ppo_whiten_rewards
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whiten the rewards before compute advantages in PPO training."
},
)
ref_model
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the reference model used for the PPO or DPO training."
},
)
ref_model_adapters
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the adapters of the reference model."
},
)
ref_model_quantization_bit
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of bits to quantize the reference model."
},
)
reward_model
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the reward model used for the PPO training."
},
)
reward_model_adapters
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Path to the adapters of the reward model."
},
)
reward_model_quantization_bit
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The number of bits to quantize the reward model."
},
)
reward_model_type
:
Literal
[
"lora"
,
"full"
,
"api"
]
=
field
(
default
=
"lora"
,
metadata
=
{
"help"
:
"The type of the reward model in PPO training. Lora model only supports lora training."
},
)
@
dataclass
class
GaloreArguments
:
r
"""
Arguments pertaining to the GaLore algorithm.
"""
use_galore
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the gradient low-Rank projection (GaLore)."
},
)
galore_target
:
str
=
field
(
default
=
"all"
,
metadata
=
{
"help"
:
(
"Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
"Use `all` to specify all the linear modules."
)
},
)
galore_rank
:
int
=
field
(
default
=
16
,
metadata
=
{
"help"
:
"The rank of GaLore gradients."
},
)
galore_update_interval
:
int
=
field
(
default
=
200
,
metadata
=
{
"help"
:
"Number of steps to update the GaLore projection."
},
)
galore_scale
:
float
=
field
(
default
=
0.25
,
metadata
=
{
"help"
:
"GaLore scaling coefficient."
},
)
galore_proj_type
:
Literal
[
"std"
,
"reverse_std"
,
"right"
,
"left"
,
"full"
]
=
field
(
default
=
"std"
,
metadata
=
{
"help"
:
"Type of GaLore projection."
},
)
galore_layerwise
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to enable layer-wise update to further save memory."
},
)
@
dataclass
class
BAdamArgument
:
r
"""
Arguments pertaining to the BAdam optimizer.
"""
use_badam
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to use the BAdam optimizer."
},
)
badam_mode
:
Literal
[
"layer"
,
"ratio"
]
=
field
(
default
=
"layer"
,
metadata
=
{
"help"
:
"Whether to use layer-wise or ratio-wise BAdam optimizer."
},
)
badam_start_block
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The starting block index for layer-wise BAdam."
},
)
badam_switch_mode
:
Optional
[
Literal
[
"ascending"
,
"descending"
,
"random"
,
"fixed"
]]
=
field
(
default
=
"ascending"
,
metadata
=
{
"help"
:
"the strategy of picking block to update for layer-wise BAdam."
},
)
badam_switch_interval
:
Optional
[
int
]
=
field
(
default
=
50
,
metadata
=
{
"help"
:
"Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
},
)
badam_update_ratio
:
float
=
field
(
default
=
0.05
,
metadata
=
{
"help"
:
"The ratio of the update for ratio-wise BAdam."
},
)
badam_mask_mode
:
Literal
[
"adjacent"
,
"scatter"
]
=
field
(
default
=
"adjacent"
,
metadata
=
{
"help"
:
(
"The mode of the mask for BAdam optimizer. "
"`adjacent` means that the trainable parameters are adjacent to each other, "
"`scatter` means that trainable parameters are randomly choosed from the weight."
)
},
)
badam_verbose
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
(
"The verbosity level of BAdam optimizer. "
"0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
)
},
)
@
dataclass
class
FinetuningArguments
(
FreezeArguments
,
LoraArguments
,
RLHFArguments
,
GaloreArguments
,
BAdamArgument
):
r
"""
Arguments pertaining to which techniques we are going to fine-tuning with.
"""
pure_bf16
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to train model in purely bf16 precision (without AMP)."
},
)
stage
:
Literal
[
"pt"
,
"sft"
,
"rm"
,
"ppo"
,
"dpo"
,
"kto"
]
=
field
(
default
=
"sft"
,
metadata
=
{
"help"
:
"Which stage will be performed in training."
},
)
finetuning_type
:
Literal
[
"lora"
,
"freeze"
,
"full"
]
=
field
(
default
=
"lora"
,
metadata
=
{
"help"
:
"Which fine-tuning method to use."
},
)
use_llama_pro
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to make only the parameters in the expanded blocks trainable."
},
)
freeze_vision_tower
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether ot not to freeze vision tower in MLLM training."
},
)
train_mm_proj_only
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to train the multimodal projector for MLLM only."
},
)
compute_accuracy
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to compute the token-level accuracy at evaluation."
},
)
plot_loss
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Whether or not to save the training loss curves."
},
)
def
__post_init__
(
self
):
def
split_arg
(
arg
):
if
isinstance
(
arg
,
str
):
return
[
item
.
strip
()
for
item
in
arg
.
split
(
","
)]
return
arg
self
.
freeze_trainable_modules
:
List
[
str
]
=
split_arg
(
self
.
freeze_trainable_modules
)
self
.
freeze_extra_modules
:
Optional
[
List
[
str
]]
=
split_arg
(
self
.
freeze_extra_modules
)
self
.
lora_alpha
:
int
=
self
.
lora_alpha
or
self
.
lora_rank
*
2
self
.
lora_target
:
List
[
str
]
=
split_arg
(
self
.
lora_target
)
self
.
additional_target
:
Optional
[
List
[
str
]]
=
split_arg
(
self
.
additional_target
)
self
.
galore_target
:
List
[
str
]
=
split_arg
(
self
.
galore_target
)
self
.
freeze_vision_tower
=
self
.
freeze_vision_tower
or
self
.
train_mm_proj_only
self
.
use_ref_model
=
self
.
stage
==
"dpo"
and
self
.
pref_loss
not
in
[
"orpo"
,
"simpo"
]
assert
self
.
finetuning_type
in
[
"lora"
,
"freeze"
,
"full"
],
"Invalid fine-tuning method."
assert
self
.
ref_model_quantization_bit
in
[
None
,
8
,
4
],
"We only accept 4-bit or 8-bit quantization."
assert
self
.
reward_model_quantization_bit
in
[
None
,
8
,
4
],
"We only accept 4-bit or 8-bit quantization."
if
self
.
stage
==
"ppo"
and
self
.
reward_model
is
None
:
raise
ValueError
(
"`reward_model` is necessary for PPO training."
)
if
self
.
stage
==
"ppo"
and
self
.
reward_model_type
==
"lora"
and
self
.
finetuning_type
!=
"lora"
:
raise
ValueError
(
"`reward_model_type` cannot be lora for Freeze/Full PPO training."
)
if
self
.
stage
==
"dpo"
and
self
.
pref_loss
!=
"sigmoid"
and
self
.
dpo_label_smoothing
>
1e-6
:
raise
ValueError
(
"`dpo_label_smoothing` is only valid for sigmoid loss function."
)
if
self
.
use_llama_pro
and
self
.
finetuning_type
==
"full"
:
raise
ValueError
(
"`use_llama_pro` is only valid for Freeze or LoRA training."
)
if
self
.
finetuning_type
==
"lora"
and
(
self
.
use_galore
or
self
.
use_badam
):
raise
ValueError
(
"Cannot use LoRA with GaLore or BAdam together."
)
if
self
.
use_galore
and
self
.
use_badam
:
raise
ValueError
(
"Cannot use GaLore with BAdam together."
)
if
self
.
pissa_init
and
(
self
.
stage
in
[
"ppo"
,
"kto"
]
or
self
.
use_ref_model
):
raise
ValueError
(
"Cannot use PiSSA for current training stage."
)
if
self
.
train_mm_proj_only
and
self
.
finetuning_type
!=
"full"
:
raise
ValueError
(
"`train_mm_proj_only` is only valid for full training."
)
if
self
.
finetuning_type
!=
"lora"
:
if
self
.
loraplus_lr_ratio
is
not
None
:
raise
ValueError
(
"`loraplus_lr_ratio` is only valid for LoRA training."
)
if
self
.
use_rslora
:
raise
ValueError
(
"`use_rslora` is only valid for LoRA training."
)
if
self
.
use_dora
:
raise
ValueError
(
"`use_dora` is only valid for LoRA training."
)
if
self
.
pissa_init
:
raise
ValueError
(
"`pissa_init` is only valid for LoRA training."
)
LLaMA-Factory/src/llamafactory/hparams/generating_args.py
0 → 100644
View file @
032b90a1
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Dict
,
Optional
@
dataclass
class
GeneratingArguments
:
r
"""
Arguments pertaining to specify the decoding parameters.
"""
do_sample
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"Whether or not to use sampling, use greedy decoding otherwise."
},
)
temperature
:
float
=
field
(
default
=
0.95
,
metadata
=
{
"help"
:
"The value used to modulate the next token probabilities."
},
)
top_p
:
float
=
field
(
default
=
0.7
,
metadata
=
{
"help"
:
"The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."
},
)
top_k
:
int
=
field
(
default
=
50
,
metadata
=
{
"help"
:
"The number of highest probability vocabulary tokens to keep for top-k filtering."
},
)
num_beams
:
int
=
field
(
default
=
1
,
metadata
=
{
"help"
:
"Number of beams for beam search. 1 means no beam search."
},
)
max_length
:
int
=
field
(
default
=
1024
,
metadata
=
{
"help"
:
"The maximum length the generated tokens can have. It can be overridden by max_new_tokens."
},
)
max_new_tokens
:
int
=
field
(
default
=
1024
,
metadata
=
{
"help"
:
"The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."
},
)
repetition_penalty
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"The parameter for repetition penalty. 1.0 means no penalty."
},
)
length_penalty
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"Exponential penalty to the length that is used with beam-based generation."
},
)
default_system
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Default system message to use in chat completion."
},
)
def
to_dict
(
self
)
->
Dict
[
str
,
Any
]:
args
=
asdict
(
self
)
if
args
.
get
(
"max_new_tokens"
,
-
1
)
>
0
:
args
.
pop
(
"max_length"
,
None
)
else
:
args
.
pop
(
"max_new_tokens"
,
None
)
return
args
Prev
1
…
3
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment