Unverified Commit 8051d954 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

Add compatibility for vLLM's new Logprob object (#1549)



* Add compatibility for vLLM's new Logprob object

* Fix

* Update lm_eval/models/vllm_causallms.py

* fix format?

* trailing whitespace

---------
Co-authored-by: default avatarHailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
parent 9e6e2402
...@@ -411,6 +411,26 @@ class VLLM(TemplateLM): ...@@ -411,6 +411,26 @@ class VLLM(TemplateLM):
# The first entry of prompt_logprobs is None because the model has no previous tokens to condition on. # The first entry of prompt_logprobs is None because the model has no previous tokens to condition on.
continuation_logprobs_dicts = outputs.prompt_logprobs continuation_logprobs_dicts = outputs.prompt_logprobs
def coerce_logprob_to_num(logprob):
# vLLM changed the return type of logprobs from float
# to a Logprob object storing the float value + extra data
# (https://github.com/vllm-project/vllm/pull/3065).
# If we are dealing with vllm's Logprob object, return
# the logprob value stored as an attribute. Otherwise,
# return the object itself (which should be a float
# for older versions of vLLM).
return getattr(logprob, "logprob", logprob)
continuation_logprobs_dicts = [
{
token: coerce_logprob_to_num(logprob)
for token, logprob in logprob_dict.items()
}
if logprob_dict is not None
else None
for logprob_dict in continuation_logprobs_dicts
]
# Calculate continuation_logprobs # Calculate continuation_logprobs
# assume ctxlen always >= 1 # assume ctxlen always >= 1
continuation_logprobs = sum( continuation_logprobs = sum(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment