Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1a6b31a8
Commit
1a6b31a8
authored
Jun 22, 2023
by
lintangsutawika
Browse files
resolved conflict
parent
1d7d3de5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
18 deletions
+34
-18
lm_eval/models/hf_causal.py
lm_eval/models/hf_causal.py
+34
-18
No files found.
lm_eval/models/hf_causal.py
View file @
1a6b31a8
...
...
@@ -11,12 +11,14 @@ from lm_eval.logger import eval_logger
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.utils
import
MultiTokenEOSCriteria
,
stop_sequences_criteria
from
accelerate
import
Accelerator
from
typing
import
Optional
,
Union
@
register_model
(
"hf-causal"
)
class
HFLM
(
LM
):
class
HF
Causal
LM
(
LM
):
def
__init__
(
self
,
device
=
"cuda"
,
...
...
@@ -35,6 +37,7 @@ class HFLM(LM):
assert
isinstance
(
batch_size
,
int
)
gpus
=
torch
.
cuda
.
device_count
()
if
gpus
<=
1
:
if
device
:
if
device
not
in
[
"cuda"
,
"cpu"
]:
...
...
@@ -66,7 +69,7 @@ class HFLM(LM):
).
to
(
self
.
device
)
self
.
model
.
eval
()
print
(
self
.
model
.
dtype
)
eval_logger
.
info
(
self
.
model
.
dtype
)
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
pretrained
if
tokenizer
is
None
else
tokenizer
,
...
...
@@ -90,6 +93,14 @@ class HFLM(LM):
)
self
.
_rank
=
accelerator
.
local_process_index
self
.
_world_size
=
accelerator
.
num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self
.
_device
=
(
torch
.
device
(
f
"cuda:
{
accelerator
.
local_process_index
}
"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
)
self
.
model
.
to
(
self
.
device
)
else
:
self
.
model
=
accelerator
.
prepare
(
self
.
model
)
self
.
_device
=
torch
.
device
(
f
"cuda:
{
accelerator
.
local_process_index
}
"
)
...
...
@@ -157,27 +168,33 @@ class HFLM(LM):
logits returned from the model
"""
with
torch
.
no_grad
():
return
self
.
model
(
inps
)
[
0
]
return
self
.
model
(
inps
)
.
logits
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
,
**
generation_kwargs
):
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if
"do_sample"
not
in
generation_kwargs
.
keys
():
generation_kwargs
[
"do_sample"
]
=
False
# build stopping criteria
stopping_criteria
=
stop_sequences_criteria
(
self
.
tokenizer
,
stop
,
1
,
context
.
shape
[
0
]
)
if
hasattr
(
self
,
"accelerator"
):
return
self
.
accelerator
.
unwrap_model
(
self
.
model
).
generate
(
context
,
max_length
=
max_length
,
pad_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
self
.
eot_token_id
,
use_cache
=
True
,
**
generation_kwargs
,
)
else
:
return
self
.
model
.
generate
(
context
,
max_length
=
max_length
,
pad_token_id
=
eos_token_id
,
eos_token_id
=
eos_token_id
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
self
.
eot_token_id
,
use_cache
=
True
,
**
generation_kwargs
,
)
...
...
@@ -197,9 +214,6 @@ class HFLM(LM):
return
self
.
_loglikelihood_tokens
(
new_reqs
)
def
loglikelihood_rolling
(
self
,
requests
):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
loglikelihoods
=
[]
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
],
disable
=
(
self
.
rank
!=
0
)):
rolling_token_windows
=
list
(
...
...
@@ -368,6 +382,7 @@ class HFLM(LM):
re_ord
=
utils
.
Reorderer
([
req
.
args
for
req
in
requests
],
_collate
)
for
context
,
gen_kwargs
in
tqdm
(
re_ord
.
get_reordered
()):
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
gen_kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
if
"until"
in
gen_kwargs
.
keys
():
...
...
@@ -389,12 +404,13 @@ class HFLM(LM):
else
:
max_gen_toks
=
self
.
max_gen_toks
try
:
(
primary_until
,)
=
self
.
tok_encode
(
until
[
0
])
except
Exception
:
# if our primary until would be multiple tokens long, we'll have errors.
# TODO: handling this better will let us stop generating earlier + often.
primary_until
=
self
.
eot_token_id
primary_until
=
until
[
0
]
# try:
# (primary_until,) = self.tok_encode(until[0])
# except Exception:
# # if our primary until would be multiple tokens long, we'll have errors.
# # TODO: handling this better will let us stop generating earlier + often.
# primary_until = self.eot_token_id
context_enc
=
torch
.
tensor
(
[
self
.
tok_encode
(
context
)[
max_gen_toks
-
self
.
max_length
:]]
...
...
@@ -403,7 +419,7 @@ class HFLM(LM):
cont
=
self
.
_model_generate
(
context
=
context_enc
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
eos_token_id
=
primary_until
,
stop
=
primary_until
,
**
gen_kwargs
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment