Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
35bdecd3
Unverified
Commit
35bdecd3
authored
Oct 31, 2023
by
Matt Hoffner
Committed by
GitHub
Oct 31, 2023
Browse files
Merge pull request #1 from LorenzoMinto/master
Return score from continuation logprobs
parents
b011af90
9b876402
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
12 deletions
+32
-12
lm_eval/models/ggml.py
lm_eval/models/ggml.py
+32
-12
No files found.
lm_eval/models/ggml.py
View file @
35bdecd3
...
...
@@ -10,6 +10,27 @@ from lm_eval.base import BaseLM
logger
=
logging
.
getLogger
(
__name__
)
def
get_result
(
logprobs
,
context_lenght
):
is_greedy
=
True
offsets
=
logprobs
[
'text_offset'
]
tokens
=
logprobs
[
'tokens'
]
tokens_logprobs
=
logprobs
[
'token_logprobs'
]
idx
=
0
while
offsets
[
idx
]
<
context_lenght
:
idx
+=
1
continuation_logprobs
=
sum
(
tokens_logprobs
[
idx
:
-
1
])
for
i
in
range
(
idx
,
len
(
tokens
)):
token
=
tokens
[
i
]
top_tokens
=
logprobs
[
"top_logprobs"
][
i
]
top_token
=
max
(
top_tokens
.
keys
(),
key
=
lambda
x
:
top_tokens
[
x
])
if
top_token
!=
token
:
is_greedy
=
False
break
return
continuation_logprobs
,
is_greedy
class
GGMLLM
(
BaseLM
):
def
__init__
(
self
,
base_url
,
truncate
=
False
):
super
().
__init__
()
...
...
@@ -17,6 +38,7 @@ class GGMLLM(BaseLM):
self
.
truncate
=
truncate
self
.
tokenizer
=
transformers
.
GPT2TokenizerFast
.
from_pretrained
(
"gpt2"
)
self
.
logpobs
=
10
self
.
temperature
=
0.0
self
.
max_length
=
1024
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
...
...
@@ -24,9 +46,11 @@ class GGMLLM(BaseLM):
for
_
in
range
(
retries
):
try
:
prompt
=
context
request
=
{
'prompt'
:
prompt
,
'logprobs'
:
self
.
logpobs
,
'temperature'
:
self
.
temperature
}
if
continuation
:
prompt
+=
continuation
request
=
{
'prompt'
:
prompt
,
'
logprobs'
:
self
.
logpobs
}
request
.
update
(
{
'prompt'
:
prompt
,
'
max_tokens'
:
1
,
'echo'
:
True
})
if
stop
is
not
None
:
request
[
'stop'
]
=
stop
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/v1/completions"
,
json
=
request
)
...
...
@@ -38,7 +62,6 @@ class GGMLLM(BaseLM):
else
:
raise
Exception
(
f
"Failed to get a valid response after
{
retries
}
retries."
)
def
loglikelihood
(
self
,
requests
):
if
not
requests
:
return
[]
...
...
@@ -49,8 +72,7 @@ class GGMLLM(BaseLM):
choice
=
response
[
"choices"
][
0
]
logprobs
=
choice
.
get
(
"logprobs"
)
if
logprobs
and
"token_logprobs"
in
logprobs
and
logprobs
[
"token_logprobs"
]:
logprob
=
logprobs
[
"token_logprobs"
][
0
]
is_greedy
=
choice
[
"finish_reason"
]
==
"length"
logprob
,
is_greedy
=
get_result
(
logprobs
,
len
(
context
))
res
.
append
((
logprob
,
is_greedy
))
else
:
logger
.
warning
(
"Invalid logprobs data. Expected 'logprobs' to contain 'token_logprobs' list."
)
...
...
@@ -58,7 +80,6 @@ class GGMLLM(BaseLM):
logger
.
error
(
f
"Invalid response for loglikelihood. Response:
{
response
}
"
)
assert
False
return
res
def
greedy_until
(
self
,
requests
):
if
not
requests
:
...
...
@@ -89,16 +110,15 @@ class GGMLLM(BaseLM):
for
request
in
requests
:
logprobs
=
[]
for
i
in
range
(
0
,
len
(
request
),
self
.
max_length
):
chunk
=
request
[
i
:
i
+
self
.
max_length
]
chunk_loglikelihood
=
self
.
loglikelihood
([(
chunk
,
request
[
i
+
1
:
i
+
self
.
max_length
+
1
])])
chunk
=
request
[
i
:
i
+
self
.
max_length
]
chunk_loglikelihood
=
self
.
loglikelihood
([(
chunk
,
request
[
i
+
1
:
i
+
self
.
max_length
+
1
])])
logprobs
.
extend
(
chunk_loglikelihood
)
avg_loglikelihood
=
sum
([
logprob
for
logprob
,
_
in
logprobs
])
/
len
(
logprobs
)
results
.
append
((
avg_loglikelihood
,
True
))
return
results
def
_model_call
(
self
,
inps
):
# Placeholder implementation
raise
NotImplementedError
()
...
...
@@ -112,7 +132,7 @@ class GGMLLM(BaseLM):
def
tok_decode
(
self
,
tokens
):
return
self
.
tokenizer
.
decode
(
tokens
)
@
property
def
batch_size
(
self
):
# Placeholder implementation
...
...
@@ -128,10 +148,10 @@ class GGMLLM(BaseLM):
# Placeholder implementation
raise
NotImplementedError
()
def
max_length
(
self
):
def
max_length
(
self
):
return
self
.
max_length
@
property
def
max_gen_toks
(
self
):
# Placeholder implementation
raise
NotImplementedError
()
\ No newline at end of file
raise
NotImplementedError
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment