Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9cf4a104
Commit
9cf4a104
authored
Jun 20, 2023
by
haileyschoelkopf
Browse files
more pre-commit
parent
306cfada
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
151 additions
and
114 deletions
+151
-114
lm_eval/api/task.py
lm_eval/api/task.py
+4
-2
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-4
lm_eval/models/hf_causal.py
lm_eval/models/hf_causal.py
+3
-3
lm_eval/models/hf_merged.py
lm_eval/models/hf_merged.py
+67
-52
lm_eval/models/seq2seq.py
lm_eval/models/seq2seq.py
+48
-30
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+1
-1
lm_eval/utils.py
lm_eval/utils.py
+26
-22
No files found.
lm_eval/api/task.py
View file @
9cf4a104
...
@@ -101,7 +101,7 @@ class TaskConfig(dict):
...
@@ -101,7 +101,7 @@ class TaskConfig(dict):
assert
(
assert
(
self
.
output_type
==
"greedy_until"
self
.
output_type
==
"greedy_until"
),
"passed `generation_kwargs`, but not using a generation request type!"
),
"passed `generation_kwargs`, but not using a generation request type!"
elif
self
.
output_type
==
"greedy_until"
:
elif
self
.
output_type
==
"greedy_until"
:
# ensure that we greedily generate in absence of explicit arguments otherwise
# ensure that we greedily generate in absence of explicit arguments otherwise
self
.
generation_kwargs
=
{
"do_sample"
:
False
,
"temperature"
:
0.0
}
self
.
generation_kwargs
=
{
"do_sample"
:
False
,
"temperature"
:
0.0
}
...
@@ -905,7 +905,9 @@ class ConfigurableTask(Task):
...
@@ -905,7 +905,9 @@ class ConfigurableTask(Task):
for
key
,
result
in
zip
(
self
.
_metric_fn_list
.
keys
(),
results
):
for
key
,
result
in
zip
(
self
.
_metric_fn_list
.
keys
(),
results
):
_dict
=
self
.
_metric_fn_list
[
key
].
compute
(
_dict
=
self
.
_metric_fn_list
[
key
].
compute
(
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_metric_fn_kwargs
[
key
]
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_metric_fn_kwargs
[
key
],
)
)
result_dict
=
{
**
result_dict
,
**
_dict
}
result_dict
=
{
**
result_dict
,
**
_dict
}
...
...
lm_eval/evaluator.py
View file @
9cf4a104
...
@@ -183,10 +183,8 @@ def evaluate(
...
@@ -183,10 +183,8 @@ def evaluate(
# get lists of each type of request
# get lists of each type of request
for
task_name
,
task
in
task_dict
.
items
():
for
task_name
,
task
in
task_dict
.
items
():
versions
[
task_name
]
=
task
.
VERSION
versions
[
task_name
]
=
task
.
VERSION
configs
[
task_name
]
=
dict
(
configs
[
task_name
]
=
dict
(
task
.
dump_config
())
task
.
dump_config
()
)
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func())
# task_docs = list(task_doc_func())
# rnd = random.Random()
# rnd = random.Random()
...
...
lm_eval/models/hf_causal.py
View file @
9cf4a104
...
@@ -35,7 +35,7 @@ class HFCausalLM(LM):
...
@@ -35,7 +35,7 @@ class HFCausalLM(LM):
assert
isinstance
(
batch_size
,
int
)
assert
isinstance
(
batch_size
,
int
)
gpus
=
torch
.
cuda
.
device_count
()
gpus
=
torch
.
cuda
.
device_count
()
if
gpus
<=
1
:
if
gpus
<=
1
:
if
device
:
if
device
:
if
device
not
in
[
"cuda"
,
"cpu"
]:
if
device
not
in
[
"cuda"
,
"cpu"
]:
...
@@ -63,7 +63,7 @@ class HFCausalLM(LM):
...
@@ -63,7 +63,7 @@ class HFCausalLM(LM):
pretrained
,
revision
=
revision
,
low_cpu_mem_usage
=
low_cpu_mem_usage
pretrained
,
revision
=
revision
,
low_cpu_mem_usage
=
low_cpu_mem_usage
).
to
(
self
.
device
)
).
to
(
self
.
device
)
self
.
model
.
eval
()
self
.
model
.
eval
()
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
self
.
tokenizer
=
transformers
.
AutoTokenizer
.
from_pretrained
(
pretrained
if
tokenizer
is
None
else
tokenizer
,
pretrained
if
tokenizer
is
None
else
tokenizer
,
revision
=
revision
,
revision
=
revision
,
...
@@ -104,7 +104,7 @@ class HFCausalLM(LM):
...
@@ -104,7 +104,7 @@ class HFCausalLM(LM):
self
.
_rank
=
self
.
accelerator
.
local_process_index
self
.
_rank
=
self
.
accelerator
.
local_process_index
self
.
_world_size
=
self
.
accelerator
.
num_processes
self
.
_world_size
=
self
.
accelerator
.
num_processes
@
property
@
property
def
eot_token_id
(
self
):
def
eot_token_id
(
self
):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
...
...
lm_eval/models/hf_merged.py
View file @
9cf4a104
...
@@ -27,6 +27,7 @@ class HFLM(LM):
...
@@ -27,6 +27,7 @@ class HFLM(LM):
"""
"""
AUTO_MODEL_CLASS
=
None
AUTO_MODEL_CLASS
=
None
def
__init__
(
def
__init__
(
self
,
self
,
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -44,7 +45,7 @@ class HFLM(LM):
...
@@ -44,7 +45,7 @@ class HFLM(LM):
assert
isinstance
(
batch_size
,
int
)
assert
isinstance
(
batch_size
,
int
)
gpus
=
torch
.
cuda
.
device_count
()
gpus
=
torch
.
cuda
.
device_count
()
if
gpus
<=
1
:
if
gpus
<=
1
:
if
device
:
if
device
:
if
device
not
in
[
"cuda"
,
"cpu"
]:
if
device
not
in
[
"cuda"
,
"cpu"
]:
...
@@ -68,7 +69,7 @@ class HFLM(LM):
...
@@ -68,7 +69,7 @@ class HFLM(LM):
# TODO: update this to be less of a hack once subfolder is fixed in HF
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
)
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
)
# get config
# get config
self
.
_config
=
transformers
.
AutoConfig
.
from_pretrained
(
self
.
_config
=
transformers
.
AutoConfig
.
from_pretrained
(
pretrained
,
pretrained
,
revision
=
revision
,
revision
=
revision
,
...
@@ -77,9 +78,12 @@ class HFLM(LM):
...
@@ -77,9 +78,12 @@ class HFLM(LM):
if
getattr
(
self
.
_config
,
"model_type"
)
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
if
getattr
(
self
.
_config
,
"model_type"
)
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForCausalLM
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForCausalLM
else
:
else
:
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForSeq2SeqLM
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForSeq2SeqLM
assert
self
.
AUTO_MODEL_CLASS
in
[
transformers
.
AutoModelForCausalLM
,
transformers
.
AutoModelForSeq2SeqLM
]
assert
self
.
AUTO_MODEL_CLASS
in
[
transformers
.
AutoModelForCausalLM
,
transformers
.
AutoModelForSeq2SeqLM
,
]
self
.
_model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
self
.
_model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
pretrained
,
revision
=
revision
,
low_cpu_mem_usage
=
low_cpu_mem_usage
pretrained
,
revision
=
revision
,
low_cpu_mem_usage
=
low_cpu_mem_usage
...
@@ -127,7 +131,7 @@ class HFLM(LM):
...
@@ -127,7 +131,7 @@ class HFLM(LM):
self
.
_rank
=
self
.
accelerator
.
local_process_index
self
.
_rank
=
self
.
accelerator
.
local_process_index
self
.
_world_size
=
self
.
accelerator
.
num_processes
self
.
_world_size
=
self
.
accelerator
.
num_processes
@
property
@
property
def
config
(
self
):
def
config
(
self
):
# return the associated transformers.AutoConfig for the given pretrained model.
# return the associated transformers.AutoConfig for the given pretrained model.
...
@@ -175,20 +179,18 @@ class HFLM(LM):
...
@@ -175,20 +179,18 @@ class HFLM(LM):
return
self
.
_world_size
return
self
.
_world_size
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
=
None
):
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
=
None
):
"""
""" """
"""
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
add_special_tokens
=
False
add_special_tokens
=
False
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
add_special_tokens
=
True
add_special_tokens
=
True
encoding
=
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
add_special_tokens
)
encoding
=
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
add_special_tokens
)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if
left_truncate_len
:
if
left_truncate_len
:
encoding
=
encoding
[
-
left_truncate_len
:]
encoding
=
encoding
[
-
left_truncate_len
:]
return
encoding
return
encoding
def
tok_decode
(
self
,
tokens
):
def
tok_decode
(
self
,
tokens
):
...
@@ -197,23 +199,9 @@ class HFLM(LM):
...
@@ -197,23 +199,9 @@ class HFLM(LM):
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
"""
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
labels: a torch tensor of shape [batch, sequence_cont]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with
torch
.
no_grad
():
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
"""
"""
inps: torch.Tensor
:param
inps: torch.Tensor
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
[batch, sequence_ctx]. the size of sequence may vary from call to call
[batch, sequence_ctx]. the size of sequence may vary from call to call
:param attn_mask: torch.Tensor, optional
:param attn_mask: torch.Tensor, optional
...
@@ -229,7 +217,9 @@ class HFLM(LM):
...
@@ -229,7 +217,9 @@ class HFLM(LM):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
attn_mask
or
labels
:
if
attn_mask
or
labels
:
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
else
:
else
:
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
return
self
.
model
(
inps
).
logits
return
self
.
model
(
inps
).
logits
...
@@ -254,16 +244,20 @@ class HFLM(LM):
...
@@ -254,16 +244,20 @@ class HFLM(LM):
def
_select_cont_toks
(
self
,
logits
,
contlen
=
None
,
inplen
=
None
):
def
_select_cont_toks
(
self
,
logits
,
contlen
=
None
,
inplen
=
None
):
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
assert
(
contlen
and
inplen
),
"Must pass input len and cont. len to select scored logits for causal LM"
assert
(
contlen
and
inplen
),
"Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding.
# discard right-padding.
# also discard the input/context tokens. we'll only score continuations.
# also discard the input/context tokens. we'll only score continuations.
logits
=
logits
[
inplen
-
contlen
:
inplen
]
logits
=
logits
[
inplen
-
contlen
:
inplen
]
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
assert
(
contlen
and
not
inplen
),
"Selecting scored logits for Seq2SeqLM requires only cont. len"
assert
(
# only discard right-padding.
contlen
and
not
inplen
),
"Selecting scored logits for Seq2SeqLM requires only cont. len"
# only discard right-padding.
# the logits input to this fn only contain decoder-side tokens.
# the logits input to this fn only contain decoder-side tokens.
logits
=
logits
[:
contlen
]
logits
=
logits
[:
contlen
]
return
logits
return
logits
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
...
@@ -289,14 +283,14 @@ class HFLM(LM):
...
@@ -289,14 +283,14 @@ class HFLM(LM):
utils
.
make_disjoint_window
,
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
context_len
=
1
,
),
),
)
)
)
)
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
#
TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
pad_amnt
=
0
pad_amnt
=
0
...
@@ -386,11 +380,11 @@ class HFLM(LM):
...
@@ -386,11 +380,11 @@ class HFLM(LM):
inp
=
torch
.
tensor
(
inp
=
torch
.
tensor
(
(
context_enc
)[
-
self
.
max_length
:],
(
context_enc
)[
-
self
.
max_length
:],
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
device
=
self
.
device
,
)
)
(
inplen
,)
=
inp
.
shape
(
inplen
,)
=
inp
.
shape
cont
=
torch
.
tensor
(
cont
=
torch
.
tensor
(
(
continuation_enc
)[
-
self
.
max_length
:],
(
continuation_enc
)[
-
self
.
max_length
:],
# TODO: left-shift these?
# TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type
# TODO: our code assumes we never end up truncating conts for either model type
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
...
@@ -400,24 +394,43 @@ class HFLM(LM):
...
@@ -400,24 +394,43 @@ class HFLM(LM):
conts
.
append
(
cont
)
conts
.
append
(
cont
)
padding_len_cont
=
max
(
padding_len_cont
,
contlen
)
if
padding_len_cont
is
not
None
else
contlen
padding_len_cont
=
(
max
(
padding_len_cont
,
contlen
)
if
padding_len_cont
is
not
None
else
contlen
)
padding_len_inp
=
max
(
padding_len_inp
,
inplen
)
if
padding_len_inp
is
not
None
else
inplen
padding_len_inp
=
(
max
(
padding_len_inp
,
inplen
)
if
padding_len_inp
is
not
None
else
inplen
)
inps
.
append
(
inp
)
# [1, inp_length]
inps
.
append
(
inp
)
# [1, inp_length]
cont_toks_list
.
append
(
continuation_enc
)
cont_toks_list
.
append
(
continuation_enc
)
inplens
.
append
(
inplen
)
inplens
.
append
(
inplen
)
# create encoder attn mask and batched conts, if seq2seq
# create encoder attn mask and batched conts, if seq2seq
call_kwargs
=
{}
call_kwargs
=
{}
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
,
padding_side
=
"right"
)
# [batch, padding_len_inp]
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
,
padding_side
=
"right"
)
# [batch, padding_len_inp]
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
# TODO: left-pad encoder inps and mask?
# TODO: left-pad encoder inps and mask?
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
)
# [batch, padding_len_inp]
batched_inps
=
utils
.
pad_and_concat
(
batched_conts
=
utils
.
pad_and_concat
(
padding_len_cont
,
conts
)
# [batch, padding_len_cont]
padding_len_inp
,
inps
batched_encoder_mask
=
utils
.
pad_and_concat
(
padding_len_inp
,
encoder_attns
)
# [batch, padding_len_inp]
)
# [batch, padding_len_inp]
call_kwargs
=
{
"attn_mask"
:
batched_encoder_mask
,
"labels"
:
batched_conts
}
batched_conts
=
utils
.
pad_and_concat
(
padding_len_cont
,
conts
)
# [batch, padding_len_cont]
batched_encoder_mask
=
utils
.
pad_and_concat
(
padding_len_inp
,
encoder_attns
)
# [batch, padding_len_inp]
call_kwargs
=
{
"attn_mask"
:
batched_encoder_mask
,
"labels"
:
batched_conts
,
}
multi_logits
=
F
.
log_softmax
(
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
...
@@ -429,13 +442,15 @@ class HFLM(LM):
...
@@ -429,13 +442,15 @@ class HFLM(LM):
# Slice to original seq length
# Slice to original seq length
contlen
=
len
(
cont_toks
)
contlen
=
len
(
cont_toks
)
# take only logits in the continuation
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
# (discard context toks if decoder-only ; discard right-padding)
ctx_len
=
inplen
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
else
None
ctx_len
=
(
inplen
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
else
None
)
logits
=
self
.
_select_cont_toks
(
logits
,
contlen
=
contlen
,
inplen
=
ctx_len
)
logits
=
self
.
_select_cont_toks
(
logits
,
contlen
=
contlen
,
inplen
=
ctx_len
)
logits
=
logits
.
unsqueeze
(
logits
=
logits
.
unsqueeze
(
0
)
# [1, seq, vocab]
0
)
# [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
# Check if per-token argmax is exactly equal to continuation
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
...
@@ -506,8 +521,8 @@ class HFLM(LM):
...
@@ -506,8 +521,8 @@ class HFLM(LM):
).
to
(
self
.
device
)
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
cont
=
self
.
_model_generate
(
context
=
context_enc
,
context
=
context_enc
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
stop
=
primary_until
,
stop
=
primary_until
,
**
gen_kwargs
,
**
gen_kwargs
,
)
)
...
@@ -519,4 +534,4 @@ class HFLM(LM):
...
@@ -519,4 +534,4 @@ class HFLM(LM):
res
.
append
(
s
)
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
\ No newline at end of file
lm_eval/models/seq2seq.py
View file @
9cf4a104
...
@@ -19,6 +19,7 @@ from accelerate import Accelerator
...
@@ -19,6 +19,7 @@ from accelerate import Accelerator
@
register_model
(
"hf-seq2seq"
,
"seq2seq"
)
@
register_model
(
"hf-seq2seq"
,
"seq2seq"
)
class
Seq2SeqHFLM
(
LM
):
class
Seq2SeqHFLM
(
LM
):
_DEFAULT_MAX_LENGTH
:
int
=
2048
_DEFAULT_MAX_LENGTH
:
int
=
2048
def
__init__
(
def
__init__
(
self
,
self
,
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -111,7 +112,8 @@ class Seq2SeqHFLM(LM):
...
@@ -111,7 +112,8 @@ class Seq2SeqHFLM(LM):
@
property
@
property
def
max_length
(
self
):
def
max_length
(
self
):
return
self
.
_DEFAULT_MAX_LENGTH
#TODO: Is this a good default?
return
self
.
_DEFAULT_MAX_LENGTH
# TODO: Is this a good default?
@
property
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
):
return
256
return
256
...
@@ -131,14 +133,14 @@ class Seq2SeqHFLM(LM):
...
@@ -131,14 +133,14 @@ class Seq2SeqHFLM(LM):
@
property
@
property
def
world_size
(
self
):
def
world_size
(
self
):
return
self
.
_world_size
return
self
.
_world_size
def
tok_encode
(
self
,
string
:
str
):
def
tok_encode
(
self
,
string
:
str
):
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
True
)
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
True
)
def
tok_decode
(
self
,
tokens
):
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
"""
"""
inps: a torch tensor of shape [batch, sequence_ctx]
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
the size of sequence may vary from call to call
...
@@ -150,8 +152,10 @@ class Seq2SeqHFLM(LM):
...
@@ -150,8 +152,10 @@ class Seq2SeqHFLM(LM):
logits returned from the model
logits returned from the model
"""
"""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
# we require users to pass do_sample=True explicitly
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
# for non-greedy gen. This should be reevaluated when considering beam search.
...
@@ -176,8 +180,8 @@ class Seq2SeqHFLM(LM):
...
@@ -176,8 +180,8 @@ class Seq2SeqHFLM(LM):
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
self
.
eot_token_id
,
pad_token_id
=
self
.
eot_token_id
,
**
generation_kwargs
,
**
generation_kwargs
,
)
)
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
new_reqs
=
[]
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
...
@@ -192,7 +196,7 @@ class Seq2SeqHFLM(LM):
...
@@ -192,7 +196,7 @@ class Seq2SeqHFLM(LM):
new_reqs
.
append
(((
context
,
continuation
),
context_enc
,
continuation_enc
))
new_reqs
.
append
(((
context
,
continuation
),
context_enc
,
continuation_enc
))
return
self
.
_loglikelihood_tokens
(
new_reqs
)
return
self
.
_loglikelihood_tokens
(
new_reqs
)
def
loglikelihood_rolling
(
self
,
requests
):
def
loglikelihood_rolling
(
self
,
requests
):
loglikelihoods
=
[]
loglikelihoods
=
[]
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
],
disable
=
(
self
.
rank
!=
0
)):
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
],
disable
=
(
self
.
rank
!=
0
)):
...
@@ -201,14 +205,14 @@ class Seq2SeqHFLM(LM):
...
@@ -201,14 +205,14 @@ class Seq2SeqHFLM(LM):
utils
.
make_disjoint_window
,
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
context_len
=
1
,
),
),
)
)
)
)
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
#
TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
pad_amnt
=
0
pad_amnt
=
0
...
@@ -237,7 +241,7 @@ class Seq2SeqHFLM(LM):
...
@@ -237,7 +241,7 @@ class Seq2SeqHFLM(LM):
loglikelihoods
.
append
(
string_nll
)
loglikelihoods
.
append
(
string_nll
)
return
loglikelihoods
return
loglikelihoods
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
res
=
[]
res
=
[]
...
@@ -251,7 +255,7 @@ class Seq2SeqHFLM(LM):
...
@@ -251,7 +255,7 @@ class Seq2SeqHFLM(LM):
toks
=
x
[
1
]
+
x
[
2
]
toks
=
x
[
1
]
+
x
[
2
]
return
-
len
(
toks
),
tuple
(
toks
)
return
-
len
(
toks
),
tuple
(
toks
)
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
utils
.
chunks
(
for
chunk
in
utils
.
chunks
(
tqdm
(
re_ord
.
get_reordered
(),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
))),
tqdm
(
re_ord
.
get_reordered
(),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
))),
...
@@ -261,7 +265,7 @@ class Seq2SeqHFLM(LM):
...
@@ -261,7 +265,7 @@ class Seq2SeqHFLM(LM):
conts
=
[]
conts
=
[]
encoder_attns
=
[]
encoder_attns
=
[]
cont_toks_list
=
[]
cont_toks_list
=
[]
max_batch_length_inp
=
None
max_batch_length_inp
=
None
max_batch_length_cont
=
None
max_batch_length_cont
=
None
...
@@ -283,33 +287,48 @@ class Seq2SeqHFLM(LM):
...
@@ -283,33 +287,48 @@ class Seq2SeqHFLM(LM):
).
to
(
self
.
device
)
).
to
(
self
.
device
)
(
contlen
,)
=
cont
.
shape
(
contlen
,)
=
cont
.
shape
max_batch_length_inp
=
max
(
max_batch_length_inp
,
inplen
)
if
max_batch_length_inp
is
not
None
else
inplen
max_batch_length_inp
=
(
max_batch_length_cont
=
max
(
max_batch_length_cont
,
contlen
)
if
max_batch_length_cont
is
not
None
else
contlen
max
(
max_batch_length_inp
,
inplen
)
if
max_batch_length_inp
is
not
None
else
inplen
)
max_batch_length_cont
=
(
max
(
max_batch_length_cont
,
contlen
)
if
max_batch_length_cont
is
not
None
else
contlen
)
inps
.
append
(
inp
)
# [1, inp_len]
inps
.
append
(
inp
)
# [1, inp_len]
conts
.
append
(
cont
)
# [1, cont_len]
conts
.
append
(
cont
)
# [1, cont_len]
encoder_attns
.
append
(
torch
.
ones_like
(
inp
))
encoder_attns
.
append
(
torch
.
ones_like
(
inp
))
cont_toks_list
.
append
(
continuation_enc
)
cont_toks_list
.
append
(
continuation_enc
)
batched_inps
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
inps
)
# [batch, padding_length]
batched_inps
=
utils
.
pad_and_concat
(
batched_conts
=
utils
.
pad_and_concat
(
max_batch_length_cont
,
conts
)
# [batch, padding_length]
max_batch_length_inp
,
inps
batched_encoder_mask
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
encoder_attns
)
)
# [batch, padding_length]
batched_conts
=
utils
.
pad_and_concat
(
max_batch_length_cont
,
conts
)
# [batch, padding_length]
batched_encoder_mask
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
encoder_attns
)
# need to make attention mask here too
# need to make attention mask here too
multi_logits
=
F
.
log_softmax
(
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
,
attn_mask
=
batched_encoder_mask
,
labels
=
batched_conts
),
dim
=-
1
self
.
_model_call
(
batched_inps
,
attn_mask
=
batched_encoder_mask
,
labels
=
batched_conts
),
dim
=-
1
,
).
cpu
()
# [batch, padding_length, vocab]
).
cpu
()
# [batch, padding_length, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
cont_toks
in
zip
(
for
(
cache_key
,
_
,
_
),
logits
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
cont_toks_list
chunk
,
multi_logits
,
cont_toks_list
):
):
# Slice to original seq length
# Slice to original seq length
contlen
=
len
(
cont_toks
)
contlen
=
len
(
cont_toks
)
logits
=
logits
[:
contlen
].
unsqueeze
(
logits
=
logits
[:
contlen
].
unsqueeze
(
0
)
# [1, seq, vocab]
0
)
# [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
# Check if per-token argmax is exactly equal to continuation
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
...
@@ -329,7 +348,7 @@ class Seq2SeqHFLM(LM):
...
@@ -329,7 +348,7 @@ class Seq2SeqHFLM(LM):
res
.
append
(
answer
)
res
.
append
(
answer
)
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
):
res
=
[]
res
=
[]
...
@@ -370,8 +389,8 @@ class Seq2SeqHFLM(LM):
...
@@ -370,8 +389,8 @@ class Seq2SeqHFLM(LM):
).
to
(
self
.
device
)
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
cont
=
self
.
_model_generate
(
context
=
context_enc
,
context
=
context_enc
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
stop
=
primary_until
,
stop
=
primary_until
,
**
gen_kwargs
,
**
gen_kwargs
,
)
)
...
@@ -383,4 +402,3 @@ class Seq2SeqHFLM(LM):
...
@@ -383,4 +402,3 @@ class Seq2SeqHFLM(LM):
res
.
append
(
s
)
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
lm_eval/tasks/__init__.py
View file @
9cf4a104
...
@@ -22,7 +22,7 @@ def include_task_folder(task_dir):
...
@@ -22,7 +22,7 @@ def include_task_folder(task_dir):
Calling this function
Calling this function
"""
"""
for
root
,
subdirs
,
file_list
in
os
.
walk
(
task_dir
):
for
root
,
subdirs
,
file_list
in
os
.
walk
(
task_dir
):
if
(
len
(
file_list
)
>
0
)
:
if
len
(
file_list
)
>
0
:
for
f
in
file_list
:
for
f
in
file_list
:
if
f
.
endswith
(
".yaml"
):
if
f
.
endswith
(
".yaml"
):
yaml_path
=
os
.
path
.
join
(
root
,
f
)
yaml_path
=
os
.
path
.
join
(
root
,
f
)
...
...
lm_eval/utils.py
View file @
9cf4a104
...
@@ -18,11 +18,12 @@ import torch
...
@@ -18,11 +18,12 @@ import torch
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
from
itertools
import
islice
from
itertools
import
islice
import
torch
import
transformers
import
transformers
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
class
ExitCodeError
(
Exception
):
class
ExitCodeError
(
Exception
):
pass
pass
...
@@ -416,13 +417,16 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
...
@@ -416,13 +417,16 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
"""
"""
return
islice
(
raw_iterator
,
rank
,
limit
,
world_size
)
return
islice
(
raw_iterator
,
rank
,
limit
,
world_size
)
def
pad_and_concat
(
max_length
:
int
,
tensors
:
List
[
torch
.
Tensor
],
padding_side
=
"right"
):
"""
def
pad_and_concat
(
max_length
:
int
,
tensors
:
List
[
torch
.
Tensor
],
padding_side
=
"right"
):
Method for padding a list of tensors given the maximum tensor
"""
length in the batch. Used for batching inputs and continuations in
Method for padding a list of tensors given the maximum tensor
seq2seq models.
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
"""
assert
padding_side
==
"left"
or
padding_side
==
"right"
,
f
"Unrecognized padding type: '
{
padding_side
}
' not 'left' or 'right'"
assert
(
padding_side
==
"left"
or
padding_side
==
"right"
),
f
"Unrecognized padding type: '
{
padding_side
}
' not 'left' or 'right'"
for
i
,
tensor
in
enumerate
(
tensors
):
for
i
,
tensor
in
enumerate
(
tensors
):
tensor_len
=
tensor
.
shape
[
0
]
tensor_len
=
tensor
.
shape
[
0
]
...
@@ -430,33 +434,33 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
...
@@ -430,33 +434,33 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
if
padding_side
==
"right"
:
if
padding_side
==
"right"
:
# right-pad
# right-pad
tensors
[
i
]
=
torch
.
cat
(
tensors
[
i
]
=
torch
.
cat
(
[
[
tensor
,
# [seq]
tensor
,
# [seq]
torch
.
zeros
(
torch
.
zeros
(
max_length
-
tensor_len
,
max_length
-
tensor_len
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
tensor
.
device
,
device
=
tensor
.
device
,
),
# [padding_length - seq]
),
# [padding_length - seq]
],
],
dim
=
0
,
dim
=
0
,
).
unsqueeze
(
0
)
).
unsqueeze
(
0
)
else
:
else
:
# left-pad
# left-pad
tensors
[
i
]
=
torch
.
cat
(
tensors
[
i
]
=
torch
.
cat
(
[
[
torch
.
zeros
(
torch
.
zeros
(
max_length
-
tensor_len
,
max_length
-
tensor_len
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
tensor
.
device
,
device
=
tensor
.
device
,
),
# [padding_length - seq]
),
# [padding_length - seq]
tensor
,
# [seq]
tensor
,
# [seq]
],
],
dim
=
0
,
dim
=
0
,
).
unsqueeze
(
0
)
).
unsqueeze
(
0
)
else
:
else
:
tensors
[
i
]
=
tensor
.
unsqueeze
(
0
)
tensors
[
i
]
=
tensor
.
unsqueeze
(
0
)
return
torch
.
cat
(
tensors
,
dim
=
0
)
return
torch
.
cat
(
tensors
,
dim
=
0
)
def
clear_torch_cache
():
def
clear_torch_cache
():
...
@@ -464,7 +468,7 @@ def clear_torch_cache():
...
@@ -464,7 +468,7 @@ def clear_torch_cache():
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# Multi-token stopping criteria
# Multi-token stopping criteria
class
MultiTokenEOSCriteria
(
transformers
.
StoppingCriteria
):
class
MultiTokenEOSCriteria
(
transformers
.
StoppingCriteria
):
"""Criteria to stop on the specified multi-token sequence."""
"""Criteria to stop on the specified multi-token sequence."""
...
@@ -511,4 +515,4 @@ def stop_sequences_criteria(
...
@@ -511,4 +515,4 @@ def stop_sequences_criteria(
for
sequence
in
stop_sequences
for
sequence
in
stop_sequences
],
],
]
]
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment