Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
ef09b104
Commit
ef09b104
authored
Jun 20, 2023
by
haileyschoelkopf
Browse files
improve pad_and_concat
parent
e8c84a38
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
3 deletions
+69
-3
lm_eval/utils.py
lm_eval/utils.py
+69
-3
No files found.
lm_eval/utils.py
View file @
ef09b104
...
@@ -19,6 +19,7 @@ from omegaconf import OmegaConf
...
@@ -19,6 +19,7 @@ from omegaconf import OmegaConf
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
from
itertools
import
islice
from
itertools
import
islice
import
torch
import
torch
import
transformers
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
...
@@ -415,21 +416,36 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
...
@@ -415,21 +416,36 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
"""
"""
return
islice
(
raw_iterator
,
rank
,
limit
,
world_size
)
return
islice
(
raw_iterator
,
rank
,
limit
,
world_size
)
def
pad_and_concat
(
max_length
:
int
,
tensors
:
List
[
torch
.
Tensor
]):
def
pad_and_concat
(
max_length
:
int
,
tensors
:
List
[
torch
.
Tensor
]
,
padding_side
=
"right"
):
"""
"""
Method for padding a list of tensors given the maximum tensor
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
length in the batch. Used for batching inputs and continuations in
seq2seq models.
seq2seq models.
"""
"""
assert
padding_side
==
"left"
or
padding_side
==
"right"
,
f
"Unrecognized padding type: '
{
padding_side
}
' not 'left' or 'right'"
for
i
,
tensor
in
enumerate
(
tensors
):
for
i
,
tensor
in
enumerate
(
tensors
):
tensor_len
=
tensor
.
shape
[
0
]
tensor_len
=
tensor
.
shape
[
0
]
if
tensor_len
<
max_length
:
if
tensor_len
<
max_length
:
tensors
[
i
]
=
torch
.
cat
(
if
padding_side
==
"right"
:
# right-pad
tensors
[
i
]
=
torch
.
cat
(
[
tensor
,
# [seq]
torch
.
zeros
(
max_length
-
tensor_len
,
dtype
=
torch
.
long
).
to
(
tensor
.
device
),
# [padding_length - seq]
],
dim
=
0
,
).
unsqueeze
(
0
)
else
:
# left-pad
tensors
[
i
]
=
torch
.
cat
(
[
[
tensor
,
# [seq]
torch
.
zeros
(
max_length
-
tensor_len
,
dtype
=
torch
.
long
).
to
(
torch
.
zeros
(
max_length
-
tensor_len
,
dtype
=
torch
.
long
).
to
(
tensor
.
device
tensor
.
device
),
# [padding_length - seq]
),
# [padding_length - seq]
tensor
,
# [seq]
],
],
dim
=
0
,
dim
=
0
,
).
unsqueeze
(
0
)
).
unsqueeze
(
0
)
...
@@ -442,3 +458,53 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor]):
...
@@ -442,3 +458,53 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor]):
def
clear_torch_cache
():
def
clear_torch_cache
():
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# Multi-token stopping criteria
class
MultiTokenEOSCriteria
(
transformers
.
StoppingCriteria
):
"""Criteria to stop on the specified multi-token sequence."""
def
__init__
(
self
,
sequence
:
str
,
tokenizer
:
transformers
.
PreTrainedTokenizer
,
initial_decoder_input_length
:
int
,
batch_size
:
int
,
):
self
.
initial_decoder_input_length
=
initial_decoder_input_length
self
.
done_tracker
=
[
False
]
*
batch_size
self
.
sequence
=
sequence
self
.
sequence_ids
=
tokenizer
.
encode
(
sequence
,
add_special_tokens
=
False
)
self
.
sequence_id_len
=
len
(
self
.
sequence_ids
)
self
.
tokenizer
=
tokenizer
def
__call__
(
self
,
input_ids
,
scores
,
**
kwargs
)
->
bool
:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch
=
input_ids
[:,
self
.
initial_decoder_input_length
:][
:,
-
self
.
sequence_id_len
:
]
lookback_tokens_batch
=
self
.
tokenizer
.
batch_decode
(
lookback_ids_batch
)
for
i
,
done
in
enumerate
(
self
.
done_tracker
):
if
not
done
:
self
.
done_tracker
[
i
]
=
self
.
sequence
in
lookback_tokens_batch
[
i
]
return
False
not
in
self
.
done_tracker
def
stop_sequences_criteria
(
tokenizer
:
transformers
.
PreTrainedTokenizer
,
stop_sequences
:
List
[
str
],
initial_decoder_input_length
:
int
,
batch_size
:
int
,
)
->
transformers
.
StoppingCriteriaList
:
return
transformers
.
StoppingCriteriaList
(
[
*
[
MultiTokenEOSCriteria
(
sequence
,
tokenizer
,
initial_decoder_input_length
,
batch_size
)
for
sequence
in
stop_sequences
],
]
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment