Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
fea936e6
Commit
fea936e6
authored
Nov 05, 2021
by
Leo Gao
Browse files
Merge TokenizedLM and TorchLM into BaseLM
parent
7f24a08b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
158 additions
and
163 deletions
+158
-163
lm_eval/base.py
lm_eval/base.py
+156
-18
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+2
-145
No files found.
lm_eval/base.py
View file @
fea936e6
...
@@ -4,16 +4,17 @@ from typing import Iterable
...
@@ -4,16 +4,17 @@ from typing import Iterable
import
numpy
as
np
import
numpy
as
np
import
re
import
re
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
torch
from
lm_eval.metrics
import
mean
,
perplexity
,
weighted_perplexity
,
weighted_mean
from
lm_eval.metrics
import
mean
,
perplexity
,
weighted_perplexity
,
weighted_mean
from
lm_eval
import
utils
from
lm_eval
import
utils
from
abc
import
abstractmethod
class
LM
(
abc
.
ABC
):
class
LM
(
abc
.
ABC
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
cache_hook
=
CacheHook
(
None
)
self
.
cache_hook
=
CacheHook
(
None
)
@
abc
.
abstractmethod
@
abstractmethod
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
"""Compute log-likelihood of generating a continuation from a context.
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
Downstream tasks should attempt to use loglikelihood instead of other
...
@@ -37,7 +38,7 @@ class LM(abc.ABC):
...
@@ -37,7 +38,7 @@ class LM(abc.ABC):
"""
"""
pass
pass
@
abc
.
abstractmethod
@
abstractmethod
def
loglikelihood_rolling
(
self
,
requests
):
def
loglikelihood_rolling
(
self
,
requests
):
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- We will use the full max context length of the model.
...
@@ -80,7 +81,7 @@ class LM(abc.ABC):
...
@@ -80,7 +81,7 @@ class LM(abc.ABC):
pass
pass
# TODO: Add an optional max length
# TODO: Add an optional max length
@
abc
.
abstractmethod
@
abstractmethod
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
):
"""Generate greedily until a stopping sequence
"""Generate greedily until a stopping sequence
...
@@ -108,15 +109,26 @@ class LM(abc.ABC):
...
@@ -108,15 +109,26 @@ class LM(abc.ABC):
self
.
cache_hook
=
cache_hook
self
.
cache_hook
=
cache_hook
class
Tokenized
LM
(
LM
):
class
Base
LM
(
LM
):
@
abc
.
abstractmethod
@
abstractmethod
def
tok_encode
(
self
,
string
:
str
):
pass
def
tok_encode
(
self
,
string
:
str
):
pass
@
abc
.
abstractmethod
@
abstractmethod
def
tok_decode
(
self
,
tokens
:
Iterable
[
int
]):
pass
def
tok_decode
(
self
,
tokens
:
Iterable
[
int
]):
pass
@
abc
.
abstractmethod
@
abstractmethod
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
pass
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
pass
@
abstractmethod
def
_model_call
(
self
,
inps
):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model
"""
pass
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks.
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks.
# TODO: enforce this somehow
# TODO: enforce this somehow
...
@@ -162,6 +174,132 @@ class TokenizedLM(LM):
...
@@ -162,6 +174,132 @@ class TokenizedLM(LM):
return
loglikelihoods
return
loglikelihoods
# subclass must implement properties batch_size, vocab_size, eot_token_id, max_gen_toks, device.
# TODO: enforce this somehow
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
def
_collate
(
x
):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks
=
x
[
1
]
+
x
[
2
]
return
(
-
len
(
toks
),
tuple
(
toks
))
# TODO: automatic (variable) batch size detection for vectorization
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
inps
=
[]
contlens
=
[]
inplens
=
[]
padding_length
=
None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for
_
,
context_enc
,
continuation_enc
in
chunk
:
# sanity check
assert
len
(
context_enc
)
>
0
assert
len
(
continuation_enc
)
>
0
assert
len
(
continuation_enc
)
<=
self
.
max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.vocab_size] slice
# cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left
inp
=
torch
.
tensor
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
):][:
-
1
]
,
dtype
=
torch
.
long
).
to
(
self
.
device
)
inplen
,
=
inp
.
shape
cont
=
continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length
=
padding_length
if
padding_length
is
not
None
else
inplen
# pad to length
inp
=
torch
.
cat
([
inp
,
# [seq]
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
).
to
(
inp
.
device
)
# [padding_length - seq]
],
dim
=
0
)
inps
.
append
(
inp
.
unsqueeze
(
0
))
contlens
.
append
(
cont
)
inplens
.
append
(
inplen
)
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
torch
.
cat
(
inps
,
dim
=
0
)),
dim
=-
1
).
cpu
()
# [batch, seq, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
inps
,
inplens
,
contlens
):
contlen
=
len
(
cont_toks
)
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
# cont_toks :: [1, seq]
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
# partial caching
if
cache_key
is
not
None
:
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
res
.
append
(
answer
)
return
reord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
res
=
[]
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
return
(
len
(
toks
),
x
[
0
])
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
context
,
until
in
tqdm
(
reord
.
get_reordered
()):
if
isinstance
(
until
,
str
):
until
=
[
until
]
primary_until
,
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
([
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]:])
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
return
reord
.
get_original
(
res
)
class
Task
(
abc
.
ABC
):
class
Task
(
abc
.
ABC
):
"""A task represents an entire benchmark including its dataset, problems,
"""A task represents an entire benchmark including its dataset, problems,
...
@@ -181,17 +319,17 @@ class Task(abc.ABC):
...
@@ -181,17 +319,17 @@ class Task(abc.ABC):
"""Downloads the task dataset if necessary"""
"""Downloads the task dataset if necessary"""
pass
pass
@
abc
.
abstractmethod
@
abstractmethod
def
has_training_docs
(
self
):
def
has_training_docs
(
self
):
"""Whether the task has a training set"""
"""Whether the task has a training set"""
pass
pass
@
abc
.
abstractmethod
@
abstractmethod
def
has_validation_docs
(
self
):
def
has_validation_docs
(
self
):
"""Whether the task has a validation set"""
"""Whether the task has a validation set"""
pass
pass
@
abc
.
abstractmethod
@
abstractmethod
def
has_test_docs
(
self
):
def
has_test_docs
(
self
):
"""Whether the task has a test set"""
"""Whether the task has a test set"""
pass
pass
...
@@ -223,15 +361,15 @@ class Task(abc.ABC):
...
@@ -223,15 +361,15 @@ class Task(abc.ABC):
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
@
abc
.
abstractmethod
@
abstractmethod
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
pass
pass
@
abc
.
abstractmethod
@
abstractmethod
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
pass
pass
@
abc
.
abstractmethod
@
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
Requests which will be sent to the LM.
...
@@ -245,7 +383,7 @@ class Task(abc.ABC):
...
@@ -245,7 +383,7 @@ class Task(abc.ABC):
"""
"""
pass
pass
@
abc
.
abstractmethod
@
abstractmethod
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
dict where keys are the names of submetrics and values are the values of
...
@@ -258,7 +396,7 @@ class Task(abc.ABC):
...
@@ -258,7 +396,7 @@ class Task(abc.ABC):
"""
"""
pass
pass
@
abc
.
abstractmethod
@
abstractmethod
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
"""
:returns: {str: [metric_score] -> float}
:returns: {str: [metric_score] -> float}
...
@@ -267,7 +405,7 @@ class Task(abc.ABC):
...
@@ -267,7 +405,7 @@ class Task(abc.ABC):
"""
"""
pass
pass
@
abc
.
abstractmethod
@
abstractmethod
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
"""
:returns: {str: bool}
:returns: {str: bool}
...
...
lm_eval/models/gpt2.py
View file @
fea936e6
...
@@ -2,7 +2,7 @@ import transformers
...
@@ -2,7 +2,7 @@ import transformers
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
lm_eval.base
import
LM
,
Tokenized
LM
from
lm_eval.base
import
LM
,
Base
LM
from
lm_eval
import
utils
from
lm_eval
import
utils
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
numpy
as
np
import
numpy
as
np
...
@@ -10,150 +10,7 @@ from abc import ABC, abstractmethod
...
@@ -10,150 +10,7 @@ from abc import ABC, abstractmethod
from
typing
import
Iterable
from
typing
import
Iterable
class
TorchLM
(
TokenizedLM
):
class
HFLM
(
BaseLM
):
@
abstractmethod
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
pass
@
abstractmethod
def
_model_call
(
self
,
inps
):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model
"""
pass
# subclass must implement properties batch_size, vocab_size, eot_token_id, max_gen_toks, device.
# TODO: enforce this somehow
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
def
_collate
(
x
):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks
=
x
[
1
]
+
x
[
2
]
return
(
-
len
(
toks
),
tuple
(
toks
))
# TODO: automatic (variable) batch size detection for vectorization
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
inps
=
[]
contlens
=
[]
inplens
=
[]
padding_length
=
None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for
_
,
context_enc
,
continuation_enc
in
chunk
:
# sanity check
assert
len
(
context_enc
)
>
0
assert
len
(
continuation_enc
)
>
0
assert
len
(
continuation_enc
)
<=
self
.
max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.vocab_size] slice
# cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left
inp
=
torch
.
tensor
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
):][:
-
1
]
,
dtype
=
torch
.
long
).
to
(
self
.
device
)
inplen
,
=
inp
.
shape
cont
=
continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length
=
padding_length
if
padding_length
is
not
None
else
inplen
# pad to length
inp
=
torch
.
cat
([
inp
,
# [seq]
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
).
to
(
inp
.
device
)
# [padding_length - seq]
],
dim
=
0
)
inps
.
append
(
inp
.
unsqueeze
(
0
))
contlens
.
append
(
cont
)
inplens
.
append
(
inplen
)
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
torch
.
cat
(
inps
,
dim
=
0
)),
dim
=-
1
).
cpu
()
# [batch, seq, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
inps
,
inplens
,
contlens
):
contlen
=
len
(
cont_toks
)
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
# cont_toks :: [1, seq]
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
# partial caching
if
cache_key
is
not
None
:
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
res
.
append
(
answer
)
return
reord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
res
=
[]
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
return
(
len
(
toks
),
x
[
0
])
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
context
,
until
in
tqdm
(
reord
.
get_reordered
()):
if
isinstance
(
until
,
str
):
until
=
[
until
]
primary_until
,
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
([
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]:])
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
return
reord
.
get_original
(
res
)
class
HFLM
(
TorchLM
):
def
__init__
(
self
,
device
=
'cuda'
,
pretrained
=
'gpt2'
,
revision
=
'main'
,
subfolder
=
None
,
tokenizer
=
None
,
batch_size
=
1
):
def
__init__
(
self
,
device
=
'cuda'
,
pretrained
=
'gpt2'
,
revision
=
'main'
,
subfolder
=
None
,
tokenizer
=
None
,
batch_size
=
1
):
super
().
__init__
()
super
().
__init__
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment