Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
bc7f52e6
Commit
bc7f52e6
authored
Jun 20, 2023
by
haileyschoelkopf
Committed by
lintangsutawika
Jun 22, 2023
Browse files
automatically unwrap model when needed
parent
1bd6229c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
32 deletions
+22
-32
lm_eval/models/hf_merged.py
lm_eval/models/hf_merged.py
+22
-32
No files found.
lm_eval/models/hf_merged.py
View file @
bc7f52e6
...
@@ -74,9 +74,10 @@ class HFLM(LM):
...
@@ -74,9 +74,10 @@ class HFLM(LM):
assert
self
.
AUTO_MODEL_CLASS
in
[
transformers
.
AutoModelForCausalLM
,
transformers
.
AutoModelForSeq2SeqLM
]
assert
self
.
AUTO_MODEL_CLASS
in
[
transformers
.
AutoModelForCausalLM
,
transformers
.
AutoModelForSeq2SeqLM
]
self
.
model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
self
.
_
model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
pretrained
,
revision
=
revision
,
low_cpu_mem_usage
=
low_cpu_mem_usage
pretrained
,
revision
=
revision
,
low_cpu_mem_usage
=
low_cpu_mem_usage
).
to
(
self
.
device
)
).
to
(
self
.
device
)
# forever after, access self._model through self.model property
self
.
model
.
eval
()
self
.
model
.
eval
()
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
...
@@ -125,28 +126,27 @@ class HFLM(LM):
...
@@ -125,28 +126,27 @@ class HFLM(LM):
# return the associated transformers.AutoConfig for the given pretrained model.
# return the associated transformers.AutoConfig for the given pretrained model.
return
self
.
_config
return
self
.
_config
@
property
def
model
(
self
):
# returns the model, unwrapping it if using Accelerate
if
hasattr
(
self
,
"accelerator"
):
return
self
.
accelerator
.
unwrap_model
(
self
.
_model
)
else
:
return
self
.
_model
@
property
@
property
def
eot_token_id
(
self
):
def
eot_token_id
(
self
):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return
self
.
tokenizer
.
eos_token_id
return
self
.
tokenizer
.
eos_token_id
# TODO: make model at self._model, have self.model property unwrap accelerator if needed under hood?
@
property
@
property
def
max_length
(
self
):
def
max_length
(
self
):
try
:
try
:
if
hasattr
(
self
,
"accelerator"
):
return
self
.
model
.
config
.
n_ctx
return
self
.
accelerator
.
unwrap_model
(
self
.
model
).
config
.
n_ctx
else
:
return
self
.
model
.
config
.
n_ctx
except
AttributeError
:
except
AttributeError
:
# gptneoconfig doesn't have n_ctx apparently
# gptneoconfig doesn't have n_ctx apparently
if
hasattr
(
self
,
"accelerator"
):
return
self
.
model
.
config
.
max_position_embeddings
return
self
.
accelerator
.
unwrap_model
(
self
.
model
).
config
.
max_position_embeddings
else
:
return
self
.
model
.
config
.
max_position_embeddings
@
property
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
):
return
256
return
256
...
@@ -236,24 +236,14 @@ class HFLM(LM):
...
@@ -236,24 +236,14 @@ class HFLM(LM):
stopping_criteria
=
stop_sequences_criteria
(
stopping_criteria
=
stop_sequences_criteria
(
self
.
tokenizer
,
stop
,
1
,
context
.
shape
[
0
]
self
.
tokenizer
,
stop
,
1
,
context
.
shape
[
0
]
)
)
if
hasattr
(
self
,
"accelerator"
):
return
self
.
model
.
generate
(
return
self
.
accelerator
.
unwrap_model
(
self
.
model
).
generate
(
context
,
context
,
max_length
=
max_length
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
self
.
eot_token_id
,
pad_token_id
=
self
.
eot_token_id
,
use_cache
=
True
,
use_cache
=
True
,
**
generation_kwargs
,
**
generation_kwargs
,
)
)
else
:
return
self
.
model
.
generate
(
context
,
max_length
=
max_length
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
self
.
eot_token_id
,
use_cache
=
True
,
**
generation_kwargs
,
)
def
_select_cont_toks
(
self
,
logits
,
contlen
=
None
,
inplen
=
None
):
def
_select_cont_toks
(
self
,
logits
,
contlen
=
None
,
inplen
=
None
):
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
...
@@ -299,7 +289,7 @@ class HFLM(LM):
...
@@ -299,7 +289,7 @@ class HFLM(LM):
)
)
)
)
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
, in seq2seq case
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
pad_amnt
=
0
pad_amnt
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment