Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1b4242c1
"vscode:/vscode.git/clone" did not exist on "25bb71bf27b0a806a337114c95cce514881805eb"
Commit
1b4242c1
authored
Mar 26, 2021
by
Leo Gao
Browse files
More changes to make neo work
parent
747b851d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
5 deletions
+10
-5
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+9
-4
main.py
main.py
+1
-1
No files found.
lm_eval/models/gpt2.py
View file @
1b4242c1
...
...
@@ -9,11 +9,16 @@ from tqdm import tqdm
class
GPT2LM
(
LM
):
MAX_GEN_TOKS
=
256
def
__init__
(
self
,
device
=
"cpu"
,
pretrained
=
'gpt2'
):
self
.
device
=
torch
.
device
(
device
)
def
__init__
(
self
,
device
=
None
,
pretrained
=
'gpt2'
):
if
device
:
self
.
device
=
torch
.
device
(
device
)
else
:
self
.
device
=
torch
.
device
(
'cuda'
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
'cpu'
)
self
.
gpt2
=
transformers
.
AutoModelForCausalLM
.
from_pretrained
(
pretrained
).
to
(
self
.
device
)
self
.
gpt2
.
eval
()
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
pretrained
)
# pretrained tokenizer for neo is broken for now so just hardcoding this to gpt2
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
'gpt2'
)
self
.
tokenizer
.
pad_token
=
"<|endoftext|>"
self
.
max_length
=
self
.
gpt2
.
config
.
n_ctx
...
...
@@ -22,7 +27,7 @@ class GPT2LM(LM):
@
classmethod
def
create_from_arg_string
(
cls
,
arg_string
):
args
=
utils
.
simple_parse_args_string
(
arg_string
)
return
cls
(
device
=
args
.
get
(
"device"
,
"cpu"
),
pretrained
=
args
.
get
(
"pretrained"
,
"gpt2"
))
return
cls
(
device
=
args
.
get
(
"device"
,
None
),
pretrained
=
args
.
get
(
"pretrained"
,
"gpt2"
))
def
loglikelihood
(
self
,
requests
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
...
...
main.py
View file @
1b4242c1
...
...
@@ -35,7 +35,7 @@ def main():
print
(
"WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
if
not
args
.
no_cache
:
lm
=
base
.
CachingLM
(
lm
,
'lm_cache/'
+
args
.
model
+
'_'
+
args
.
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
)
+
'.db'
)
lm
=
base
.
CachingLM
(
lm
,
'lm_cache/'
+
args
.
model
+
'_'
+
args
.
model_args
.
replace
(
'='
,
'-'
).
replace
(
','
,
'_'
)
.
replace
(
'/'
,
'-'
)
+
'.db'
)
if
args
.
tasks
==
"all_tasks"
:
task_names
=
tasks
.
ALL_TASKS
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment