Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
8ad386eb
Commit
8ad386eb
authored
Aug 07, 2023
by
baberabb
Browse files
added logliklihood_rolling and fixed greedy_until
parent
331340ad
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
9 deletions
+37
-9
lm_eval/models/openai_completions.py
lm_eval/models/openai_completions.py
+37
-9
No files found.
lm_eval/models/openai_completions.py
View file @
8ad386eb
...
@@ -134,7 +134,7 @@ class OpenaiCompletionsLM(LM):
...
@@ -134,7 +134,7 @@ class OpenaiCompletionsLM(LM):
continuation_enc
=
whole_enc
[
context_enc_len
:]
continuation_enc
=
whole_enc
[
context_enc_len
:]
return
context_enc
,
continuation_enc
return
context_enc
,
continuation_enc
def
loglikelihood
(
self
,
requests
)
->
List
[
List
[
float
]]:
def
loglikelihood
(
self
,
requests
)
->
List
[
Tuple
[
float
,
bool
]]:
new_reqs
=
[]
new_reqs
=
[]
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
for
context
,
continuation
in
[
req
.
args
for
req
in
requests
]:
if
context
==
""
:
if
context
==
""
:
...
@@ -149,13 +149,15 @@ class OpenaiCompletionsLM(LM):
...
@@ -149,13 +149,15 @@ class OpenaiCompletionsLM(LM):
return
self
.
_loglikelihood_tokens
(
new_reqs
)
return
self
.
_loglikelihood_tokens
(
new_reqs
)
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
)
->
List
[
List
[
float
]]:
def
_loglikelihood_tokens
(
self
,
requests
,
disable_tqdm
=
False
)
->
List
[
Tuple
[
float
,
bool
]]:
res
=
[]
res
=
[]
def
_collate
(
x
):
def
_collate
(
x
):
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
# this doesn't efficiently handle last-token differences yet, but those are kinda annoying because
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# it's not guaranteed that the 100 or so logprobs we get to see actually contain all the continuations
# we care about and so we need some kind of backup for when it isn't
# we care about
,
and so we need some kind of backup for when it isn't
toks
=
x
[
1
]
+
x
[
2
]
toks
=
x
[
1
]
+
x
[
2
]
return
-
len
(
toks
),
tuple
(
toks
)
return
-
len
(
toks
),
tuple
(
toks
)
...
@@ -197,13 +199,13 @@ class OpenaiCompletionsLM(LM):
...
@@ -197,13 +199,13 @@ class OpenaiCompletionsLM(LM):
# partial caching
# partial caching
if
cache_key
is
not
None
:
if
cache_key
is
not
None
:
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
self
.
cache_hook
.
add_partial
(
"loglikelihood"
,
cache_key
,
answer
)
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
greedy_until
(
self
,
requests
)
->
List
[
str
]:
def
greedy_until
(
self
,
requests
)
->
List
[
str
]:
if
not
requests
:
if
not
requests
:
return
[]
return
[]
res
=
[]
res
=
[]
requests
=
[
req
.
args
for
req
in
requests
]
def
_collate
(
x
):
def
_collate
(
x
):
toks
=
self
.
tok_encode
(
x
[
0
])
toks
=
self
.
tok_encode
(
x
[
0
])
...
@@ -253,7 +255,7 @@ class OpenaiCompletionsLM(LM):
...
@@ -253,7 +255,7 @@ class OpenaiCompletionsLM(LM):
for
resp
,
(
context
,
args_
)
in
zip
(
response
.
choices
,
chunk
):
for
resp
,
(
context
,
args_
)
in
zip
(
response
.
choices
,
chunk
):
s
=
resp
[
"text"
]
s
=
resp
[
"text"
]
until_
=
args_
.
get
(
[
"until"
]
,
[])
until_
=
args_
.
get
(
"until"
,
[])
for
term
in
until_
:
for
term
in
until_
:
if
len
(
term
)
>
0
:
if
len
(
term
)
>
0
:
...
@@ -265,7 +267,6 @@ class OpenaiCompletionsLM(LM):
...
@@ -265,7 +267,6 @@ class OpenaiCompletionsLM(LM):
)
)
res
.
append
(
s
)
res
.
append
(
s
)
return
re_ord
.
get_original
(
res
)
return
re_ord
.
get_original
(
res
)
def
_model_call
(
self
,
inps
):
def
_model_call
(
self
,
inps
):
...
@@ -276,6 +277,33 @@ class OpenaiCompletionsLM(LM):
...
@@ -276,6 +277,33 @@ class OpenaiCompletionsLM(LM):
# Isn't used because we override greedy_until
# Isn't used because we override greedy_until
raise
NotImplementedError
()
raise
NotImplementedError
()
def
loglikelihood_rolling
(
self
,
requests
):
def
loglikelihood_rolling
(
self
,
requests
)
->
List
[
float
]:
# Isn't used because we override _loglikelihood_tokens
loglikelihoods
=
[]
raise
NotImplementedError
()
for
(
string
,)
in
tqdm
([
req
.
args
for
req
in
requests
]):
rolling_token_windows
=
list
(
map
(
utils
.
make_disjoint_window
,
utils
.
get_rolling_token_windows
(
token_list
=
self
.
tok_encode
(
string
),
prefix_token
=
self
.
eot_token_id
,
max_seq_len
=
self
.
max_length
,
context_len
=
1
,
),
)
)
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows
=
[(
None
,)
+
x
for
x
in
rolling_token_windows
]
string_nll
=
self
.
_loglikelihood_tokens
(
rolling_token_windows
,
disable_tqdm
=
True
,
)
# discard is_greedy
string_nll
=
[
x
[
0
]
for
x
in
string_nll
]
string_nll
=
sum
(
string_nll
)
loglikelihoods
.
append
(
string_nll
)
return
loglikelihoods
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment