Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1bd6229c
Commit
1bd6229c
authored
Jun 20, 2023
by
haileyschoelkopf
Committed by
lintangsutawika
Jun 22, 2023
Browse files
remove some old code, edge-case seq2seq case
parent
9f36ab18
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
22 deletions
+19
-22
lm_eval/models/hf_merged.py
lm_eval/models/hf_merged.py
+19
-22
No files found.
lm_eval/models/hf_merged.py
View file @
1bd6229c
...
@@ -120,12 +120,16 @@ class HFLM(LM):
...
@@ -120,12 +120,16 @@ class HFLM(LM):
self
.
_rank
=
self
.
accelerator
.
local_process_index
self
.
_rank
=
self
.
accelerator
.
local_process_index
self
.
_world_size
=
self
.
accelerator
.
num_processes
self
.
_world_size
=
self
.
accelerator
.
num_processes
@
property
def
config
(
self
):
# return the associated transformers.AutoConfig for the given pretrained model.
return
self
.
_config
@
property
@
property
def
eot_token_id
(
self
):
def
eot_token_id
(
self
):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return
self
.
tokenizer
.
eos_token_id
return
self
.
tokenizer
.
eos_token_id
# TODO: add a self.config property
# TODO: make model at self._model, have self.model property unwrap accelerator if needed under hood?
# TODO: make model at self._model, have self.model property unwrap accelerator if needed under hood?
@
property
@
property
def
max_length
(
self
):
def
max_length
(
self
):
...
@@ -378,7 +382,8 @@ class HFLM(LM):
...
@@ -378,7 +382,8 @@ class HFLM(LM):
inp
=
torch
.
tensor
(
inp
=
torch
.
tensor
(
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
],
(
context_enc
+
continuation_enc
)[
-
(
self
.
max_length
+
1
)
:][:
-
1
],
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
).
to
(
self
.
device
)
device
=
self
.
device
,
)
(
inplen
,)
=
inp
.
shape
(
inplen
,)
=
inp
.
shape
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
inp
=
torch
.
tensor
(
inp
=
torch
.
tensor
(
...
@@ -388,26 +393,18 @@ class HFLM(LM):
...
@@ -388,26 +393,18 @@ class HFLM(LM):
(
inplen
,)
=
inp
.
shape
(
inplen
,)
=
inp
.
shape
cont
=
torch
.
tensor
(
cont
=
torch
.
tensor
(
(
continuation_enc
)[
-
self
.
max_length
:],
(
continuation_enc
)[
-
self
.
max_length
:],
# TODO: left-shift these?
# TODO: our code assumes we never end up truncating conts for either model type
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
).
to
(
self
.
device
)
).
to
(
self
.
device
)
(
contlen
,)
=
cont
.
shape
(
contlen
,)
=
cont
.
shape
conts
.
append
(
cont
)
padding_len_cont
=
max
(
padding_len_cont
,
contlen
)
if
padding_len_cont
is
not
None
else
contlen
padding_len_cont
=
max
(
padding_len_cont
,
contlen
)
if
padding_len_cont
is
not
None
else
contlen
padding_len_inp
=
max
(
padding_len_inp
,
inplen
)
if
padding_len_inp
is
not
None
else
inplen
padding_len_inp
=
max
(
padding_len_inp
,
inplen
)
if
padding_len_inp
is
not
None
else
inplen
# # pad length from seq to padding_length
# inp = torch.cat(
# [
# inp, # [seq]
# torch.zeros(padding_length - inplen, dtype=torch.long).to(
# inp.device
# ), # [padding_length - seq]
# ],
# dim=0,
# )
inps
.
append
(
inp
)
# [1, inp_length]
inps
.
append
(
inp
)
# [1, inp_length]
cont_toks_list
.
append
(
continuation_enc
)
cont_toks_list
.
append
(
continuation_enc
)
inplens
.
append
(
inplen
)
inplens
.
append
(
inplen
)
...
@@ -415,18 +412,17 @@ class HFLM(LM):
...
@@ -415,18 +412,17 @@ class HFLM(LM):
# create encoder attn mask and batched conts, if seq2seq
# create encoder attn mask and batched conts, if seq2seq
call_kwargs
=
{}
call_kwargs
=
{}
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
:
# batched_inps = torch.cat(inps, dim=0) # [batch, padding_length]
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
,
padding_side
=
"right"
)
# [batch, padding_len_inp]
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
,
padding_side
=
"right"
)
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
elif
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForSeq2SeqLM
:
# TODO: left-pad encoder inps and mask?
# TODO: left-pad encoder inps and mask?
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
)
# [batch,
enc_
padding_len
gth
]
batched_inps
=
utils
.
pad_and_concat
(
padding_len_inp
,
inps
)
# [batch, padding_len
_inp
]
batched_conts
=
utils
.
pad_and_concat
(
padding_len_cont
,
conts
)
# [batch, padding_len
gth
]
batched_conts
=
utils
.
pad_and_concat
(
padding_len_cont
,
conts
)
# [batch, padding_len
_cont
]
batched_encoder_mask
=
utils
.
pad_and_concat
(
padding_len_inp
,
encoder_attns
)
#
size???
batched_encoder_mask
=
utils
.
pad_and_concat
(
padding_len_inp
,
encoder_attns
)
#
[batch, padding_len_inp]
call_kwargs
=
{
"attn_mask"
:
batched_encoder_mask
,
"labels"
:
batched_conts
}
call_kwargs
=
{
"attn_mask"
:
batched_encoder_mask
,
"labels"
:
batched_conts
}
multi_logits
=
F
.
log_softmax
(
multi_logits
=
F
.
log_softmax
(
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
self
.
_model_call
(
batched_inps
,
**
call_kwargs
),
dim
=-
1
).
cpu
()
# [batch, padding_length, vocab]
).
cpu
()
# [batch, padding_length
(inp or cont)
, vocab]
for
(
cache_key
,
_
,
_
),
logits
,
inplen
,
cont_toks
in
zip
(
for
(
cache_key
,
_
,
_
),
logits
,
inplen
,
cont_toks
in
zip
(
chunk
,
multi_logits
,
inplens
,
cont_toks_list
chunk
,
multi_logits
,
inplens
,
cont_toks_list
...
@@ -436,7 +432,8 @@ class HFLM(LM):
...
@@ -436,7 +432,8 @@ class HFLM(LM):
contlen
=
len
(
cont_toks
)
contlen
=
len
(
cont_toks
)
# take only logits in the continuation
# take only logits in the continuation
# (discard context toks if decoder-only ; discard right-padding)
# (discard context toks if decoder-only ; discard right-padding)
logits
=
self
.
_select_cont_toks
(
logits
,
contlen
=
contlen
,
inplen
=
inplen
)
ctx_len
=
inplen
if
self
.
AUTO_MODEL_CLASS
==
transformers
.
AutoModelForCausalLM
else
None
logits
=
self
.
_select_cont_toks
(
logits
,
contlen
=
contlen
,
inplen
=
ctx_len
)
logits
=
logits
.
unsqueeze
(
logits
=
logits
.
unsqueeze
(
0
0
)
# [1, seq, vocab]
)
# [1, seq, vocab]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment