Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
dced6e96
Commit
dced6e96
authored
Jun 24, 2023
by
haileyschoelkopf
Browse files
add batched generation
parent
e3960fa0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
64 additions
and
17 deletions
+64
-17
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+64
-17
No files found.
lm_eval/models/huggingface.py
View file @
dced6e96
...
...
@@ -15,6 +15,7 @@ from lm_eval.api.registry import register_model
from
lm_eval.utils
import
MultiTokenEOSCriteria
,
stop_sequences_criteria
from
accelerate
import
Accelerator
from
typing
import
List
,
Union
@
register_model
(
"hf-auto"
,
"hf"
,
"huggingface"
)
...
...
@@ -99,6 +100,7 @@ class HFLM(LM):
)
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
self
.
tokenizer
.
pad_token_id
=
self
.
tokenizer
.
eos_token_id
self
.
_max_length
=
max_length
...
...
@@ -204,6 +206,33 @@ class HFLM(LM):
return
encoding
def
tok_batch_encode
(
self
,
strings
:
List
[
str
],
padding_side
=
"left"
,
left_truncate_len
=
None
):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side
=
self
.
tokenizer
.
padding_side
self
.
tokenizer
.
padding_side
=
padding_side
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
add_special_tokens
=
False
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
add_special_tokens
=
True
encoding
=
self
.
tokenizer
(
strings
,
padding
=
"longest"
,
return_tensors
=
"pt"
,
add_special_tokens
=
add_special_tokens
,
)
if
left_truncate_len
:
encoding
[
"input_ids"
]
=
encoding
[
"input_ids"
][:,
-
left_truncate_len
:]
encoding
[
"attention_mask"
]
=
encoding
[
"attention_mask"
][
:,
-
left_truncate_len
:
]
self
.
tokenizer
.
padding_side
=
old_padding_side
return
encoding
[
"input_ids"
],
encoding
[
"attention_mask"
]
def
tok_decode
(
self
,
tokens
):
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
return
self
.
tokenizer
.
decode
(
tokens
)
...
...
@@ -495,13 +524,21 @@ class HFLM(LM):
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
return
len
(
toks
),
x
[
0
]
return
-
len
(
toks
),
x
[
0
]
re_ord
=
utils
.
Reorderer
([
req
.
args
for
req
in
requests
],
_collate
)
for
context
,
gen_kwargs
in
tqdm
(
re_ord
.
get_reordered
(),
disable
=
(
self
.
rank
!=
0
)
for
chunk
in
utils
.
chunks
(
tqdm
(
re_ord
.
get_reordered
(),
disable
=
(
self
.
rank
!=
0
),
),
self
.
batch_size
,
):
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
gen_kwargs
=
all_gen_kwargs
[
0
]
# TODO: handle case where not all gen kwargs are same
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
...
...
@@ -534,32 +571,42 @@ class HFLM(LM):
# max len for inputs = encoder's whole max_length
max_ctx_len
=
self
.
max_length
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
,
left_truncate_len
=
max_ctx_len
)],
device
=
self
.
device
,
context_enc
,
attn_masks
=
self
.
tok_batch_encode
(
contexts
,
left_truncate_len
=
max_ctx_len
)
context_enc
=
context_enc
.
to
(
self
.
device
)
attn_masks
=
attn_masks
.
to
(
self
.
device
)
# [self.tok_encode(context, left_truncate_len=max_ctx_len)],
# device=self.device,
# ) for context in contexts]
# padding_len = max([context.shape[1] for context in context_enc])
# self.tokenizer.batch_encod
# context_enc = utils.pad_and_concat(padding_len, context_enc, padding_side="left")
cont
=
self
.
_model_generate
(
context
=
context_enc
,
attention_mask
=
attn_masks
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
stop
=
primary_until
,
**
kwargs
,
)
cont_toks_list
=
cont
[
0
].
tolist
()
# discard context toks if using causal decoder-only LM
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
cont_toks_list
=
cont_toks_list
[
context_enc
.
shape
[
1
]
:]
cont_toks_list
=
cont
.
tolist
()
for
cont_toks
,
context
in
zip
(
cont_toks_list
,
contexts
):
# discard context + left-padding toks if using causal decoder-only LM
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
cont_toks
=
cont_toks
[
context_enc
.
shape
[
1
]
:]
s
=
self
.
tok_decode
(
cont_toks
_list
)
s
=
self
.
tok_decode
(
cont_toks
)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for
term
in
until
:
if
len
(
term
)
>
0
:
# ignore '' separator, for seq2seq case where
s
=
s
.
split
(
term
)[
0
]
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for
term
in
until
:
if
len
(
term
)
>
0
:
# ignore '' separator, for seq2seq case where
s
=
s
.
split
(
term
)[
0
]
res
.
append
(
s
)
res
.
append
(
s
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
gen_kwargs
),
s
)
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
gen_kwargs
),
s
)
return
re_ord
.
get_original
(
res
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment