Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
b48f5205
Commit
b48f5205
authored
Jun 20, 2023
by
haileyschoelkopf
Committed by
lintangsutawika
Jun 22, 2023
Browse files
more pre-commit
parent
86b71954
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
160 additions
and
103 deletions
+160
-103
lm_eval/api/task.py
lm_eval/api/task.py
+4
-2
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-4
lm_eval/models/hf_merged.py
lm_eval/models/hf_merged.py
+67
-52
lm_eval/models/seq2seq.py
lm_eval/models/seq2seq.py
+48
-30
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+1
-1
lm_eval/utils.py
lm_eval/utils.py
+38
-14
No files found.
lm_eval/api/task.py
View file @
b48f5205
...
@@ -102,7 +102,7 @@ class TaskConfig(dict):
...
@@ -102,7 +102,7 @@ class TaskConfig(dict):
assert
(
assert
(
self
.
output_type
==
"greedy_until"
self
.
output_type
==
"greedy_until"
),
"passed `generation_kwargs`, but not using a generation request type!"
),
"passed `generation_kwargs`, but not using a generation request type!"
elif
self
.
output_type
==
"greedy_until"
:
elif
self
.
output_type
==
"greedy_until"
:
# ensure that we greedily generate in absence of explicit arguments otherwise
# ensure that we greedily generate in absence of explicit arguments otherwise
self
.
generation_kwargs
=
{
"do_sample"
:
False
,
"temperature"
:
0.0
}
self
.
generation_kwargs
=
{
"do_sample"
:
False
,
"temperature"
:
0.0
}
...
@@ -883,7 +883,9 @@ class ConfigurableTask(Task):
...
@@ -883,7 +883,9 @@ class ConfigurableTask(Task):
for
key
,
result
in
zip
(
self
.
_metric_fn_list
.
keys
(),
results
):
for
key
,
result
in
zip
(
self
.
_metric_fn_list
.
keys
(),
results
):
_dict
=
self
.
_metric_fn_list
[
key
].
compute
(
_dict
=
self
.
_metric_fn_list
[
key
].
compute
(
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_metric_fn_kwargs
[
key
]
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_metric_fn_kwargs
[
key
],
)
)
result_dict
=
{
**
result_dict
,
**
_dict
}
result_dict
=
{
**
result_dict
,
**
_dict
}
...
...
lm_eval/evaluator.py
View file @
b48f5205
...
@@ -183,10 +183,8 @@ def evaluate(
...
@@ -183,10 +183,8 @@ def evaluate(
# get lists of each type of request
# get lists of each type of request
for
task_name
,
task
in
task_dict
.
items
():
for
task_name
,
task
in
task_dict
.
items
():
versions
[
task_name
]
=
task
.
VERSION
versions
[
task_name
]
=
task
.
VERSION
configs
[
task_name
]
=
dict
(
configs
[
task_name
]
=
dict
(
task
.
dump_config
())
task
.
dump_config
()
)
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func())
# task_docs = list(task_doc_func())
# rnd = random.Random()
# rnd = random.Random()
...
...
lm_eval/models/hf_merged.py
View file @
b48f5205
...
@@ -27,6 +27,7 @@ class HFLM(LM):
...
@@ -27,6 +27,7 @@ class HFLM(LM):
"""
"""
AUTO_MODEL_CLASS
=
None
AUTO_MODEL_CLASS
=
None
def
__init__
(
def
__init__
(
self
,
self
,
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -44,7 +45,7 @@ class HFLM(LM):
...
@@ -44,7 +45,7 @@ class HFLM(LM):
assert
isinstance
(
batch_size
,
int
)
assert
isinstance
(
batch_size
,
int
)
gpus
=
torch
.
cuda
.
device_count
()
gpus
=
torch
.
cuda
.
device_count
()
if
gpus
<=
1
:
if
gpus
<=
1
:
if
device
:
if
device
:
if
device
not
in
[
"cuda"
,
"cpu"
]:
if
device
not
in
[
"cuda"
,
"cpu"
]:
...
@@ -68,7 +69,7 @@ class HFLM(LM):
...
@@ -68,7 +69,7 @@ class HFLM(LM):
# TODO: update this to be less of a hack once subfolder is fixed in HF
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
)
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
)
# get config
# get config
self
.
_config
=
transformers
.
AutoConfig
.
from_pretrained
(
self
.
_config
=
transformers
.
AutoConfig
.
from_pretrained
(
pretrained
,
pretrained
,
revision
=
revision
,
revision
=
revision
,
...
@@ -77,9 +78,12 @@ class HFLM(LM):
...
@@ -77,9 +78,12 @@ class HFLM(LM):
if
getattr
(
self
.
_config
,
"model_type"
)
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
if
getattr
(
self
.
_config
,
"model_type"
)
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForCausalLM
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForCausalLM
else
:
else
:
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForSeq2SeqLM
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForSeq2SeqLM
assert
self
.
AUTO_MODEL_CLASS
in
[
transformers
.
AutoModelForCausalLM
,
transformers
.
AutoModelForSeq2SeqLM
]
assert
self
.
AUTO_MODEL_CLASS
in
[
transformers
.
AutoModelForCausalLM
,
transformers
.
AutoModelForSeq2SeqLM
,
]
self
.
_model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
self
.
_model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
pretrained
,
revision
=
revision
,
low_cpu_mem_usage
=
low_cpu_mem_usage
pretrained
,
revision
=
revision
,
low_cpu_mem_usage
=
low_cpu_mem_usage
...
@@ -127,7 +131,7 @@ class HFLM(LM):
...
@@ -127,7 +131,7 @@ class HFLM(LM):
self
.
_rank
=
self
.
accelerator
.
local_process_index
self
.
_rank
=
self
.
accelerator
.
local_process_index
self
.
_world_size
=
self
.
accelerator
.
num_processes
self
.
_world_size
=
self
.
accelerator
.
num_processes
@
property
@
property
def
config
(
self
):
def
config
(
self
):
# return the associated transformers.AutoConfig for the given pretrained model.
# return the associated transformers.AutoConfig for the given pretrained model.
...
@@ -175,20 +179,18 @@ class HFLM(LM):
...
@@ -175,20 +179,18 @@ class HFLM(LM):
return
self
.
_world_size
return
self
.
_world_size
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
=
None
):
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
=
None
):
"""
""" """
"""
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
add_special_tokens
=
False
add_special_tokens
=
False
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
add_special_tokens
=
True
add_special_tokens
=
True
encoding
=
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
add_special_tokens
)
encoding
=
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
add_special_tokens
)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if
left_truncate_len
:
if
left_truncate_len
:
encoding
=
encoding
[
-
left_truncate_len
:]
encoding
=
encoding
[
-
left_truncate_len
:]
return
encoding
return
encoding
def
tok_decode
(
self
,
tokens
):
def
tok_decode
(
self
,
tokens
):
...
@@ -197,23 +199,9 @@ class HFLM(LM):
...
@@ -197,23 +199,9 @@ class HFLM(LM):
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
"""
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
labels: a torch tensor of shape [batch, sequence_cont]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with
torch
.
no_grad
():
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
"""
"""
inps: torch.Tensor
:param
inps: torch.Tensor
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
[batch, sequence_ctx]. the size of sequence may vary from call to call
[batch, sequence_ctx]. the size of sequence may vary from call to call
:param attn_mask: torch.Tensor, optional
:param attn_mask: torch.Tensor, optional
...
@@ -229,7 +217,9 @@ class HFLM(LM):
...
@@ -229,7 +217,9 @@ class HFLM(LM):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
attn_mask
or
labels
:
if
attn_mask
or
labels
:
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
else
:
else
:
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
return
self
.
model
(
inps
).
logits
return
self
.
model
(
inps
).
logits
...
@@ -254,16 +244,20 @@ class HFLM(LM):
...
@@ -254,16 +244,20 @@ class HFLM(LM):
def
_select_cont_toks
(
self
,
logits
,
contlen
=
None
,
inplen
=
None
):
def
_select_cont_toks
(
self
,
logits
,
contlen
=
None
,
inplen
=
None
):
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
assert
(
contlen
and
inplen
),
"Must pass input len and cont. len to select scored logits for causal LM"
assert
(
contlen
and
inplen
),
"Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding.
# discard right-padding.
# also discard the input/context tokens. we'll only score continuations.
# also discard the input/context tokens. we'll only score continuations.
logits
=
logits
[
inplen
-
contlen
:
inplen
]
logits
=
logits
[
inplen
-
contlen
:
inplen
]
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
assert
(
contlen
and
not
inplen
),
"Selecting scored logits for Seq2SeqLM requires only cont. len"
assert
(
# only discard right-padding.
contlen
and
not
inplen
),
"Selecting scored logits for Seq2SeqLM requires only cont. len"
# only discard right-padding.
# the logits input to this fn only contain decoder-side tokens.
# the logits input to this fn only contain decoder-side tokens.
logits
=
logits
[:
contlen
]
logits
=
logits
[:
contlen
]
return
logits
return
logits
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
...
@@ -289,14 +283,14 @@ class HFLM(LM):
...
@@ -289,14 +283,14 @@ class HFLM(LM):
utils
.
make_disjoint_window
,
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
context_len
=
1
,
),
),
)
)
)
)
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
#
TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
pad_amnt
=
0
pad_amnt
=
0
...
@@ -386,11 +380,11 @@ class HFLM(LM):
...
@@ -386,11 +380,11 @@ class HFLM(LM):
inp
=
torch
.
tensor
(
inp
=
torch
.
tensor
(
(
context_enc
)[
-
self
.
max_length
:],
(
context_enc
)[
-
self
.
max_length
:],
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
self
.
device
device
=
self
.
device
,
)
)
(
inplen
,)
=
inp
.
shape
(
inplen
,)
=
inp
.
shape
cont
=
torch
.
tensor
(
cont
=
torch
.
tensor
(
(
continuation_enc
)[
-
self
.
max_length
:],
(
continuation_enc
)[
-
self
.
max_length
:],
# TODO: left-shift these?
# TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type
# TODO: our code assumes we never end up truncating conts for either model type
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
...
@@ -400,24 +394,43 @@ class HFLM(LM):
...
@@ -400,24 +394,43 @@ class HFLM(LM):
conts
.
append
(
cont
)
conts
.
append
(
cont
)
padding_len_cont
=
max
(
padding_len_cont
,
contlen
)
if
padding_len_cont
is
not
None
else
contlen
padding_len_cont
=
(
max
(
padding_len_cont
,
contlen
)
if
padding_len_cont
is
not
None
else
contlen
)
padding_len_inp
=
max
(
padding_len_inp
,
inplen
)
if
padding_len_inp
is
not
None
else
inplen
padding_len_inp
=
(
max
(
padding_len_inp
,
inplen
)
if
padding_len_inp
is
not
None
else
inplen
)
inps
.
append
(
inp
)
# [1, inp_length]
inps
.
append
(
inp
)
# [1, inp_length]
cont_toks_list
.
append
(
continuation_enc
)
cont_toks_list
.
append
(
continuation_enc
)
inplens
.
append
(
inplen
)
inplens
.
append
(
inplen
)
# create encoder attn mask and batched conts, if seq2seq
# create encoder attn mask and batched conts, if seq2seq
call_kwargs
=
{}
call_kwargs
=
{}
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
,
padding_side
=
"right"
)
# [batch, padding_len_inp]
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
,
padding_side
=
"right"
)
# [batch, padding_len_inp]
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
# TODO: left-pad encoder inps and mask?
# TODO: left-pad encoder inps and mask?
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
)
# [batch, padding_len_inp]
batched_inps
=
utils
.
pad_and_concat
(
batched_conts
=
utils
.
pad_and_concat
(
padding_len_cont
,
conts
)
# [batch, padding_len_cont]
padding_len_inp
,
inps
batched_encoder_mask
=
utils
.
pad_and_concat
(
padding_len_inp
,
encoder_attns
)
# [batch, padding_len_inp]
)
# [batch, padding_len_inp]
call_kwargs
=
{
"attn_mask"
:
batched_encoder_mask
,
"labels"
:
batched_conts
}
batched_conts
=
utils
.
pad_and_concat
(
padding_len_cont
,
conts
)
# [batch, padding_len_cont]
batched_encoder_mask
=
utils
.
pad_and_concat
(
padding_len_inp
,
encoder_attns
)
# [batch, padding_len_inp]
call_kwargs
=
{
"attn_mask"
:
batched_encoder_mask
,
"labels"
:
batched_conts
,
}
multi_logits
=
F
.
log_softmax
(
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
...
@@ -429,13 +442,15 @@ class HFLM(LM):
...
@@ -429,13 +442,15 @@ class HFLM(LM):
# Slice to original seq length
# Slice to original seq length
contlen
=
len
(
cont_toks
)
contlen
=
len
(
cont_toks
)
# take only logits in the continuation
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
# (discard context toks if decoder-only ; discard right-padding)
ctx_len
=
inplen
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
else
None
ctx_len
=
(
inplen
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
else
None
)
logits
=
self
.
_select_cont_toks
(
logits
,
contlen
=
contlen
,
inplen
=
ctx_len
)
logits
=
self
.
_select_cont_toks
(
logits
,
contlen
=
contlen
,
inplen
=
ctx_len
)
logits
=
logits
.
unsqueeze
(
logits
=
logits
.
unsqueeze
(
0
)
# [1, seq, vocab]
0
)
# [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
# Check if per-token argmax is exactly equal to continuation
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
...
@@ -506,8 +521,8 @@ class HFLM(LM):
...
@@ -506,8 +521,8 @@ class HFLM(LM):
).
to
(
self
.
device
)
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
cont
=
self
.
_model_generate
(
context
=
context_enc
,
context
=
context_enc
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
stop
=
primary_until
,
stop
=
primary_until
,
**
gen_kwargs
,
**
gen_kwargs
,
)
)
...
@@ -519,4 +534,4 @@ class HFLM(LM):
...
@@ -519,4 +534,4 @@ class HFLM(LM):
res
.
append
(
s
)
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
\ No newline at end of file
lm_eval/models/seq2seq.py
View file @
b48f5205
...
@@ -19,6 +19,7 @@ from accelerate import Accelerator
...
@@ -19,6 +19,7 @@ from accelerate import Accelerator
@
register_model
(
"hf-seq2seq"
,
"seq2seq"
)
@
register_model
(
"hf-seq2seq"
,
"seq2seq"
)
class
Seq2SeqHFLM
(
LM
):
class
Seq2SeqHFLM
(
LM
):
_DEFAULT_MAX_LENGTH
:
int
=
2048
_DEFAULT_MAX_LENGTH
:
int
=
2048
def
__init__
(
def
__init__
(
self
,
self
,
device
=
"cuda"
,
device
=
"cuda"
,
...
@@ -111,7 +112,8 @@ class Seq2SeqHFLM(LM):
...
@@ -111,7 +112,8 @@ class Seq2SeqHFLM(LM):
@
property
@
property
def
max_length
(
self
):
def
max_length
(
self
):
return
self
.
_DEFAULT_MAX_LENGTH
#TODO: Is this a good default?
return
self
.
_DEFAULT_MAX_LENGTH
# TODO: Is this a good default?
@
property
@
property
def
max_gen_toks
(
self
):
def
max_gen_toks
(
self
):
return
256
return
256
...
@@ -131,14 +133,14 @@ class Seq2SeqHFLM(LM):
...
@@ -131,14 +133,14 @@ class Seq2SeqHFLM(LM):
@
property
@
property
def
world_size
(
self
):
def
world_size
(
self
):
return
self
.
_world_size
return
self
.
_world_size
def
tok_encode
(
self
,
string
:
str
):
def
tok_encode
(
self
,
string
:
str
):
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
True
)
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
True
)
def
tok_decode
(
self
,
tokens
):
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
"""
"""
inps: a torch tensor of shape [batch, sequence_ctx]
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
the size of sequence may vary from call to call
...
@@ -150,8 +152,10 @@ class Seq2SeqHFLM(LM):
...
@@ -150,8 +152,10 @@ class Seq2SeqHFLM(LM):
logits returned from the model
logits returned from the model
"""
"""
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
# we require users to pass do_sample=True explicitly
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
# for non-greedy gen. This should be reevaluated when considering beam search.
...
@@ -176,8 +180,8 @@ class Seq2SeqHFLM(LM):
...
@@ -176,8 +180,8 @@ class Seq2SeqHFLM(LM):
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
self
.
eot_token_id
,
pad_token_id
=
self
.
eot_token_id
,
**
generation_kwargs
,
**
generation_kwargs
,
)
)
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
new_reqs
=
[]
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
...
@@ -192,7 +196,7 @@ class Seq2SeqHFLM(LM):
...
@@ -192,7 +196,7 @@ class Seq2SeqHFLM(LM):
new_reqs
.
append
(((
context
,
continuation
),
context_enc
,
continuation_enc
))
new_reqs
.
append
(((
context
,
continuation
),
context_enc
,
continuation_enc
))
return
self
.
_loglikelihood_tokens
(
new_reqs
)
return
self
.
_loglikelihood_tokens
(
new_reqs
)
def
loglikelihood_rolling
(
self
,
requests
):
def
loglikelihood_rolling
(
self
,
requests
):
loglikelihoods
=
[]
loglikelihoods
=
[]
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
],
disable
=
(
self
.
rank
!=
0
)):
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
],
disable
=
(
self
.
rank
!=
0
)):
...
@@ -201,14 +205,14 @@ class Seq2SeqHFLM(LM):
...
@@ -201,14 +205,14 @@ class Seq2SeqHFLM(LM):
utils
.
make_disjoint_window
,
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
context_len
=
1
,
),
),
)
)
)
)
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
#
TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
pad_amnt
=
0
pad_amnt
=
0
...
@@ -237,7 +241,7 @@ class Seq2SeqHFLM(LM):
...
@@ -237,7 +241,7 @@ class Seq2SeqHFLM(LM):
loglikelihoods
.
append
(
string_nll
)
loglikelihoods
.
append
(
string_nll
)
return
loglikelihoods
return
loglikelihoods
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
res
=
[]
res
=
[]
...
@@ -251,7 +255,7 @@ class Seq2SeqHFLM(LM):
...
@@ -251,7 +255,7 @@ class Seq2SeqHFLM(LM):
toks
=
x
[
1
]
+
x
[
2
]
toks
=
x
[
1
]
+
x
[
2
]
return
-
len
(
toks
),
tuple
(
toks
)
return
-
len
(
toks
),
tuple
(
toks
)
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
utils
.
chunks
(
for
chunk
in
utils
.
chunks
(
tqdm
(
re_ord
.
get_reordered
(),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
))),
tqdm
(
re_ord
.
get_reordered
(),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
))),
...
@@ -261,7 +265,7 @@ class Seq2SeqHFLM(LM):
...
@@ -261,7 +265,7 @@ class Seq2SeqHFLM(LM):
conts
=
[]
conts
=
[]
encoder_attns
=
[]
encoder_attns
=
[]
cont_toks_list
=
[]
cont_toks_list
=
[]
max_batch_length_inp
=
None
max_batch_length_inp
=
None
max_batch_length_cont
=
None
max_batch_length_cont
=
None
...
@@ -283,33 +287,48 @@ class Seq2SeqHFLM(LM):
...
@@ -283,33 +287,48 @@ class Seq2SeqHFLM(LM):
).
to
(
self
.
device
)
).
to
(
self
.
device
)
(
contlen
,)
=
cont
.
shape
(
contlen
,)
=
cont
.
shape
max_batch_length_inp
=
max
(
max_batch_length_inp
,
inplen
)
if
max_batch_length_inp
is
not
None
else
inplen
max_batch_length_inp
=
(
max_batch_length_cont
=
max
(
max_batch_length_cont
,
contlen
)
if
max_batch_length_cont
is
not
None
else
contlen
max
(
max_batch_length_inp
,
inplen
)
if
max_batch_length_inp
is
not
None
else
inplen
)
max_batch_length_cont
=
(
max
(
max_batch_length_cont
,
contlen
)
if
max_batch_length_cont
is
not
None
else
contlen
)
inps
.
append
(
inp
)
# [1, inp_len]
inps
.
append
(
inp
)
# [1, inp_len]
conts
.
append
(
cont
)
# [1, cont_len]
conts
.
append
(
cont
)
# [1, cont_len]
encoder_attns
.
append
(
torch
.
ones_like
(
inp
))
encoder_attns
.
append
(
torch
.
ones_like
(
inp
))
cont_toks_list
.
append
(
continuation_enc
)
cont_toks_list
.
append
(
continuation_enc
)
batched_inps
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
inps
)
# [batch, padding_length]
batched_inps
=
utils
.
pad_and_concat
(
batched_conts
=
utils
.
pad_and_concat
(
max_batch_length_cont
,
conts
)
# [batch, padding_length]
max_batch_length_inp
,
inps
batched_encoder_mask
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
encoder_attns
)
)
# [batch, padding_length]
batched_conts
=
utils
.
pad_and_concat
(
max_batch_length_cont
,
conts
)
# [batch, padding_length]
batched_encoder_mask
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
encoder_attns
)
# need to make attention mask here too
# need to make attention mask here too
multi_logits
=
F
.
log_softmax
(
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
,
attn_mask
=
batched_encoder_mask
,
labels
=
batched_conts
),
dim
=-
1
self
.
_model_call
(
batched_inps
,
attn_mask
=
batched_encoder_mask
,
labels
=
batched_conts
),
dim
=-
1
,
).
cpu
()
# [batch, padding_length, vocab]
).
cpu
()
# [batch, padding_length, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
cont_toks
in
zip
(
for
(
cache_key
,
_
,
_
),
logits
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
cont_toks_list
chunk
,
multi_logits
,
cont_toks_list
):
):
# Slice to original seq length
# Slice to original seq length
contlen
=
len
(
cont_toks
)
contlen
=
len
(
cont_toks
)
logits
=
logits
[:
contlen
].
unsqueeze
(
logits
=
logits
[:
contlen
].
unsqueeze
(
0
)
# [1, seq, vocab]
0
)
# [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
# Check if per-token argmax is exactly equal to continuation
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
...
@@ -329,7 +348,7 @@ class Seq2SeqHFLM(LM):
...
@@ -329,7 +348,7 @@ class Seq2SeqHFLM(LM):
res
.
append
(
answer
)
res
.
append
(
answer
)
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
def
greedy_until
(
self
,
requests
):
res
=
[]
res
=
[]
...
@@ -370,8 +389,8 @@ class Seq2SeqHFLM(LM):
...
@@ -370,8 +389,8 @@ class Seq2SeqHFLM(LM):
).
to
(
self
.
device
)
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
cont
=
self
.
_model_generate
(
context
=
context_enc
,
context
=
context_enc
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
stop
=
primary_until
,
stop
=
primary_until
,
**
gen_kwargs
,
**
gen_kwargs
,
)
)
...
@@ -383,4 +402,3 @@ class Seq2SeqHFLM(LM):
...
@@ -383,4 +402,3 @@ class Seq2SeqHFLM(LM):
res
.
append
(
s
)
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
lm_eval/tasks/__init__.py
View file @
b48f5205
...
@@ -22,7 +22,7 @@ def include_task_folder(task_dir):
...
@@ -22,7 +22,7 @@ def include_task_folder(task_dir):
Calling this function
Calling this function
"""
"""
for
root
,
subdirs
,
file_list
in
os
.
walk
(
task_dir
):
for
root
,
subdirs
,
file_list
in
os
.
walk
(
task_dir
):
if
(
len
(
file_list
)
>
0
)
:
if
len
(
file_list
)
>
0
:
for
f
in
file_list
:
for
f
in
file_list
:
if
f
.
endswith
(
".yaml"
):
if
f
.
endswith
(
".yaml"
):
yaml_path
=
os
.
path
.
join
(
root
,
f
)
yaml_path
=
os
.
path
.
join
(
root
,
f
)
...
...
lm_eval/utils.py
View file @
b48f5205
...
@@ -19,6 +19,11 @@ import transformers
...
@@ -19,6 +19,11 @@ import transformers
from
omegaconf
import
OmegaConf
from
omegaconf
import
OmegaConf
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
from
itertools
import
islice
from
itertools
import
islice
<<<<<<<
HEAD
=======
import
transformers
>>>>>>>
more
pre
-
commit
from
lm_eval.logger
import
eval_logger
from
lm_eval.logger
import
eval_logger
...
@@ -417,6 +422,7 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
...
@@ -417,6 +422,7 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return
islice
(
raw_iterator
,
rank
,
limit
,
world_size
)
return
islice
(
raw_iterator
,
rank
,
limit
,
world_size
)
<<<<<<<
HEAD
def
clear_torch_cache
():
def
clear_torch_cache
():
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -437,8 +443,17 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
...
@@ -437,8 +443,17 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
Method for padding a list of tensors given the maximum tensor
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
length in the batch. Used for batching inputs and continuations in
seq2seq models.
seq2seq models.
=======
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
"""
Method
for
padding
a
list
of
tensors
given
the
maximum
tensor
length
in
the
batch
.
Used
for
batching
inputs
and
continuations
in
seq2seq
models
.
>>>>>>>
more
pre
-
commit
"""
"""
assert
padding_side
==
"left"
or
padding_side
==
"right"
,
f
"Unrecognized padding type: '
{
padding_side
}
' not 'left' or 'right'"
assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0]
tensor_len = tensor.shape[0]
...
@@ -446,36 +461,45 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
...
@@ -446,36 +461,45 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
if padding_side == "right":
if padding_side == "right":
# right-pad
# right-pad
tensors[i] = torch.cat(
tensors[i] = torch.cat(
[
[
tensor
,
# [seq]
tensor, # [seq]
torch
.
zeros
(
torch.zeros(
max_length
-
tensor_len
,
max_length - tensor_len,
dtype
=
torch
.
long
,
dtype=torch.long,
device
=
tensor
.
device
,
device=tensor.device,
),
# [padding_length - seq]
), # [padding_length - seq]
],
],
dim
=
0
,
dim=0,
).
unsqueeze
(
0
)
).unsqueeze(0)
else:
else:
# left-pad
# left-pad
tensors[i] = torch.cat(
tensors[i] = torch.cat(
[
[
torch.zeros(
torch.zeros(
max_length
-
tensor_len
,
max_length - tensor_len,
dtype=torch.long,
dtype=torch.long,
device=tensor.device,
device=tensor.device,
), # [padding_length - seq]
), # [padding_length - seq]
tensor
,
# [seq]
tensor,
# [seq]
],
],
dim=0,
dim=0,
).unsqueeze(0)
).unsqueeze(0)
else:
else:
tensors[i] = tensor.unsqueeze(0)
tensors[i] = tensor.unsqueeze(0)
return
torch
.
cat
(
tensors
,
dim
=
0
)
return torch.cat(tensors, dim
=
0)
<<<<<<< HEAD
# Multi-token stopping criteria
# Multi-token stopping criteria
=======
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
# Multi-token stopping criteria
>>>>>>> more pre-commit
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""
Criteria
to
stop
on
the
specified
multi
-
token
sequence
.
"""
"""
Criteria
to
stop
on
the
specified
multi
-
token
sequence
.
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment