Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
fea936e6
Commit
fea936e6
authored
Nov 05, 2021
by
Leo Gao
Browse files
Merge TokenizedLM and TorchLM into BaseLM
parent
7f24a08b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
158 additions
and
163 deletions
+158
-163
lm_eval/base.py
lm_eval/base.py
+156
-18
lm_eval/models/gpt2.py
lm_eval/models/gpt2.py
+2
-145
No files found.
lm_eval/base.py
View file @
fea936e6
...
...
@@ -4,16 +4,17 @@ from typing import Iterable
import
numpy
as
np
import
re
from
tqdm
import
tqdm
import
torch
from
lm_eval.metrics
import
mean
,
perplexity
,
weighted_perplexity
,
weighted_mean
from
lm_eval
import
utils
from
abc
import
abstractmethod
class
LM
(
abc
.
ABC
):
def
__init__
(
self
):
self
.
cache_hook
=
CacheHook
(
None
)
@
abc
.
abstractmethod
@
abstractmethod
def
loglikelihood
(
self
,
requests
):
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
...
...
@@ -37,7 +38,7 @@ class LM(abc.ABC):
"""
pass
@
abc
.
abstractmethod
@
abstractmethod
def
loglikelihood_rolling
(
self
,
requests
):
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
...
...
@@ -80,7 +81,7 @@ class LM(abc.ABC):
pass
# TODO: Add an optional max length
@
abc
.
abstractmethod
@
abstractmethod
def
greedy_until
(
self
,
requests
):
"""Generate greedily until a stopping sequence
...
...
@@ -108,15 +109,26 @@ class LM(abc.ABC):
self
.
cache_hook
=
cache_hook
class
Tokenized
LM
(
LM
):
@
abc
.
abstractmethod
class
Base
LM
(
LM
):
@
abstractmethod
def
tok_encode
(
self
,
string
:
str
):
pass
@
abc
.
abstractmethod
@
abstractmethod
def
tok_decode
(
self
,
tokens
:
Iterable
[
int
]):
pass
@
abc
.
abstractmethod
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
pass
@
abstractmethod
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
pass
@
abstractmethod
def
_model_call
(
self
,
inps
):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model
"""
pass
# subclass must implement properties vocab_size, eot_token_id, max_gen_toks.
# TODO: enforce this somehow
...
...
@@ -162,6 +174,132 @@ class TokenizedLM(LM):
return
loglikelihoods
# subclass must implement properties batch_size, vocab_size, eot_token_id, max_gen_toks, device.
# TODO: enforce this somehow
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
def
_collate
(
x
):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks
=
x
[
1
]
+
x
[
2
]
return
(
-
len
(
toks
),
tuple
(
toks
))
# TODO: automatic (variable) batch size detection for vectorization
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
inps
=
[]
contlens
=
[]
inplens
=
[]
padding_length
=
None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for
_
,
context_enc
,
continuation_enc
in
chunk
:
# sanity check
assert
len
(
context_enc
)
>
0
assert
len
(
continuation_enc
)
>
0
assert
len
(
continuation_enc
)
<=
self
.
max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.vocab_size] slice
# cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left
inp
=
torch
.
tensor
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
):][:
-
1
]
,
dtype
=
torch
.
long
).
to
(
self
.
device
)
inplen
,
=
inp
.
shape
cont
=
continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length
=
padding_length
if
padding_length
is
not
None
else
inplen
# pad to length
inp
=
torch
.
cat
([
inp
,
# [seq]
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
).
to
(
inp
.
device
)
# [padding_length - seq]
],
dim
=
0
)
inps
.
append
(
inp
.
unsqueeze
(
0
))
contlens
.
append
(
cont
)
inplens
.
append
(
inplen
)
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
torch
.
cat
(
inps
,
dim
=
0
)),
dim
=-
1
).
cpu
()
# [batch, seq, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
inps
,
inplens
,
contlens
):
contlen
=
len
(
cont_toks
)
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
# cont_toks :: [1, seq]
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
# partial caching
if
cache_key
is
not
None
:
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
res
.
append
(
answer
)
return
reord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
res
=
[]
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
return
(
len
(
toks
),
x
[
0
])
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
context
,
until
in
tqdm
(
reord
.
get_reordered
()):
if
isinstance
(
until
,
str
):
until
=
[
until
]
primary_until
,
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
([
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]:])
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
return
reord
.
get_original
(
res
)
class
Task
(
abc
.
ABC
):
"""A task represents an entire benchmark including its dataset, problems,
...
...
@@ -181,17 +319,17 @@ class Task(abc.ABC):
"""Downloads the task dataset if necessary"""
pass
@
abc
.
abstractmethod
@
abstractmethod
def
has_training_docs
(
self
):
"""Whether the task has a training set"""
pass
@
abc
.
abstractmethod
@
abstractmethod
def
has_validation_docs
(
self
):
"""Whether the task has a validation set"""
pass
@
abc
.
abstractmethod
@
abstractmethod
def
has_test_docs
(
self
):
"""Whether the task has a test set"""
pass
...
...
@@ -223,15 +361,15 @@ class Task(abc.ABC):
return
rnd
.
sample
(
self
.
_training_docs
,
k
)
@
abc
.
abstractmethod
@
abstractmethod
def
doc_to_text
(
self
,
doc
):
pass
@
abc
.
abstractmethod
@
abstractmethod
def
doc_to_target
(
self
,
doc
):
pass
@
abc
.
abstractmethod
@
abstractmethod
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
...
...
@@ -245,7 +383,7 @@ class Task(abc.ABC):
"""
pass
@
abc
.
abstractmethod
@
abstractmethod
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
...
...
@@ -258,7 +396,7 @@ class Task(abc.ABC):
"""
pass
@
abc
.
abstractmethod
@
abstractmethod
def
aggregation
(
self
):
"""
:returns: {str: [metric_score] -> float}
...
...
@@ -267,7 +405,7 @@ class Task(abc.ABC):
"""
pass
@
abc
.
abstractmethod
@
abstractmethod
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
...
...
lm_eval/models/gpt2.py
View file @
fea936e6
...
...
@@ -2,7 +2,7 @@ import transformers
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
lm_eval.base
import
LM
,
Tokenized
LM
from
lm_eval.base
import
LM
,
Base
LM
from
lm_eval
import
utils
from
tqdm
import
tqdm
import
numpy
as
np
...
...
@@ -10,150 +10,7 @@ from abc import ABC, abstractmethod
from
typing
import
Iterable
class
TorchLM
(
TokenizedLM
):
@
abstractmethod
def
_model_generate
(
self
,
context
,
max_length
,
eos_token_id
):
pass
@
abstractmethod
def
_model_call
(
self
,
inps
):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits retuned from the model
"""
pass
# subclass must implement properties batch_size, vocab_size, eot_token_id, max_gen_toks, device.
# TODO: enforce this somehow
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
def
_collate
(
x
):
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch padded context length.
# this is useful to simplify the batching logic and more importantly to make automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
toks
=
x
[
1
]
+
x
[
2
]
return
(
-
len
(
toks
),
tuple
(
toks
))
# TODO: automatic (variable) batch size detection for vectorization
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
utils
.
chunks
(
tqdm
(
reord
.
get_reordered
(),
disable
=
disable_tqdm
),
self
.
batch_size
):
inps
=
[]
contlens
=
[]
inplens
=
[]
padding_length
=
None
# because vectorizing is annoying, we first convert each (context, continuation) pair to padded
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
for
_
,
context_enc
,
continuation_enc
in
chunk
:
# sanity check
assert
len
(
context_enc
)
>
0
assert
len
(
continuation_enc
)
>
0
assert
len
(
continuation_enc
)
<=
self
.
max_length
# how this all works:
# CTX CONT
# inp 0 1 2 3|4 5 6 7 8 9 <- last token is deleted by inp[:, :-1]
# gpt2 \ \
# logits 1 2 3|4 5 6 7 8 9 <- the ctx half gets tossed out by the [:, -len(continuation_enc):, :self.vocab_size] slice
# cont_toks 4 5 6 7 8 9
# when too long to fit in context, truncate from the left
inp
=
torch
.
tensor
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
):][:
-
1
]
,
dtype
=
torch
.
long
).
to
(
self
.
device
)
inplen
,
=
inp
.
shape
cont
=
continuation_enc
# since in _collate we make sure length is descending, the longest is always the first one.
padding_length
=
padding_length
if
padding_length
is
not
None
else
inplen
# pad to length
inp
=
torch
.
cat
([
inp
,
# [seq]
torch
.
zeros
(
padding_length
-
inplen
,
dtype
=
torch
.
long
).
to
(
inp
.
device
)
# [padding_length - seq]
],
dim
=
0
)
inps
.
append
(
inp
.
unsqueeze
(
0
))
contlens
.
append
(
cont
)
inplens
.
append
(
inplen
)
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
torch
.
cat
(
inps
,
dim
=
0
)),
dim
=-
1
).
cpu
()
# [batch, seq, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
inp
,
inplen
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
inps
,
inplens
,
contlens
):
contlen
=
len
(
cont_toks
)
logits
=
logits
[
inplen
-
contlen
:
inplen
].
unsqueeze
(
0
)
# [1, seq, vocab]
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
# cont_toks :: [1, seq]
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
#last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
# partial caching
if
cache_key
is
not
None
:
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
res
.
append
(
answer
)
return
reord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
# TODO: implement fully general `until` that handles untils that are
# multiple tokens or that span multiple tokens correctly
# TODO: extract to TokenizedLM?
res
=
[]
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
return
(
len
(
toks
),
x
[
0
])
reord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
context
,
until
in
tqdm
(
reord
.
get_reordered
()):
if
isinstance
(
until
,
str
):
until
=
[
until
]
primary_until
,
=
self
.
tok_encode
(
until
[
0
])
context_enc
=
torch
.
tensor
([
self
.
tok_encode
(
context
)[
self
.
max_gen_toks
-
self
.
max_length
:]]).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
)
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
()[
context_enc
.
shape
[
1
]:])
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
# partial caching
self
.
cache_hook
.
add_partial
(
"greedy_until"
,
(
context
,
until
),
s
)
res
.
append
(
s
)
return
reord
.
get_original
(
res
)
class
HFLM
(
TorchLM
):
class
HFLM
(
BaseLM
):
def
__init__
(
self
,
device
=
'cuda'
,
pretrained
=
'gpt2'
,
revision
=
'main'
,
subfolder
=
None
,
tokenizer
=
None
,
batch_size
=
1
):
super
().
__init__
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment