Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
e0498dd7
Commit
e0498dd7
authored
Jun 27, 2023
by
haileyschoelkopf
Browse files
change self.gpt2 -> self.model
parent
e6960b9a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
9 deletions
+9
-9
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+9
-9
No files found.
lm_eval/models/gpt2.py
View file @
e0498dd7
...
@@ -39,8 +39,8 @@ class HFLM(BaseLM):
...
@@ -39,8 +39,8 @@ class HFLM(BaseLM):
# Initialize model
# Initialize model
if
isinstance
(
pretrained
,
transformers
.
PreTrainedModel
):
if
isinstance
(
pretrained
,
transformers
.
PreTrainedModel
):
self
.
gpt2
=
pretrained
self
.
model
=
pretrained
self
.
_device
=
self
.
gpt2
.
device
self
.
_device
=
self
.
model
.
device
if
tokenizer
:
if
tokenizer
:
assert
isinstance
(
assert
isinstance
(
...
@@ -53,7 +53,7 @@ class HFLM(BaseLM):
...
@@ -53,7 +53,7 @@ class HFLM(BaseLM):
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
else
:
else
:
# Get tokenizer
# Get tokenizer
model_name
=
self
.
gpt2
.
name_or_path
model_name
=
self
.
model
.
name_or_path
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
model_name
,
model_name
,
revision
=
revision
,
revision
=
revision
,
...
@@ -81,7 +81,7 @@ class HFLM(BaseLM):
...
@@ -81,7 +81,7 @@ class HFLM(BaseLM):
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
)
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
)
# Initialize new model and tokenizer instances
# Initialize new model and tokenizer instances
self
.
gpt2
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
self
.
model
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
pretrained
,
pretrained
,
load_in_8bit
=
load_in_8bit
,
load_in_8bit
=
load_in_8bit
,
low_cpu_mem_usage
=
low_cpu_mem_usage
,
low_cpu_mem_usage
=
low_cpu_mem_usage
,
...
@@ -98,7 +98,7 @@ class HFLM(BaseLM):
...
@@ -98,7 +98,7 @@ class HFLM(BaseLM):
else
:
else
:
raise
TypeError
(
'Parameter pretrained should be of type str or transformers.PreTrainedModel'
)
raise
TypeError
(
'Parameter pretrained should be of type str or transformers.PreTrainedModel'
)
self
.
gpt2
.
eval
()
self
.
model
.
eval
()
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
...
@@ -134,8 +134,8 @@ class HFLM(BaseLM):
...
@@ -134,8 +134,8 @@ class HFLM(BaseLM):
return
self
.
_max_length
return
self
.
_max_length
seqlen_config_attrs
=
(
"n_positions"
,
"max_position_embeddings"
,
"n_ctx"
)
seqlen_config_attrs
=
(
"n_positions"
,
"max_position_embeddings"
,
"n_ctx"
)
for
attr
in
seqlen_config_attrs
:
for
attr
in
seqlen_config_attrs
:
if
hasattr
(
self
.
gpt2
.
config
,
attr
):
if
hasattr
(
self
.
model
.
config
,
attr
):
return
getattr
(
self
.
gpt2
.
config
,
attr
)
return
getattr
(
self
.
model
.
config
,
attr
)
if
hasattr
(
self
.
tokenizer
,
"model_max_length"
):
if
hasattr
(
self
.
tokenizer
,
"model_max_length"
):
if
self
.
tokenizer
.
model_max_length
==
1000000000000000019884624838656
:
if
self
.
tokenizer
.
model_max_length
==
1000000000000000019884624838656
:
return
self
.
_DEFAULT_MAX_LENGTH
return
self
.
_DEFAULT_MAX_LENGTH
...
@@ -172,14 +172,14 @@ class HFLM(BaseLM):
...
@@ -172,14 +172,14 @@ class HFLM(BaseLM):
logits returned from the model
logits returned from the model
"""
"""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
gpt2
(
inps
)[
0
]
return
self
.
model
(
inps
)[
0
]
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
generation_kwargs
=
{
"do_sample"
:
False
,
"max_length"
:
max_length
}
generation_kwargs
=
{
"do_sample"
:
False
,
"max_length"
:
max_length
}
if
eos_token_id
is
not
None
:
if
eos_token_id
is
not
None
:
generation_kwargs
[
'eos_token_id'
]
=
eos_token_id
generation_kwargs
[
'eos_token_id'
]
=
eos_token_id
generation_kwargs
[
'pad_token_id'
]
=
eos_token_id
# setting eos_token_id as pad token
generation_kwargs
[
'pad_token_id'
]
=
eos_token_id
# setting eos_token_id as pad token
return
self
.
gpt2
.
generate
(
context
,
**
generation_kwargs
)
return
self
.
model
.
generate
(
context
,
**
generation_kwargs
)
# for backwards compatibility
# for backwards compatibility
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment