Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
b48f5205
Commit
b48f5205
authored
Jun 20, 2023
by
haileyschoelkopf
Committed by
lintangsutawika
Jun 22, 2023
Browse files
more pre-commit
parent
86b71954
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
160 additions
and
103 deletions
+160
-103
lm_eval/api/task.py
lm_eval/api/task.py
+4
-2
lm_eval/evaluator.py
lm_eval/evaluator.py
+2
-4
lm_eval/models/hf_merged.py
lm_eval/models/hf_merged.py
+67
-52
lm_eval/models/seq2seq.py
lm_eval/models/seq2seq.py
+48
-30
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+1
-1
lm_eval/utils.py
lm_eval/utils.py
+38
-14
No files found.
lm_eval/api/task.py
View file @
b48f5205
...
...
@@ -102,7 +102,7 @@ class TaskConfig(dict):
assert
(
self
.
output_type
==
"greedy_until"
),
"passed `generation_kwargs`, but not using a generation request type!"
elif
self
.
output_type
==
"greedy_until"
:
elif
self
.
output_type
==
"greedy_until"
:
# ensure that we greedily generate in absence of explicit arguments otherwise
self
.
generation_kwargs
=
{
"do_sample"
:
False
,
"temperature"
:
0.0
}
...
...
@@ -883,7 +883,9 @@ class ConfigurableTask(Task):
for
key
,
result
in
zip
(
self
.
_metric_fn_list
.
keys
(),
results
):
_dict
=
self
.
_metric_fn_list
[
key
].
compute
(
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_metric_fn_kwargs
[
key
]
references
=
[
gold
],
predictions
=
[
result
],
**
self
.
_metric_fn_kwargs
[
key
],
)
result_dict
=
{
**
result_dict
,
**
_dict
}
...
...
lm_eval/evaluator.py
View file @
b48f5205
...
...
@@ -183,10 +183,8 @@ def evaluate(
# get lists of each type of request
for
task_name
,
task
in
task_dict
.
items
():
versions
[
task_name
]
=
task
.
VERSION
configs
[
task_name
]
=
dict
(
task
.
dump_config
()
)
configs
[
task_name
]
=
dict
(
task
.
dump_config
())
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func())
# rnd = random.Random()
...
...
lm_eval/models/hf_merged.py
View file @
b48f5205
...
...
@@ -27,6 +27,7 @@ class HFLM(LM):
"""
AUTO_MODEL_CLASS
=
None
def
__init__
(
self
,
device
=
"cuda"
,
...
...
@@ -44,7 +45,7 @@ class HFLM(LM):
assert
isinstance
(
batch_size
,
int
)
gpus
=
torch
.
cuda
.
device_count
()
if
gpus
<=
1
:
if
device
:
if
device
not
in
[
"cuda"
,
"cpu"
]:
...
...
@@ -68,7 +69,7 @@ class HFLM(LM):
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision
=
revision
+
(
"/"
+
subfolder
if
subfolder
is
not
None
else
""
)
# get config
# get config
self
.
_config
=
transformers
.
AutoConfig
.
from_pretrained
(
pretrained
,
revision
=
revision
,
...
...
@@ -77,9 +78,12 @@ class HFLM(LM):
if
getattr
(
self
.
_config
,
"model_type"
)
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForCausalLM
else
:
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForSeq2SeqLM
self
.
AUTO_MODEL_CLASS
=
transformers
.
AutoModelForSeq2SeqLM
assert
self
.
AUTO_MODEL_CLASS
in
[
transformers
.
AutoModelForCausalLM
,
transformers
.
AutoModelForSeq2SeqLM
]
assert
self
.
AUTO_MODEL_CLASS
in
[
transformers
.
AutoModelForCausalLM
,
transformers
.
AutoModelForSeq2SeqLM
,
]
self
.
_model
=
self
.
AUTO_MODEL_CLASS
.
from_pretrained
(
pretrained
,
revision
=
revision
,
low_cpu_mem_usage
=
low_cpu_mem_usage
...
...
@@ -127,7 +131,7 @@ class HFLM(LM):
self
.
_rank
=
self
.
accelerator
.
local_process_index
self
.
_world_size
=
self
.
accelerator
.
num_processes
@
property
def
config
(
self
):
# return the associated transformers.AutoConfig for the given pretrained model.
...
...
@@ -175,20 +179,18 @@ class HFLM(LM):
return
self
.
_world_size
def
tok_encode
(
self
,
string
:
str
,
left_truncate_len
=
None
):
"""
"""
""" """
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
add_special_tokens
=
False
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
add_special_tokens
=
True
encoding
=
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
add_special_tokens
)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if
left_truncate_len
:
encoding
=
encoding
[
-
left_truncate_len
:]
return
encoding
def
tok_decode
(
self
,
tokens
):
...
...
@@ -197,23 +199,9 @@ class HFLM(LM):
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
"""
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
labels: a torch tensor of shape [batch, sequence_cont]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with
torch
.
no_grad
():
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
"""
inps: torch.Tensor
:param
inps: torch.Tensor
A torch tensor of shape [batch, (sequence_ctx + sequence_cont)] or of shape
[batch, sequence_ctx]. the size of sequence may vary from call to call
:param attn_mask: torch.Tensor, optional
...
...
@@ -229,7 +217,9 @@ class HFLM(LM):
with
torch
.
no_grad
():
if
attn_mask
or
labels
:
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
else
:
assert
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
return
self
.
model
(
inps
).
logits
...
...
@@ -254,16 +244,20 @@ class HFLM(LM):
def
_select_cont_toks
(
self
,
logits
,
contlen
=
None
,
inplen
=
None
):
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
assert
(
contlen
and
inplen
),
"Must pass input len and cont. len to select scored logits for causal LM"
assert
(
contlen
and
inplen
),
"Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding.
# also discard the input/context tokens. we'll only score continuations.
logits
=
logits
[
inplen
-
contlen
:
inplen
]
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
assert
(
contlen
and
not
inplen
),
"Selecting scored logits for Seq2SeqLM requires only cont. len"
# only discard right-padding.
assert
(
contlen
and
not
inplen
),
"Selecting scored logits for Seq2SeqLM requires only cont. len"
# only discard right-padding.
# the logits input to this fn only contain decoder-side tokens.
logits
=
logits
[:
contlen
]
logits
=
logits
[:
contlen
]
return
logits
def
loglikelihood
(
self
,
requests
):
...
...
@@ -289,14 +283,14 @@ class HFLM(LM):
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
),
)
)
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
#
TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
pad_amnt
=
0
...
...
@@ -386,11 +380,11 @@ class HFLM(LM):
inp
=
torch
.
tensor
(
(
context_enc
)[
-
self
.
max_length
:],
dtype
=
torch
.
long
,
device
=
self
.
device
device
=
self
.
device
,
)
(
inplen
,)
=
inp
.
shape
cont
=
torch
.
tensor
(
(
continuation_enc
)[
-
self
.
max_length
:],
(
continuation_enc
)[
-
self
.
max_length
:],
# TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type
dtype
=
torch
.
long
,
...
...
@@ -400,24 +394,43 @@ class HFLM(LM):
conts
.
append
(
cont
)
padding_len_cont
=
max
(
padding_len_cont
,
contlen
)
if
padding_len_cont
is
not
None
else
contlen
padding_len_cont
=
(
max
(
padding_len_cont
,
contlen
)
if
padding_len_cont
is
not
None
else
contlen
)
padding_len_inp
=
max
(
padding_len_inp
,
inplen
)
if
padding_len_inp
is
not
None
else
inplen
padding_len_inp
=
(
max
(
padding_len_inp
,
inplen
)
if
padding_len_inp
is
not
None
else
inplen
)
inps
.
append
(
inp
)
# [1, inp_length]
cont_toks_list
.
append
(
continuation_enc
)
inplens
.
append
(
inplen
)
# create encoder attn mask and batched conts, if seq2seq
call_kwargs
=
{}
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
,
padding_side
=
"right"
)
# [batch, padding_len_inp]
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
,
padding_side
=
"right"
)
# [batch, padding_len_inp]
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
# TODO: left-pad encoder inps and mask?
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
)
# [batch, padding_len_inp]
batched_conts
=
utils
.
pad_and_concat
(
padding_len_cont
,
conts
)
# [batch, padding_len_cont]
batched_encoder_mask
=
utils
.
pad_and_concat
(
padding_len_inp
,
encoder_attns
)
# [batch, padding_len_inp]
call_kwargs
=
{
"attn_mask"
:
batched_encoder_mask
,
"labels"
:
batched_conts
}
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
)
# [batch, padding_len_inp]
batched_conts
=
utils
.
pad_and_concat
(
padding_len_cont
,
conts
)
# [batch, padding_len_cont]
batched_encoder_mask
=
utils
.
pad_and_concat
(
padding_len_inp
,
encoder_attns
)
# [batch, padding_len_inp]
call_kwargs
=
{
"attn_mask"
:
batched_encoder_mask
,
"labels"
:
batched_conts
,
}
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
...
...
@@ -429,13 +442,15 @@ class HFLM(LM):
# Slice to original seq length
contlen
=
len
(
cont_toks
)
# take only logits in the continuation
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
ctx_len
=
inplen
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
else
None
ctx_len
=
(
inplen
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
else
None
)
logits
=
self
.
_select_cont_toks
(
logits
,
contlen
=
contlen
,
inplen
=
ctx_len
)
logits
=
logits
.
unsqueeze
(
0
)
# [1, seq, vocab]
logits
=
logits
.
unsqueeze
(
0
)
# [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
...
...
@@ -506,8 +521,8 @@ class HFLM(LM):
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context
=
context_enc
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
context
=
context_enc
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
stop
=
primary_until
,
**
gen_kwargs
,
)
...
...
@@ -519,4 +534,4 @@ class HFLM(LM):
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
\ No newline at end of file
return
re_ord
.
get_original
(
res
)
lm_eval/models/seq2seq.py
View file @
b48f5205
...
...
@@ -19,6 +19,7 @@ from accelerate import Accelerator
@
register_model
(
"hf-seq2seq"
,
"seq2seq"
)
class
Seq2SeqHFLM
(
LM
):
_DEFAULT_MAX_LENGTH
:
int
=
2048
def
__init__
(
self
,
device
=
"cuda"
,
...
...
@@ -111,7 +112,8 @@ class Seq2SeqHFLM(LM):
@
property
def
max_length
(
self
):
return
self
.
_DEFAULT_MAX_LENGTH
#TODO: Is this a good default?
return
self
.
_DEFAULT_MAX_LENGTH
# TODO: Is this a good default?
@
property
def
max_gen_toks
(
self
):
return
256
...
...
@@ -131,14 +133,14 @@ class Seq2SeqHFLM(LM):
@
property
def
world_size
(
self
):
return
self
.
_world_size
def
tok_encode
(
self
,
string
:
str
):
return
self
.
tokenizer
.
encode
(
string
,
add_special_tokens
=
True
)
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
,
skip_special_tokens
=
True
)
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
def
_model_call
(
self
,
inps
,
attn_mask
=
None
,
labels
=
None
):
"""
inps: a torch tensor of shape [batch, sequence_ctx]
the size of sequence may vary from call to call
...
...
@@ -150,8 +152,10 @@ class Seq2SeqHFLM(LM):
logits returned from the model
"""
with
torch
.
no_grad
():
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
return
self
.
model
(
input_ids
=
inps
,
attention_mask
=
attn_mask
,
labels
=
labels
).
logits
def
_model_generate
(
self
,
context
,
max_length
,
stop
,
**
generation_kwargs
):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
...
...
@@ -176,8 +180,8 @@ class Seq2SeqHFLM(LM):
stopping_criteria
=
stopping_criteria
,
pad_token_id
=
self
.
eot_token_id
,
**
generation_kwargs
,
)
)
def
loglikelihood
(
self
,
requests
):
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
...
...
@@ -192,7 +196,7 @@ class Seq2SeqHFLM(LM):
new_reqs
.
append
(((
context
,
continuation
),
context_enc
,
continuation_enc
))
return
self
.
_loglikelihood_tokens
(
new_reqs
)
def
loglikelihood_rolling
(
self
,
requests
):
loglikelihoods
=
[]
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
],
disable
=
(
self
.
rank
!=
0
)):
...
...
@@ -201,14 +205,14 @@ class Seq2SeqHFLM(LM):
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
),
)
)
#TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
#
TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
pad_amnt
=
0
...
...
@@ -237,7 +241,7 @@ class Seq2SeqHFLM(LM):
loglikelihoods
.
append
(
string_nll
)
return
loglikelihoods
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
):
res
=
[]
...
...
@@ -251,7 +255,7 @@ class Seq2SeqHFLM(LM):
toks
=
x
[
1
]
+
x
[
2
]
return
-
len
(
toks
),
tuple
(
toks
)
re_ord
=
utils
.
Reorderer
(
requests
,
_collate
)
for
chunk
in
utils
.
chunks
(
tqdm
(
re_ord
.
get_reordered
(),
disable
=
(
disable_tqdm
or
(
self
.
rank
!=
0
))),
...
...
@@ -261,7 +265,7 @@ class Seq2SeqHFLM(LM):
conts
=
[]
encoder_attns
=
[]
cont_toks_list
=
[]
max_batch_length_inp
=
None
max_batch_length_cont
=
None
...
...
@@ -283,33 +287,48 @@ class Seq2SeqHFLM(LM):
).
to
(
self
.
device
)
(
contlen
,)
=
cont
.
shape
max_batch_length_inp
=
max
(
max_batch_length_inp
,
inplen
)
if
max_batch_length_inp
is
not
None
else
inplen
max_batch_length_cont
=
max
(
max_batch_length_cont
,
contlen
)
if
max_batch_length_cont
is
not
None
else
contlen
max_batch_length_inp
=
(
max
(
max_batch_length_inp
,
inplen
)
if
max_batch_length_inp
is
not
None
else
inplen
)
max_batch_length_cont
=
(
max
(
max_batch_length_cont
,
contlen
)
if
max_batch_length_cont
is
not
None
else
contlen
)
inps
.
append
(
inp
)
# [1, inp_len]
conts
.
append
(
cont
)
# [1, cont_len]
conts
.
append
(
cont
)
# [1, cont_len]
encoder_attns
.
append
(
torch
.
ones_like
(
inp
))
cont_toks_list
.
append
(
continuation_enc
)
batched_inps
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
inps
)
# [batch, padding_length]
batched_conts
=
utils
.
pad_and_concat
(
max_batch_length_cont
,
conts
)
# [batch, padding_length]
batched_encoder_mask
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
encoder_attns
)
batched_inps
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
inps
)
# [batch, padding_length]
batched_conts
=
utils
.
pad_and_concat
(
max_batch_length_cont
,
conts
)
# [batch, padding_length]
batched_encoder_mask
=
utils
.
pad_and_concat
(
max_batch_length_inp
,
encoder_attns
)
# need to make attention mask here too
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
,
attn_mask
=
batched_encoder_mask
,
labels
=
batched_conts
),
dim
=-
1
self
.
_model_call
(
batched_inps
,
attn_mask
=
batched_encoder_mask
,
labels
=
batched_conts
),
dim
=-
1
,
).
cpu
()
# [batch, padding_length, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
cont_toks_list
):
# Slice to original seq length
# Slice to original seq length
contlen
=
len
(
cont_toks
)
logits
=
logits
[:
contlen
].
unsqueeze
(
0
)
# [1, seq, vocab]
logits
=
logits
[:
contlen
].
unsqueeze
(
0
)
# [1, seq, vocab]
# Check if per-token argmax is exactly equal to continuation
greedy_tokens
=
logits
.
argmax
(
dim
=-
1
)
...
...
@@ -329,7 +348,7 @@ class Seq2SeqHFLM(LM):
res
.
append
(
answer
)
return
re_ord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
):
res
=
[]
...
...
@@ -370,8 +389,8 @@ class Seq2SeqHFLM(LM):
).
to
(
self
.
device
)
cont
=
self
.
_model_generate
(
context
=
context_enc
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
context
=
context_enc
,
max_length
=
context_enc
.
shape
[
1
]
+
max_gen_toks
,
stop
=
primary_until
,
**
gen_kwargs
,
)
...
...
@@ -383,4 +402,3 @@ class Seq2SeqHFLM(LM):
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
lm_eval/tasks/__init__.py
View file @
b48f5205
...
...
@@ -22,7 +22,7 @@ def include_task_folder(task_dir):
Calling this function
"""
for
root
,
subdirs
,
file_list
in
os
.
walk
(
task_dir
):
if
(
len
(
file_list
)
>
0
)
:
if
len
(
file_list
)
>
0
:
for
f
in
file_list
:
if
f
.
endswith
(
".yaml"
):
yaml_path
=
os
.
path
.
join
(
root
,
f
)
...
...
lm_eval/utils.py
View file @
b48f5205
...
...
@@ -19,6 +19,11 @@ import transformers
from
omegaconf
import
OmegaConf
from
jinja2
import
BaseLoader
,
Environment
,
StrictUndefined
from
itertools
import
islice
<<<<<<<
HEAD
=======
import
transformers
>>>>>>>
more
pre
-
commit
from
lm_eval.logger
import
eval_logger
...
...
@@ -417,6 +422,7 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return
islice
(
raw_iterator
,
rank
,
limit
,
world_size
)
<<<<<<<
HEAD
def
clear_torch_cache
():
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -437,8 +443,17 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
=======
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
"""
Method
for
padding
a
list
of
tensors
given
the
maximum
tensor
length
in
the
batch
.
Used
for
batching
inputs
and
continuations
in
seq2seq
models
.
>>>>>>>
more
pre
-
commit
"""
assert
padding_side
==
"left"
or
padding_side
==
"right"
,
f
"Unrecognized padding type: '
{
padding_side
}
' not 'left' or 'right'"
assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0]
...
...
@@ -446,36 +461,45 @@ def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="ri
if padding_side == "right":
# right-pad
tensors[i] = torch.cat(
[
tensor
,
# [seq]
torch
.
zeros
(
max_length
-
tensor_len
,
dtype
=
torch
.
long
,
device
=
tensor
.
device
,
),
# [padding_length - seq]
],
dim
=
0
,
).
unsqueeze
(
0
)
[
tensor, # [seq]
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
# left-pad
tensors[i] = torch.cat(
[
torch.zeros(
max_length
-
tensor_len
,
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
tensor
,
# [seq]
tensor,
# [seq]
],
dim=0,
).unsqueeze(0)
else:
tensors[i] = tensor.unsqueeze(0)
return
torch
.
cat
(
tensors
,
dim
=
0
)
return torch.cat(tensors, dim
=
0)
<<<<<<< HEAD
# Multi-token stopping criteria
=======
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
# Multi-token stopping criteria
>>>>>>> more pre-commit
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""
Criteria
to
stop
on
the
specified
multi
-
token
sequence
.
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment