Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0a3b8069
Commit
0a3b8069
authored
Jun 22, 2023
by
lintangsutawika
Browse files
resolved merge conflict
parent
1de7e4a5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
0 deletions
+24
-0
lm_eval/utils.py
lm_eval/utils.py
+24
-0
No files found.
lm_eval/utils.py
View file @
0a3b8069
...
...
@@ -429,3 +429,27 @@ def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
else
:
_torch_dtype
=
dtype
return
_torch_dtype
def
pad_and_concat
(
max_length
:
int
,
tensors
:
List
[
torch
.
Tensor
]):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
for
i
,
tensor
in
enumerate
(
tensors
):
tensor_len
=
tensor
.
shape
[
0
]
if
tensor_len
<
max_length
:
tensors
[
i
]
=
torch
.
cat
(
[
tensor
,
# [seq]
torch
.
zeros
(
max_length
-
tensor_len
,
dtype
=
torch
.
long
).
to
(
tensor
.
device
),
# [padding_length - seq]
],
dim
=
0
,
).
unsqueeze
(
0
)
else
:
tensors
[
i
]
=
tensor
.
unsqueeze
(
0
)
return
torch
.
cat
(
tensors
,
dim
=
0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment