Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
6ccd520f
Unverified
Commit
6ccd520f
authored
Dec 19, 2024
by
Baber Abbasi
Committed by
GitHub
Dec 20, 2024
Browse files
add warning for truncation (#2585)
* add warning for truncation
parent
2b75b110
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
0 deletions
+16
-0
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+16
-0
No files found.
lm_eval/models/huggingface.py
View file @
6ccd520f
...
...
@@ -818,6 +818,12 @@ class HFLM(TemplateLM):
**
add_special_tokens
,
)
if
left_truncate_len
:
original_lengths
=
encoding
[
"input_ids"
].
size
(
1
)
if
original_lengths
>
left_truncate_len
:
eval_logger
.
warn
(
f
"Left truncation applied. Original sequence length was
{
original_lengths
}
, "
f
"truncating to last
{
left_truncate_len
}
tokens. Some content will be lost."
,
)
encoding
[
"input_ids"
]
=
encoding
[
"input_ids"
][:,
-
left_truncate_len
:]
encoding
[
"attention_mask"
]
=
encoding
[
"attention_mask"
][
:,
-
left_truncate_len
:
...
...
@@ -1096,6 +1102,13 @@ class HFLM(TemplateLM):
# when too long to fit in context, truncate from the left
if
self
.
backend
==
"causal"
:
total_length
=
len
(
context_enc
)
+
len
(
continuation_enc
)
if
total_length
>
self
.
max_length
+
1
:
eval_logger
.
warn
(
f
"Combined length of context (
{
len
(
context_enc
)
}
) and continuation (
{
len
(
continuation_enc
)
}
) "
f
"exceeds model's maximum length (
{
self
.
max_length
}
). "
f
"Truncating
{
total_length
-
self
.
max_length
+
1
}
tokens from the left."
)
inp
=
torch
.
tensor
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
],
dtype
=
torch
.
long
,
...
...
@@ -1303,6 +1316,9 @@ class HFLM(TemplateLM):
if
self
.
backend
==
"causal"
:
# max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len
=
self
.
max_length
-
max_gen_toks
assert
(
max_ctx_len
>
0
),
f
"Invalid configuration: requested max tokens to generate (
{
max_gen_toks
}
) must be less than model's maximum sequence length (
{
self
.
max_length
}
)."
elif
self
.
backend
==
"seq2seq"
:
# max len for inputs = encoder's whole max_length
max_ctx_len
=
self
.
max_length
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment