Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
51afaca2
Commit
51afaca2
authored
Mar 25, 2024
by
lintangsutawika
Browse files
seq2seq
parent
5a85f9bb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
13 deletions
+15
-13
lm_eval/api/model.py
lm_eval/api/model.py
+3
-2
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+12
-11
No files found.
lm_eval/api/model.py
View file @
51afaca2
...
@@ -305,8 +305,9 @@ class TemplateLM(LM):
...
@@ -305,8 +305,9 @@ class TemplateLM(LM):
continuation_enc
=
whole_enc
[
context_enc_len
:]
continuation_enc
=
whole_enc
[
context_enc_len
:]
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
context_enc
=
self
.
tok_encode
(
context
,
add_special_tokens
=
True
)
# The encoder may require context end with special tokens
continuation_enc
=
self
.
tok_encode
(
continuation
,
add_special_tokens
=
True
)
context_enc
=
self
.
tok_encode
(
context
)
continuation_enc
=
self
.
tok_encode
(
continuation
,
add_special_tokens
=
False
)
return
context_enc
,
continuation_enc
return
context_enc
,
continuation_enc
...
...
lm_eval/models/huggingface.py
View file @
51afaca2
...
@@ -664,14 +664,14 @@ class HFLM(TemplateLM):
...
@@ -664,14 +664,14 @@ class HFLM(TemplateLM):
self
,
string
:
str
,
left_truncate_len
=
None
,
add_special_tokens
=
None
self
,
string
:
str
,
left_truncate_len
=
None
,
add_special_tokens
=
None
)
->
List
[
int
]:
)
->
List
[
int
]:
""" """
""" """
if
add_special_tokens
is
None
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
add_special_tokens
=
False
or
self
.
add_bos_token
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
# TODO: investigate best practices for enc-dec models + special tokens
add_special_tokens
=
True
encoding
=
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
add_special_tokens
)
add_special_tokens
=
{}
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
add_special_tokens
=
{
"add_special_tokens"
:
False
or
self
.
add_bos_token
}
encoding
=
self
.
tokenizer
.
encode
(
string
,
**
add_special_tokens
)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if
left_truncate_len
:
if
left_truncate_len
:
...
@@ -690,17 +690,18 @@ class HFLM(TemplateLM):
...
@@ -690,17 +690,18 @@ class HFLM(TemplateLM):
old_padding_side
=
self
.
tokenizer
.
padding_side
old_padding_side
=
self
.
tokenizer
.
padding_side
self
.
tokenizer
.
padding_side
=
padding_side
self
.
tokenizer
.
padding_side
=
padding_side
add_special_tokens
=
{}
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
add_special_tokens
=
False
or
self
.
add_bos_token
add_special_tokens
=
{
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
"add_special_tokens"
:
False
or
self
.
add_bos_token
add_special_tokens
=
True
}
encoding
=
self
.
tokenizer
(
encoding
=
self
.
tokenizer
(
strings
,
strings
,
truncation
=
truncation
,
truncation
=
truncation
,
padding
=
"longest"
,
padding
=
"longest"
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
add_special_tokens
=
add_special_tokens
,
**
add_special_tokens
,
)
)
if
left_truncate_len
:
if
left_truncate_len
:
encoding
[
"input_ids"
]
=
encoding
[
"input_ids"
][:,
-
left_truncate_len
:]
encoding
[
"input_ids"
]
=
encoding
[
"input_ids"
][:,
-
left_truncate_len
:]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment