Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
fb8abcec
Commit
fb8abcec
authored
Jun 20, 2023
by
haileyschoelkopf
Committed by
lintangsutawika
Jun 22, 2023
Browse files
improve pad_and_concat
parent
6d36a9c9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
81 additions
and
15 deletions
+81
-15
lm_eval/utils.py
lm_eval/utils.py
+81
-15
No files found.
lm_eval/utils.py
View file @
fb8abcec
...
...
@@ -14,6 +14,7 @@ from typing import List, Union
import
gc
import
torch
import
transformers
from
omegaconf
import
OmegaConf
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
...
...
@@ -431,25 +432,90 @@ def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
return
_torch_dtype
def
pad_and_concat
(
max_length
:
int
,
tensors
:
List
[
torch
.
Tensor
]):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
def
pad_and_concat
(
max_length
:
int
,
tensors
:
List
[
torch
.
Tensor
]
,
padding_side
=
"right"
):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
assert
padding_side
==
"left"
or
padding_side
==
"right"
,
f
"Unrecognized padding type: '
{
padding_side
}
' not 'left' or 'right'"
for
i
,
tensor
in
enumerate
(
tensors
):
tensor_len
=
tensor
.
shape
[
0
]
if
tensor_len
<
max_length
:
tensors
[
i
]
=
torch
.
cat
(
[
tensor
,
# [seq]
torch
.
zeros
(
max_length
-
tensor_len
,
dtype
=
torch
.
long
).
to
(
tensor
.
device
),
# [padding_length - seq]
],
dim
=
0
,
).
unsqueeze
(
0
)
if
padding_side
==
"right"
:
# right-pad
tensors
[
i
]
=
torch
.
cat
(
[
tensor
,
# [seq]
torch
.
zeros
(
max_length
-
tensor_len
,
dtype
=
torch
.
long
).
to
(
tensor
.
device
),
# [padding_length - seq]
],
dim
=
0
,
).
unsqueeze
(
0
)
else
:
# left-pad
tensors
[
i
]
=
torch
.
cat
(
[
torch
.
zeros
(
max_length
-
tensor_len
,
dtype
=
torch
.
long
).
to
(
tensor
.
device
),
# [padding_length - seq]
tensor
,
# [seq]
],
dim
=
0
,
).
unsqueeze
(
0
)
else
:
tensors
[
i
]
=
tensor
.
unsqueeze
(
0
)
return
torch
.
cat
(
tensors
,
dim
=
0
)
return
torch
.
cat
(
tensors
,
dim
=
0
)
# Multi-token stopping criteria
class
MultiTokenEOSCriteria
(
transformers
.
StoppingCriteria
):
"""Criteria to stop on the specified multi-token sequence."""
def
__init__
(
self
,
sequence
:
str
,
tokenizer
:
transformers
.
PreTrainedTokenizer
,
initial_decoder_input_length
:
int
,
batch_size
:
int
,
):
self
.
initial_decoder_input_length
=
initial_decoder_input_length
self
.
done_tracker
=
[
False
]
*
batch_size
self
.
sequence
=
sequence
self
.
sequence_ids
=
tokenizer
.
encode
(
sequence
,
add_special_tokens
=
False
)
self
.
sequence_id_len
=
len
(
self
.
sequence_ids
)
self
.
tokenizer
=
tokenizer
def
__call__
(
self
,
input_ids
,
scores
,
**
kwargs
)
->
bool
:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch
=
input_ids
[:,
self
.
initial_decoder_input_length
:][
:,
-
self
.
sequence_id_len
:
]
lookback_tokens_batch
=
self
.
tokenizer
.
batch_decode
(
lookback_ids_batch
)
for
i
,
done
in
enumerate
(
self
.
done_tracker
):
if
not
done
:
self
.
done_tracker
[
i
]
=
self
.
sequence
in
lookback_tokens_batch
[
i
]
return
False
not
in
self
.
done_tracker
def
stop_sequences_criteria
(
tokenizer
:
transformers
.
PreTrainedTokenizer
,
stop_sequences
:
List
[
str
],
initial_decoder_input_length
:
int
,
batch_size
:
int
,
)
->
transformers
.
StoppingCriteriaList
:
return
transformers
.
StoppingCriteriaList
(
[
*
[
MultiTokenEOSCriteria
(
sequence
,
tokenizer
,
initial_decoder_input_length
,
batch_size
)
for
sequence
in
stop_sequences
],
]
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment