Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
opencompass
Commits
c8f1d513
Unverified
Commit
c8f1d513
authored
Jul 11, 2023
by
Hubert
Committed by
GitHub
Jul 11, 2023
Browse files
[Fix] fix clp inferencer (#44)
parent
50b658d2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
39 deletions
+31
-39
opencompass/openicl/icl_inferencer/icl_clp_inferencer.py
opencompass/openicl/icl_inferencer/icl_clp_inferencer.py
+31
-39
No files found.
opencompass/openicl/icl_inferencer/icl_clp_inferencer.py
View file @
c8f1d513
...
...
@@ -5,24 +5,22 @@ import os
from
functools
import
partial
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
accelerate
import
Accelerator
from
tqdm
import
trange
from
opencompass.models
import
BaseModel
from
opencompass.openicl
import
PromptTemplate
from
opencompass.openicl.icl_inferencer.icl_base_inferencer
import
\
PPLInferencerOutputHandler
from
opencompass.openicl.icl_retriever
import
BaseRetriever
from
opencompass.openicl.utils.logging
import
get_logger
from
opencompass.registry
import
ICL_INFERENCERS
from
..icl_prompt_template
import
PromptTemplate
from
..icl_retriever
import
BaseRetriever
from
..utils
import
get_logger
from
.icl_base_inferencer
import
BaseInferencer
,
PPLInferencerOutputHandler
logger
=
get_logger
(
__name__
)
@
ICL_INFERENCERS
.
register_module
()
class
CLPInferencer
:
class
CLPInferencer
(
BaseInferencer
)
:
"""Conditional log probability based In-context Learning Inferencer.
Calculate the log probability of each choices according the logits.
...
...
@@ -42,8 +40,6 @@ class CLPInferencer:
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
the LM.
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
accelerator (:obj:`Accelerator`, optional): An instance of the
`Accelerator` class, used for multiprocessing.
output_json_filepath (:obj:`str`, optional): File path for output
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
...
...
@@ -57,29 +53,20 @@ class CLPInferencer:
model
:
BaseModel
,
max_seq_len
:
Optional
[
int
]
=
None
,
batch_size
:
Optional
[
int
]
=
1
,
accelerator
:
Optional
[
Accelerator
]
=
None
,
output_json_filepath
:
Optional
[
str
]
=
'./icl_inference_output'
,
output_json_filename
:
Optional
[
str
]
=
'predictions'
,
fix_id_list
:
Optional
[
List
[
int
]]
=
None
,
single_token
:
bool
=
True
,
**
kwargs
)
->
None
:
super
().
__init__
(
model
=
model
,
max_seq_len
=
max_seq_len
,
batch_size
=
batch_size
,
output_json_filename
=
output_json_filename
,
output_json_filepath
=
output_json_filepath
,
**
kwargs
,
)
self
.
model
=
model
self
.
accelerator
=
accelerator
self
.
is_main_process
=
(
True
if
self
.
accelerator
is
None
or
self
.
accelerator
.
is_main_process
else
False
)
self
.
device
=
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
if
self
.
model
is
not
None
:
self
.
model
.
to
(
self
.
device
)
self
.
max_seq_len
=
max_seq_len
self
.
batch_size
=
batch_size
self
.
output_json_filepath
=
output_json_filepath
self
.
output_json_filename
=
output_json_filename
if
not
os
.
path
.
exists
(
self
.
output_json_filepath
):
os
.
makedirs
(
self
.
output_json_filepath
)
self
.
fix_id_list
=
fix_id_list
# TODO: support multiple token
assert
single_token
,
'Only support single token choice currently.'
...
...
@@ -111,8 +98,8 @@ class CLPInferencer:
# 3. Generate in-context examples for testing inputs
for
idx
in
range
(
len
(
ice_idx_list
)):
ice
.
append
(
retriever
.
generate_ice
(
ice_idx_list
[
idx
],
ice_template
=
ice_template
))
retriever
.
generate_ice
(
ice_idx_list
[
idx
],
ice_template
=
ice_template
))
output_handler
.
save_ice
(
ice
)
# 4. Collect prompts and calculate conditional log probs
...
...
@@ -129,6 +116,9 @@ class CLPInferencer:
]
except
ValueError
:
choice_ids
=
[
self
.
model
.
tokenizer
.
encode
(
c
)
for
c
in
choices
]
if
self
.
model
.
tokenizer
.
__class__
.
__name__
==
'ChatGLMTokenizer'
:
# noqa
choice_ids
=
[
c
[
2
:]
for
c
in
choice_ids
]
else
:
if
self
.
model
.
tokenizer
.
add_bos_token
:
choice_ids
=
[
c
[
1
:]
for
c
in
choice_ids
]
if
self
.
model
.
tokenizer
.
add_eos_token
:
...
...
@@ -175,7 +165,8 @@ class CLPInferencer:
choice_target_ids
.
append
(
prompt_token_num
-
1
)
logger
.
info
(
'Calculating conditional log probability for prompts.'
)
for
idx
in
trange
(
0
,
for
idx
in
trange
(
0
,
len
(
prompt_list
),
self
.
batch_size
,
disable
=
not
self
.
is_main_process
):
...
...
@@ -209,10 +200,11 @@ class CLPInferencer:
choice_ids
,
mask_length
=
None
):
# TODO: support multiple tokens
try
:
if
hasattr
(
self
.
model
,
'generator'
)
:
outputs
,
_
=
self
.
model
.
generator
.
get_logits
(
input_texts
)
e
xcept
AttributeError
:
e
lse
:
outputs
,
_
=
self
.
model
.
get_logits
(
input_texts
)
shift_logits
=
outputs
[...,
:
-
1
,
:].
contiguous
()
shift_logits
=
F
.
log_softmax
(
shift_logits
,
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment