Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
d4d1330a
Unverified
Commit
d4d1330a
authored
Nov 23, 2023
by
Fengzhe Zhou
Committed by
GitHub
Nov 23, 2023
Browse files
[Sync] Fix cmnli, fix vicuna meta template, fix longbench postprocess and other minor fixes (#625)
parent
5329724b
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
72 additions
and
8 deletions
+72
-8
opencompass/datasets/longbench/longbench_trivia_qa.py
opencompass/datasets/longbench/longbench_trivia_qa.py
+7
-1
opencompass/models/huggingface.py
opencompass/models/huggingface.py
+49
-2
opencompass/models/llama2.py
opencompass/models/llama2.py
+8
-1
opencompass/openicl/icl_inferencer/icl_gen_inferencer.py
opencompass/openicl/icl_inferencer/icl_gen_inferencer.py
+7
-3
opencompass/runners/slurm_sequential.py
opencompass/runners/slurm_sequential.py
+1
-1
No files found.
opencompass/datasets/longbench/longbench_trivia_qa.py
View file @
d4d1330a
from
datasets
import
Dataset
,
load_dataset
from
datasets
import
Dataset
,
load_dataset
from
opencompass.registry
import
LOAD_DATASET
from
opencompass.registry
import
LOAD_DATASET
,
TEXT_POSTPROCESSORS
from
..base
import
BaseDataset
from
..base
import
BaseDataset
...
@@ -24,3 +24,9 @@ class LongBenchtriviaqaDataset(BaseDataset):
...
@@ -24,3 +24,9 @@ class LongBenchtriviaqaDataset(BaseDataset):
})
})
dataset
[
split
]
=
Dataset
.
from_list
(
raw_data
)
dataset
[
split
]
=
Dataset
.
from_list
(
raw_data
)
return
dataset
return
dataset
@
TEXT_POSTPROCESSORS
.
register_module
()
def
triviaqa_postprocess
(
text
:
str
)
->
str
:
text
=
text
.
lstrip
(
'
\n
'
).
split
(
'
\n
'
)[
0
]
return
text
opencompass/models/huggingface.py
View file @
d4d1330a
...
@@ -46,6 +46,9 @@ class HuggingFace(BaseModel):
...
@@ -46,6 +46,9 @@ class HuggingFace(BaseModel):
mode (str, optional): The method of input truncation when input length
mode (str, optional): The method of input truncation when input length
exceeds max_seq_len. 'mid' represents the part of input to
exceeds max_seq_len. 'mid' represents the part of input to
truncate. Defaults to 'none'.
truncate. Defaults to 'none'.
use_fastchat_template (str, optional): Whether to use fastchat to get
the conversation template. If True, fastchat needs to be
implemented first. Defaults to False.
Note:
Note:
About ``extract_pred_after_decode``: Commonly, we should extract the
About ``extract_pred_after_decode``: Commonly, we should extract the
...
@@ -68,7 +71,8 @@ class HuggingFace(BaseModel):
...
@@ -68,7 +71,8 @@ class HuggingFace(BaseModel):
extract_pred_after_decode
:
bool
=
False
,
extract_pred_after_decode
:
bool
=
False
,
batch_padding
:
bool
=
False
,
batch_padding
:
bool
=
False
,
pad_token_id
:
Optional
[
int
]
=
None
,
pad_token_id
:
Optional
[
int
]
=
None
,
mode
:
str
=
'none'
):
mode
:
str
=
'none'
,
use_fastchat_template
:
bool
=
False
):
super
().
__init__
(
path
=
path
,
super
().
__init__
(
path
=
path
,
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
tokenizer_only
=
tokenizer_only
,
tokenizer_only
=
tokenizer_only
,
...
@@ -91,6 +95,7 @@ class HuggingFace(BaseModel):
...
@@ -91,6 +95,7 @@ class HuggingFace(BaseModel):
model_kwargs
=
model_kwargs
,
model_kwargs
=
model_kwargs
,
peft_path
=
peft_path
)
peft_path
=
peft_path
)
self
.
generation_kwargs
=
generation_kwargs
self
.
generation_kwargs
=
generation_kwargs
self
.
use_fastchat_template
=
use_fastchat_template
def
_load_tokenizer
(
self
,
path
:
str
,
tokenizer_path
:
Optional
[
str
],
def
_load_tokenizer
(
self
,
path
:
str
,
tokenizer_path
:
Optional
[
str
],
tokenizer_kwargs
:
dict
):
tokenizer_kwargs
:
dict
):
...
@@ -220,6 +225,20 @@ class HuggingFace(BaseModel):
...
@@ -220,6 +225,20 @@ class HuggingFace(BaseModel):
if
self
.
extract_pred_after_decode
:
if
self
.
extract_pred_after_decode
:
prompt_lens
=
[
len
(
input_
)
for
input_
in
inputs
]
prompt_lens
=
[
len
(
input_
)
for
input_
in
inputs
]
if
self
.
use_fastchat_template
:
try
:
from
fastchat.model
import
get_conversation_template
except
ModuleNotFoundError
:
raise
ModuleNotFoundError
(
'Fastchat is not implemented. You can use '
'
\'
pip install "fschat[model_worker,webui]"
\'
'
'to implement fastchat.'
)
for
i
in
range
(
len
(
inputs
)):
conv
=
get_conversation_template
(
'vicuna'
)
conv
.
append_message
(
conv
.
roles
[
0
],
inputs
[
i
])
conv
.
append_message
(
conv
.
roles
[
1
],
None
)
inputs
[
i
]
=
conv
.
get_prompt
()
# step-1: tokenize the input with batch_encode_plus
# step-1: tokenize the input with batch_encode_plus
tokens
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
,
tokens
=
self
.
tokenizer
.
batch_encode_plus
(
inputs
,
padding
=
True
,
padding
=
True
,
...
@@ -263,6 +282,19 @@ class HuggingFace(BaseModel):
...
@@ -263,6 +282,19 @@ class HuggingFace(BaseModel):
if
self
.
extract_pred_after_decode
:
if
self
.
extract_pred_after_decode
:
prompt_lens
=
[
len
(
input_
)
for
input_
in
inputs
]
prompt_lens
=
[
len
(
input_
)
for
input_
in
inputs
]
if
self
.
use_fastchat_template
:
try
:
from
fastchat.model
import
get_conversation_template
except
ModuleNotFoundError
:
raise
ModuleNotFoundError
(
'Fastchat is not implemented. You can use '
'
\'
pip install "fschat[model_worker,webui]"
\'
'
'to implement fastchat.'
)
conv
=
get_conversation_template
(
'vicuna'
)
conv
.
append_message
(
conv
.
roles
[
0
],
inputs
[
0
])
conv
.
append_message
(
conv
.
roles
[
1
],
None
)
inputs
=
[
conv
.
get_prompt
()]
if
self
.
mode
==
'mid'
:
if
self
.
mode
==
'mid'
:
input_ids
=
self
.
tokenizer
(
inputs
,
truncation
=
False
)[
'input_ids'
]
input_ids
=
self
.
tokenizer
(
inputs
,
truncation
=
False
)[
'input_ids'
]
input_ids
=
torch
.
tensor
(
input_ids
,
device
=
self
.
model
.
device
)
input_ids
=
torch
.
tensor
(
input_ids
,
device
=
self
.
model
.
device
)
...
@@ -491,7 +523,8 @@ class HuggingFaceChatGLM3(HuggingFace):
...
@@ -491,7 +523,8 @@ class HuggingFaceChatGLM3(HuggingFace):
def
generate
(
self
,
def
generate
(
self
,
inputs
:
List
[
str
or
PromptList
],
inputs
:
List
[
str
or
PromptList
],
max_out_len
:
int
=
512
,
max_out_len
:
int
=
512
,
temperature
:
float
=
0.6
)
->
str
:
temperature
:
float
=
0.6
,
skip_overlength
=
False
)
->
str
:
"""Generate response from input prompt.
"""Generate response from input prompt.
Args:
Args:
...
@@ -518,6 +551,20 @@ class HuggingFaceChatGLM3(HuggingFace):
...
@@ -518,6 +551,20 @@ class HuggingFaceChatGLM3(HuggingFace):
history
.
append
(
msg
)
history
.
append
(
msg
)
user_content
=
history
[
-
1
][
'content'
]
user_content
=
history
[
-
1
][
'content'
]
history
=
history
[:
-
1
]
history
=
history
[:
-
1
]
if
skip_overlength
:
# The model will report the following error
# if the sequence length is greater than the maximum length:
# "Input length of input_ids is {INPUT_IDS},
# but `max_length` is set to 8192.
# This can lead to unexpected behavior.
# You should consider increasing `max_new_tokens`."
# The following hardcode can fix this exception.
len_user_content
=
len
(
self
.
tokenizer
.
encode
(
user_content
))
if
len_user_content
>
8192
:
responses
.
append
(
''
)
continue
try
:
try
:
response
,
history
=
self
.
model
.
chat
(
self
.
tokenizer
,
response
,
history
=
self
.
model
.
chat
(
self
.
tokenizer
,
user_content
,
user_content
,
...
...
opencompass/models/llama2.py
View file @
d4d1330a
...
@@ -141,12 +141,19 @@ class Llama2Chat(BaseModel):
...
@@ -141,12 +141,19 @@ class Llama2Chat(BaseModel):
path
:
str
,
path
:
str
,
max_seq_len
:
int
,
max_seq_len
:
int
,
max_batch_size
:
int
,
max_batch_size
:
int
,
tokenizer_path
:
Optional
[
str
]
=
None
):
tokenizer_path
:
Optional
[
str
]
=
None
,
force_bf16
=
False
):
from
llama
import
Llama
from
llama
import
Llama
self
.
generator
=
Llama
.
build
(
path
,
tokenizer_path
,
max_seq_len
,
self
.
generator
=
Llama
.
build
(
path
,
tokenizer_path
,
max_seq_len
,
max_batch_size
)
max_batch_size
)
self
.
tokenizer
=
self
.
generator
.
tokenizer
self
.
tokenizer
=
self
.
generator
.
tokenizer
self
.
model
=
self
.
generator
.
model
self
.
model
=
self
.
generator
.
model
if
force_bf16
:
# force set model to `bfloat16` to fix
# the exception of 'RuntimeError: probability tensor
# contains either `inf`, `nan` or element < 0',
# encountered during the inference of llama2-7b
self
.
model
=
self
.
model
.
bfloat16
()
def
_load_tokenizer
(
self
,
tokenizer_path
:
str
):
def
_load_tokenizer
(
self
,
tokenizer_path
:
str
):
from
llama
import
Tokenizer
from
llama
import
Tokenizer
...
...
opencompass/openicl/icl_inferencer/icl_gen_inferencer.py
View file @
d4d1330a
...
@@ -108,9 +108,13 @@ class GenInferencer(BaseInferencer):
...
@@ -108,9 +108,13 @@ class GenInferencer(BaseInferencer):
'tmp_'
+
output_json_filename
)
'tmp_'
+
output_json_filename
)
if
osp
.
exists
(
tmp_json_filepath
):
if
osp
.
exists
(
tmp_json_filepath
):
# TODO: move resume to output handler
# TODO: move resume to output handler
tmp_result_dict
=
mmengine
.
load
(
tmp_json_filepath
)
try
:
output_handler
.
results_dict
=
tmp_result_dict
tmp_result_dict
=
mmengine
.
load
(
tmp_json_filepath
)
index
=
len
(
tmp_result_dict
)
except
Exception
:
pass
else
:
output_handler
.
results_dict
=
tmp_result_dict
index
=
len
(
tmp_result_dict
)
# 4. Wrap prompts with Dataloader
# 4. Wrap prompts with Dataloader
dataloader
=
self
.
get_dataloader
(
prompt_list
[
index
:],
self
.
batch_size
)
dataloader
=
self
.
get_dataloader
(
prompt_list
[
index
:],
self
.
batch_size
)
...
...
opencompass/runners/slurm_sequential.py
View file @
d4d1330a
...
@@ -96,7 +96,7 @@ class SlurmSequentialRunner(BaseRunner):
...
@@ -96,7 +96,7 @@ class SlurmSequentialRunner(BaseRunner):
try
:
try
:
parent_conns
=
[]
parent_conns
=
[]
num_workers
=
min
(
self
.
max_num_workers
,
len
(
tasks
))
num_workers
=
max
(
min
(
self
.
max_num_workers
,
len
(
tasks
))
,
1
)
with
Pool
(
processes
=
num_workers
)
as
pool
:
with
Pool
(
processes
=
num_workers
)
as
pool
:
for
task
in
tasks
:
for
task
in
tasks
:
parent_conn
,
child_conn
=
Pipe
()
parent_conn
,
child_conn
=
Pipe
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment