Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
4e1ef749
Unverified
Commit
4e1ef749
authored
Mar 01, 2024
by
Lintang Sutawika
Committed by
GitHub
Mar 01, 2024
Browse files
Update huggingface.py
parent
ae79b121
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
97 deletions
+72
-97
lm_eval/models/huggingface.py
lm_eval/models/huggingface.py
+72
-97
No files found.
lm_eval/models/huggingface.py
View file @
4e1ef749
...
@@ -24,7 +24,7 @@ from transformers.models.auto.modeling_auto import (
...
@@ -24,7 +24,7 @@ from transformers.models.auto.modeling_auto import (
from
lm_eval
import
utils
from
lm_eval
import
utils
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.instance
import
Instance
from
lm_eval.api.model
import
Template
LM
from
lm_eval.api.model
import
LM
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval.models.utils
import
(
from
lm_eval.models.utils
import
(
Collator
,
Collator
,
...
@@ -64,7 +64,7 @@ def _get_accelerate_args(
...
@@ -64,7 +64,7 @@ def _get_accelerate_args(
@
register_model
(
"hf-auto"
,
"hf"
,
"huggingface"
)
@
register_model
(
"hf-auto"
,
"hf"
,
"huggingface"
)
class
HFLM
(
Template
LM
):
class
HFLM
(
LM
):
"""
"""
An abstracted Huggingface model class. Enables usage with both models of
An abstracted Huggingface model class. Enables usage with both models of
`transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
`transformers.AutoModelForCausalLM` and `transformers.AutoModelForSeq2SeqLM` classes.
...
@@ -78,8 +78,9 @@ class HFLM(TemplateLM):
...
@@ -78,8 +78,9 @@ class HFLM(TemplateLM):
def
__init__
(
def
__init__
(
self
,
self
,
pretrained
:
Optional
[
Union
[
str
,
transformers
.
PreTrainedModel
]]
=
"gpt2"
,
pretrained
:
Optional
[
Union
[
str
,
transformers
.
PreTrainedModel
]]
=
"gpt2"
,
backend
:
Optional
[
Literal
[
"default"
,
"causal"
,
"seq2seq"
]]
=
"default"
,
backend
:
Optional
[
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
Literal
[
"default"
,
"causal"
,
"seq2seq"
]
]
=
"default"
,
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision
:
Optional
[
str
]
=
"main"
,
revision
:
Optional
[
str
]
=
"main"
,
subfolder
:
Optional
[
str
]
=
None
,
subfolder
:
Optional
[
str
]
=
None
,
tokenizer
:
Optional
[
tokenizer
:
Optional
[
...
@@ -90,7 +91,6 @@ class HFLM(TemplateLM):
...
@@ -90,7 +91,6 @@ class HFLM(TemplateLM):
]
]
]
=
None
,
]
=
None
,
truncation
:
Optional
[
bool
]
=
False
,
truncation
:
Optional
[
bool
]
=
False
,
logits_cache
:
bool
=
True
,
max_length
:
Optional
[
int
]
=
None
,
max_length
:
Optional
[
int
]
=
None
,
device
:
Optional
[
str
]
=
"cuda"
,
device
:
Optional
[
str
]
=
"cuda"
,
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
"auto"
,
dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
"auto"
,
...
@@ -98,7 +98,6 @@ class HFLM(TemplateLM):
...
@@ -98,7 +98,6 @@ class HFLM(TemplateLM):
max_batch_size
:
Optional
[
int
]
=
64
,
max_batch_size
:
Optional
[
int
]
=
64
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
trust_remote_code
:
Optional
[
bool
]
=
False
,
use_fast_tokenizer
:
Optional
[
bool
]
=
True
,
use_fast_tokenizer
:
Optional
[
bool
]
=
True
,
add_bos_token
:
Optional
[
bool
]
=
False
,
# arguments used for splitting a model across GPUs naively.
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
# only used if `parallelize=True`.
parallelize
:
Optional
[
bool
]
=
False
,
parallelize
:
Optional
[
bool
]
=
False
,
...
@@ -240,7 +239,7 @@ class HFLM(TemplateLM):
...
@@ -240,7 +239,7 @@ class HFLM(TemplateLM):
)
)
self
.
truncation
=
truncation
self
.
truncation
=
truncation
self
.
logits_cache
=
logits_cache
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
# select (or create) a pad token to use
# select (or create) a pad token to use
if
self
.
tokenizer
.
pad_token
:
if
self
.
tokenizer
.
pad_token
:
...
@@ -250,7 +249,7 @@ class HFLM(TemplateLM):
...
@@ -250,7 +249,7 @@ class HFLM(TemplateLM):
elif
self
.
tokenizer
.
eos_token
:
elif
self
.
tokenizer
.
eos_token
:
self
.
tokenizer
.
pad_token_id
=
self
.
tokenizer
.
eos_token_id
self
.
tokenizer
.
pad_token_id
=
self
.
tokenizer
.
eos_token_id
else
:
else
:
if
getattr
(
self
.
config
,
"
model_type
"
,
None
)
==
"qwen"
:
if
self
.
config
.
model_type
==
"qwen"
:
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens
self
.
tokenizer
.
pad_token
=
"<|endoftext|>"
self
.
tokenizer
.
pad_token
=
"<|endoftext|>"
elif
(
elif
(
...
@@ -266,14 +265,6 @@ class HFLM(TemplateLM):
...
@@ -266,14 +265,6 @@ class HFLM(TemplateLM):
else
:
else
:
self
.
tokenizer
.
add_special_tokens
({
"pad_token"
:
"<|pad|>"
})
self
.
tokenizer
.
add_special_tokens
({
"pad_token"
:
"<|pad|>"
})
# TODO: override this for Gemma
self
.
add_bos_token
=
add_bos_token
if
getattr
(
self
.
config
,
"model_type"
,
None
)
==
"gemma"
:
self
.
add_bos_token
=
True
eval_logger
.
info
(
f
"Model type is '
{
self
.
config
.
model_type
}
', a BOS token will be used as Gemma underperforms without it."
)
self
.
_max_length
=
max_length
self
.
_max_length
=
max_length
self
.
batch_schedule
=
1
self
.
batch_schedule
=
1
...
@@ -666,9 +657,8 @@ class HFLM(TemplateLM):
...
@@ -666,9 +657,8 @@ class HFLM(TemplateLM):
""" """
""" """
if
add_special_tokens
is
None
:
if
add_special_tokens
is
None
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
add_special_tokens
=
False
or
self
.
add_bos_token
add_special_tokens
=
False
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
# TODO: investigate best practices for enc-dec models + special tokens
add_special_tokens
=
True
add_special_tokens
=
True
encoding
=
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
add_special_tokens
)
encoding
=
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
add_special_tokens
)
...
@@ -691,7 +681,7 @@ class HFLM(TemplateLM):
...
@@ -691,7 +681,7 @@ class HFLM(TemplateLM):
self
.
tokenizer
.
padding_side
=
padding_side
self
.
tokenizer
.
padding_side
=
padding_side
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
add_special_tokens
=
False
or
self
.
add_bos_token
add_special_tokens
=
False
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
add_special_tokens
=
True
add_special_tokens
=
True
...
@@ -770,9 +760,7 @@ class HFLM(TemplateLM):
...
@@ -770,9 +760,7 @@ class HFLM(TemplateLM):
**
generation_kwargs
,
**
generation_kwargs
,
)
)
def
_select_cont_toks
(
def
_select_cont_toks
(
self
,
logits
,
contlen
=
None
,
inplen
=
None
):
self
,
logits
:
torch
.
Tensor
,
contlen
:
int
=
None
,
inplen
:
int
=
None
)
->
torch
.
Tensor
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
assert
(
assert
(
contlen
and
inplen
contlen
and
inplen
...
@@ -790,6 +778,39 @@ class HFLM(TemplateLM):
...
@@ -790,6 +778,39 @@ class HFLM(TemplateLM):
return
logits
return
logits
def
_encode_pair
(
self
,
context
:
str
,
continuation
:
str
)
->
Tuple
[
List
[
int
],
List
[
int
]]:
n_spaces
=
len
(
context
)
-
len
(
context
.
rstrip
())
if
n_spaces
>
0
:
continuation
=
context
[
-
n_spaces
:]
+
continuation
context
=
context
[:
-
n_spaces
]
whole_enc
=
self
.
tok_encode
(
context
+
continuation
,
add_special_tokens
=
False
)
context_enc
=
self
.
tok_encode
(
context
,
add_special_tokens
=
False
)
# whole_enc = self.tok_encode(context + continuation)
# context_enc = self.tok_encode(context, add_special_tokens=False)
context_enc_len
=
len
(
context_enc
)
continuation_enc
=
whole_enc
[
context_enc_len
:]
return
context_enc
,
continuation_enc
def
loglikelihood
(
self
,
requests
:
List
[
Instance
])
->
List
[
Tuple
[
float
,
bool
]]:
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
if
context
==
""
:
# end of text as context
context_enc
,
continuation_enc
=
(
[
self
.
eot_token_id
],
self
.
tok_encode
(
continuation
),
)
else
:
context_enc
,
continuation_enc
=
self
.
_encode_pair
(
context
,
continuation
)
new_reqs
.
append
(((
context
,
continuation
),
context_enc
,
continuation_enc
))
return
self
.
_loglikelihood_tokens
(
new_reqs
)
def
loglikelihood_rolling
(
self
,
requests
:
List
[
Instance
])
->
List
[
float
]:
def
loglikelihood_rolling
(
self
,
requests
:
List
[
Instance
])
->
List
[
float
]:
loglikelihoods
=
[]
loglikelihoods
=
[]
...
@@ -830,7 +851,7 @@ class HFLM(TemplateLM):
...
@@ -830,7 +851,7 @@ class HFLM(TemplateLM):
rolling_token_windows
+=
pad_amnt
*
[
rolling_token_windows
[
0
]]
rolling_token_windows
+=
pad_amnt
*
[
rolling_token_windows
[
0
]]
string_nll
=
self
.
_loglikelihood_tokens
(
string_nll
=
self
.
_loglikelihood_tokens
(
requests
=
rolling_token_windows
,
rolling_token_windows
,
disable_tqdm
=
True
,
disable_tqdm
=
True
,
override_bs
=
adaptive_batch_size
,
override_bs
=
adaptive_batch_size
,
)
)
...
@@ -872,7 +893,7 @@ class HFLM(TemplateLM):
...
@@ -872,7 +893,7 @@ class HFLM(TemplateLM):
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
# TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
res
=
[]
res
=
[]
def
_collate
(
req
:
Tuple
[
Tuple
[
str
,
str
],
List
[
int
],
List
[
int
]]
):
def
_collate
(
x
):
"""Defines the key for the sorted method"""
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - time estimates will always be over not underestimates, which is more useful for planning
...
@@ -881,26 +902,10 @@ class HFLM(TemplateLM):
...
@@ -881,26 +902,10 @@ class HFLM(TemplateLM):
# automatic adaptive batches much much easier to implement
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
# - any OOMs will happen right away rather than near the end
toks
=
req
[
1
]
+
req
[
2
]
toks
=
x
[
1
]
+
x
[
2
]
return
-
len
(
toks
),
tuple
(
toks
)
return
-
len
(
toks
),
tuple
(
toks
)
def
_lookup_one_token_cont
(
req
:
Tuple
[
Tuple
[
str
,
str
],
List
[
int
],
List
[
int
]]):
re_ord
=
Collator
(
requests
,
sort_fn
=
_collate
)
"""Defines the key to group and lookup one-token continuations"""
# Use with group_by="contexts" (optional)"
# allows for the creation of a lookup, so we can re-use logits in case of one-token continuations.
# speeds up some multiple-choice tasks proportionally to the number of choices.
# groups requests by context+continuation[:-1] and infer on one request/group.
return
req
[
-
2
]
+
req
[
-
1
][:
-
1
]
re_ord
=
Collator
(
requests
,
sort_fn
=
_collate
,
group_by
=
"contexts"
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
and
self
.
logits_cache
else
None
,
group_fn
=
_lookup_one_token_cont
,
)
# automatic (variable) batch size detection for vectorization
# automatic (variable) batch size detection for vectorization
# pull longest context sample from request
# pull longest context sample from request
...
@@ -921,11 +926,7 @@ class HFLM(TemplateLM):
...
@@ -921,11 +926,7 @@ class HFLM(TemplateLM):
)
)
chunks
=
re_ord
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
)
chunks
=
re_ord
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
)
pbar
=
tqdm
(
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)))
total
=
len
(
requests
),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
)),
desc
=
"Running loglikelihood requests"
,
)
for
chunk
in
chunks
:
for
chunk
in
chunks
:
inps
=
[]
inps
=
[]
cont_toks_list
=
[]
cont_toks_list
=
[]
...
@@ -1025,7 +1026,7 @@ class HFLM(TemplateLM):
...
@@ -1025,7 +1026,7 @@ class HFLM(TemplateLM):
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
)
# [batch, padding_length (inp or cont), vocab]
)
# [batch, padding_length (inp or cont), vocab]
for
(
request_str
,
ctx_tokens
,
_
),
logits
,
inplen
,
cont_toks
in
zip
(
for
(
cache_key
,
_
,
_
),
logits
,
inplen
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
inplens
,
cont_toks_list
chunk
,
multi_logits
,
inplens
,
cont_toks_list
):
):
# Slice to original seq length
# Slice to original seq length
...
@@ -1044,36 +1045,24 @@ class HFLM(TemplateLM):
...
@@ -1044,36 +1045,24 @@ class HFLM(TemplateLM):
# Check if per-token argmax is exactly equal to continuation
# Check if per-token argmax is exactly equal to continuation
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
,
device
=
self
.
device
).
unsqueeze
(
0
)
# [1, seq]
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
# check for one-token continuation cache hits.
# Answer: (log prob, is-exact-match)
# noop in case group_by != "contexts" or no cache hit and returns the
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
# original args. Otherwise, expands the logits batch dimension and yields each
# batch along with matching continuation tokens and prompt strings.
res
.
append
(
answer
)
# logits -> [1, seq, vocab]
for
request_str
,
cont_toks
,
logits
in
re_ord
.
get_cache
(
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
req_str
=
request_str
,
pbar
.
update
(
1
)
cxt_toks
=
ctx_tokens
,
cont_toks
=
cont_toks
,
logits
=
logits
,
):
cont_toks
=
torch
.
tensor
(
cont_toks
,
dtype
=
torch
.
long
,
device
=
self
.
device
).
unsqueeze
(
0
)
# [1, seq]
max_equal
=
(
greedy_tokens
==
cont_toks
).
all
()
# Obtain log-probs at the corresponding continuation token indices
# last_token_slice = logits[:, -1, :].squeeze(0).tolist()
logits
=
torch
.
gather
(
logits
,
2
,
cont_toks
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# [1, seq]
# Answer: (log prob, is-exact-match)
answer
=
(
float
(
logits
.
sum
()),
bool
(
max_equal
))
res
.
append
(
answer
)
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
request_str
,
answer
)
pbar
.
update
(
1
)
pbar
.
close
()
pbar
.
close
()
...
@@ -1082,7 +1071,7 @@ class HFLM(TemplateLM):
...
@@ -1082,7 +1071,7 @@ class HFLM(TemplateLM):
def
generate_until
(
self
,
requests
:
List
[
Instance
])
->
List
[
str
]:
def
generate_until
(
self
,
requests
:
List
[
Instance
])
->
List
[
str
]:
res
=
[]
res
=
[]
def
_collate
(
req
:
Tuple
[
str
,
dict
]
):
def
_collate
(
x
):
"""Defines the key for the sorted method"""
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - time estimates will always be over not underestimates, which is more useful for planning
...
@@ -1090,14 +1079,10 @@ class HFLM(TemplateLM):
...
@@ -1090,14 +1079,10 @@ class HFLM(TemplateLM):
# padded context length. this is useful to simplify the batching logic and more importantly to make
# padded context length. this is useful to simplify the batching logic and more importantly to make
# automatic adaptive batches much much easier to implement
# automatic adaptive batches much much easier to implement
# - any OOMs will happen right away rather than near the end
# - any OOMs will happen right away rather than near the end
toks
=
self
.
tok_encode
(
req
[
0
])
toks
=
self
.
tok_encode
(
x
[
0
])
return
-
len
(
toks
),
req
[
0
]
return
-
len
(
toks
),
x
[
0
]
pbar
=
tqdm
(
pbar
=
tqdm
(
total
=
len
(
requests
),
disable
=
(
self
.
rank
!=
0
))
total
=
len
(
requests
),
disable
=
(
self
.
rank
!=
0
),
desc
=
"Running generate_until requests"
,
)
adaptive_batch_size
=
None
adaptive_batch_size
=
None
if
self
.
batch_size
==
"auto"
:
if
self
.
batch_size
==
"auto"
:
# using rolling window with maximum context
# using rolling window with maximum context
...
@@ -1122,13 +1107,7 @@ class HFLM(TemplateLM):
...
@@ -1122,13 +1107,7 @@ class HFLM(TemplateLM):
# we group requests by their generation_kwargs,
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
# in the same batch.
# in the same batch.
# group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
re_ords
=
Collator
([
reg
.
args
for
reg
in
requests
],
_collate
,
grouping
=
True
)
re_ords
=
Collator
(
[
reg
.
args
for
reg
in
requests
],
sort_fn
=
_collate
,
group_by
=
"gen_kwargs"
,
group_fn
=
lambda
x
:
x
[
1
],
)
chunks
=
re_ords
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
)
chunks
=
re_ords
.
get_batched
(
n
=
batch_size
,
batch_fn
=
batch_fn
)
for
chunk
in
chunks
:
for
chunk
in
chunks
:
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
contexts
,
all_gen_kwargs
=
zip
(
*
chunk
)
...
@@ -1151,12 +1130,8 @@ class HFLM(TemplateLM):
...
@@ -1151,12 +1130,8 @@ class HFLM(TemplateLM):
raise
ValueError
(
raise
ValueError
(
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
f
"Expected `kwargs` to be of type `dict` but got
{
type
(
gen_kwargs
)
}
"
)
)
# add EOS token to stop sequences
eos
=
self
.
tok_decode
(
self
.
eot_token_id
)
if
not
until
:
if
not
until
:
until
=
[
eos
]
until
=
[
self
.
tok_decode
(
self
.
eot_token_id
)]
else
:
until
.
append
(
eos
)
if
"max_gen_toks"
in
kwargs
.
keys
():
if
"max_gen_toks"
in
kwargs
.
keys
():
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
max_gen_toks
=
kwargs
.
pop
(
"max_gen_toks"
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment