Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1e9c8b59
Commit
1e9c8b59
authored
Nov 20, 2023
by
baberabb
Browse files
add typehints
parent
b22f3440
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
9 deletions
+17
-9
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+17
-9
No files found.
lm_eval/models/huggingface.py
View file @
1e9c8b59
...
...
@@ -16,13 +16,14 @@ from pathlib import Path
import
torch.nn.functional
as
F
from
lm_eval
import
utils
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.utils
import
MultiTokenEOSCriteria
,
stop_sequences_criteria
from
accelerate
import
Accelerator
,
find_executable_batch_size
,
DistributedType
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
,
Tuple
eval_logger
=
utils
.
eval_logger
...
...
@@ -420,7 +421,9 @@ class HFLM(LM):
utils
.
clear_torch_cache
()
return
batch_size
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
=
None
,
add_special_tokens
=
None
):
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
=
None
,
add_special_tokens
=
None
)
->
List
[
int
]:
""" """
if
add_special_tokens
is
None
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
...
...
@@ -442,7 +445,7 @@ class HFLM(LM):
padding_side
:
str
=
"left"
,
left_truncate_len
:
int
=
None
,
truncation
:
bool
=
False
,
):
)
->
Tuple
[
List
[
int
],
List
[
int
]]
:
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side
=
self
.
tokenizer
.
padding_side
self
.
tokenizer
.
padding_side
=
padding_side
...
...
@@ -536,7 +539,9 @@ class HFLM(LM):
return
logits
def
_encode_pair
(
self
,
context
,
continuation
):
def
_encode_pair
(
self
,
context
:
str
,
continuation
:
str
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
if
n_spaces
>
0
:
continuation
=
context
[
-
n_spaces
:]
+
continuation
...
...
@@ -551,7 +556,7 @@ class HFLM(LM):
continuation_enc
=
whole_enc
[
context_enc_len
:]
return
context_enc
,
continuation_enc
def
loglikelihood
(
self
,
requests
)
:
def
loglikelihood
(
self
,
requests
:
List
[
Instance
])
->
List
[
Tuple
[
float
,
bool
]]
:
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
if
context
==
""
:
...
...
@@ -566,7 +571,7 @@ class HFLM(LM):
return
self
.
_loglikelihood_tokens
(
new_reqs
)
def
loglikelihood_rolling
(
self
,
requests
)
:
def
loglikelihood_rolling
(
self
,
requests
:
List
[
Instance
])
->
List
[
float
]
:
loglikelihoods
=
[]
adaptive_batch_size
=
None
...
...
@@ -640,8 +645,11 @@ class HFLM(LM):
return
self
.
batch_sizes
[
sched
]
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
:
bool
=
False
,
override_bs
=
None
):
self
,
requests
:
List
[
Tuple
[
Tuple
[
str
,
str
],
List
[
int
],
List
[
int
]]],
disable_tqdm
:
bool
=
False
,
override_bs
:
int
=
None
,
)
->
List
[
Tuple
[
float
,
bool
]]:
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
...
...
@@ -820,7 +828,7 @@ class HFLM(LM):
return
re_ord
.
get_original
(
res
)
def
generate_until
(
self
,
requests
)
:
def
generate_until
(
self
,
requests
:
List
[
Instance
])
->
List
[
str
]
:
res
=
defaultdict
(
list
)
re_ords
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment