Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0b4f88dd
"host/online_compile/hip_utility/hipoc_kernel.cpp" did not exist on "1685048a6725e531b577510295d2d62664c15962"
Commit
0b4f88dd
authored
Jun 20, 2023
by
haileyschoelkopf
Committed by
lintangsutawika
Jun 22, 2023
Browse files
make seq2seq take correct args format
parent
1a6b31a8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
65 deletions
+63
-65
lm_eval/models/seq2seq.py
lm_eval/models/seq2seq.py
+63
-65
No files found.
lm_eval/models/seq2seq.py
View file @
0b4f88dd
import
torch
import
torch
import
transformers
import
transformers
import
copy
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -10,8 +11,9 @@ from lm_eval.logger import eval_logger
...
@@ -10,8 +11,9 @@ from lm_eval.logger import eval_logger
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.registry
import
register_model
from
lm_eval.api.model
import
LM
from
lm_eval.api.model
import
LM
from
lm_eval.utils
import
MultiTokenEOSCriteria
,
stop_sequences_criteria
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
typing
import
List
@
register_model
(
"hf-seq2seq"
,
"seq2seq"
)
@
register_model
(
"hf-seq2seq"
,
"seq2seq"
)
...
@@ -83,6 +85,14 @@ class Seq2SeqHFLM(LM):
...
@@ -83,6 +85,14 @@ class Seq2SeqHFLM(LM):
print
(
warning
)
print
(
warning
)
self
.
_rank
=
accelerator
.
local_process_index
self
.
_rank
=
accelerator
.
local_process_index
self
.
_world_size
=
accelerator
.
num_processes
self
.
_world_size
=
accelerator
.
num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self
.
_device
=
(
torch
.
device
(
f
"cuda:
{
accelerator
.
local_process_index
}
"
)
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
)
self
.
model
.
to
(
self
.
device
)
else
:
else
:
self
.
model
=
accelerator
.
prepare
(
self
.
model
)
self
.
model
=
accelerator
.
prepare
(
self
.
model
)
self
.
_device
=
torch
.
device
(
f
"cuda:
{
accelerator
.
local_process_index
}
"
)
self
.
_device
=
torch
.
device
(
f
"cuda:
{
accelerator
.
local_process_index
}
"
)
...
@@ -142,18 +152,30 @@ class Seq2SeqHFLM(LM):
...
@@ -142,18 +152,30 @@ class Seq2SeqHFLM(LM):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
def
_model_generate
(
self
,
context
,
max_length
,
stop
):
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if
"do_sample"
not
in
generation_kwargs
.
keys
():
generation_kwargs
[
"do_sample"
]
=
False
# build stopping criteria
stopping_criteria
=
stop_sequences_criteria
(
stopping_criteria
=
stop_sequences_criteria
(
self
.
tokenizer
,
stop
,
1
,
context
.
shape
[
0
]
self
.
tokenizer
,
stop
,
1
,
context
.
shape
[
0
]
)
)
if
hasattr
(
self
,
"accelerator"
):
return
self
.
accelerator
.
unwrap_model
(
self
.
model
).
generate
(
context
,
max_new_tokens
=
max_length
,
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
self
.
eot_token_id
,
**
generation_kwargs
,
)
else
:
return
self
.
model
.
generate
(
return
self
.
model
.
generate
(
context
,
context
,
max_new_tokens
=
max_length
,
max_new_tokens
=
max_length
,
stopping_criteria
=
stopping_criteria
,
stopping_criteria
=
stopping_criteria
,
do_sample
=
False
,
pad_token_id
=
self
.
eot_token_id
,
**
generation_kwargs
,
)
)
def
loglikelihood
(
self
,
requests
):
def
loglikelihood
(
self
,
requests
):
...
@@ -173,7 +195,7 @@ class Seq2SeqHFLM(LM):
...
@@ -173,7 +195,7 @@ class Seq2SeqHFLM(LM):
def
loglikelihood_rolling
(
self
,
requests
):
def
loglikelihood_rolling
(
self
,
requests
):
loglikelihoods
=
[]
loglikelihoods
=
[]
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
]):
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
]
,
disable
=
(
self
.
rank
!=
0
)
):
rolling_token_windows
=
list
(
rolling_token_windows
=
list
(
map
(
map
(
utils
.
make_disjoint_window
,
utils
.
make_disjoint_window
,
...
@@ -317,9 +339,30 @@ class Seq2SeqHFLM(LM):
...
@@ -317,9 +339,30 @@ class Seq2SeqHFLM(LM):
re_ord
=
utils
.
Reorderer
([
req
.
args
for
req
in
requests
],
_collate
)
re_ord
=
utils
.
Reorderer
([
req
.
args
for
req
in
requests
],
_collate
)
for
context
,
until
in
tqdm
(
re_ord
.
get_reordered
()):
for
context
,
gen_kwargs
in
tqdm
(
re_ord
.
get_reordered
()):
until
=
None
if
isinstance
(
gen_kwargs
,
dict
):
gen_kwargs
=
copy
.
deepcopy
(
gen_kwargs
)
# edge case for repeats > 1
print
(
gen_kwargs
)
if
"until"
in
gen_kwargs
.
keys
():
until
=
gen_kwargs
.
pop
(
"until"
)
if
isinstance
(
until
,
str
):
if
isinstance
(
until
,
str
):
until
=
[
until
]
until
=
[
gen_kwargs
]
elif
not
isinstance
(
until
,
list
):
raise
ValueError
(
f
"Expected `gen_kwargs['until']` to be of type Union[str,list] but got
{
until
}
"
)
else
:
raise
ValueError
(
f
"Expected `gen_kwargs` to be of type `dict` but got
{
gen_kwargs
}
"
)
if
not
until
:
until
=
[
self
.
tok_decode
(
self
.
eot_token_id
)]
if
"max_gen_toks"
in
gen_kwargs
.
keys
():
max_gen_toks
=
gen_kwargs
.
pop
(
"max_gen_toks"
)
else
:
max_gen_toks
=
self
.
max_gen_toks
(
primary_until
)
=
until
[
0
]
(
primary_until
)
=
until
[
0
]
context_enc
=
torch
.
tensor
(
context_enc
=
torch
.
tensor
(
...
@@ -327,62 +370,17 @@ class Seq2SeqHFLM(LM):
...
@@ -327,62 +370,17 @@ class Seq2SeqHFLM(LM):
).
to
(
self
.
device
)
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
cont
=
self
.
_model_generate
(
context_enc
,
context_enc
.
shape
[
1
]
+
self
.
max_gen_toks
,
primary_until
context
=
context_enc
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
stop
=
primary_until
,
**
gen_kwargs
,
)
)
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
())
s
=
self
.
tok_decode
(
cont
[
0
].
tolist
())
print
(
s
)
for
term
in
until
:
for
term
in
until
:
s
=
s
.
split
(
term
)[
0
]
s
=
s
.
split
(
term
)[
0
]
print
(
s
)
res
.
append
(
s
)
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
class
MultiTokenEOSCriteria
(
transformers
.
StoppingCriteria
):
"""Criteria to stop on the specified multi-token sequence."""
def
__init__
(
self
,
sequence
:
str
,
tokenizer
:
transformers
.
PreTrainedTokenizer
,
initial_decoder_input_length
:
int
,
batch_size
:
int
,
):
self
.
initial_decoder_input_length
=
initial_decoder_input_length
self
.
done_tracker
=
[
False
]
*
batch_size
self
.
sequence
=
sequence
self
.
sequence_ids
=
tokenizer
.
encode
(
sequence
,
add_special_tokens
=
False
)
self
.
sequence_id_len
=
len
(
self
.
sequence_ids
)
self
.
tokenizer
=
tokenizer
def
__call__
(
self
,
input_ids
,
scores
,
**
kwargs
)
->
bool
:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch
=
input_ids
[:,
self
.
initial_decoder_input_length
:][
:,
-
self
.
sequence_id_len
:
]
lookback_tokens_batch
=
self
.
tokenizer
.
batch_decode
(
lookback_ids_batch
)
for
i
,
done
in
enumerate
(
self
.
done_tracker
):
if
not
done
:
self
.
done_tracker
[
i
]
=
self
.
sequence
in
lookback_tokens_batch
[
i
]
return
False
not
in
self
.
done_tracker
def
stop_sequences_criteria
(
tokenizer
:
transformers
.
PreTrainedTokenizer
,
stop_sequences
:
List
[
str
],
initial_decoder_input_length
:
int
,
batch_size
:
int
,
)
->
transformers
.
StoppingCriteriaList
:
return
transformers
.
StoppingCriteriaList
(
[
*
[
MultiTokenEOSCriteria
(
sequence
,
tokenizer
,
initial_decoder_input_length
,
batch_size
)
for
sequence
in
stop_sequences
],
]
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment