Commit 0635af13 authored by baberabb's avatar baberabb
Browse files

bugfix

parent 103a10e3
......@@ -104,6 +104,8 @@ class MyCustomLM(LM):
Using this decorator results in the class being added to an accounting of the usable LM types maintained internally to the library at `lm_eval.api.registry.MODEL_REGISTRY`. See `lm_eval.api.registry` for more detail on what sorts of registries and decorators exist in the library!
**Tip: be sure to import your model in `lm_eval/models/__init__.py!`**
## Testing
We also recommend that new model contributions be accompanied by short tests of their 3 core functionalities, at minimum. To see an example of such tests, look at https://github.com/EleutherAI/lm-evaluation-harness/blob/35bdecd379c0cefad6897e67db892f4a6026a128/tests/test_ggml.py .
......
......@@ -3,6 +3,7 @@ from . import openai_completions
from . import textsynth
from . import dummy
from . import anthropic_llms
from . import vllm_causallms
# TODO: implement __all__
......@@ -7,7 +7,7 @@ import copy
from tqdm import tqdm
from lm_eval.api.registry import register_model
from lm_eval import utils
from vllm import LLM, SamplingParams
from vllm_causallms import LLM, SamplingParams
@register_model("vllm")
......@@ -32,6 +32,7 @@ class VLLM(LM):
self.model = LLM(
model=pretrained,
gpu_memory_utilization=0.2,
revision=revision,
dtype=dtype,
tokenizer_mode=tokenizer_mode,
......@@ -298,8 +299,8 @@ class VLLM(LM):
:param tokens: list
Tokens from response
:param outputs: list
Logprobabilities tied to tokens
:param outputs: RequestOutput
Contains prompt
:param ctxlen: int
Length of context (so we can slice them away and only keep the predictions)
:return:
......@@ -309,22 +310,26 @@ class VLLM(LM):
Whether argmax matches given continuation exactly
"""
# Extract the logprobs for the continuation tokens
continuation_logprobs_dicts = outputs.prompt_logprobs
# Calculate continuation_logprobs
continuation_logprobs = sum(
logprob_dict[token]
logprob_dict.get(token) # Use .get to avoid KeyError and default to 0
for token, logprob_dict in zip(
tokens[ctxlen:], outputs.prompt_logprobs[ctxlen:]
tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
)
)
# Determine if is_greedy
is_greedy = True
for i in range(ctxlen, len(tokens)):
token = tokens[i]
top_tokens = outputs[i]
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x])
if top_token != token:
is_greedy = False
break
for token, logprob_dict in zip(tokens[ctxlen:], continuation_logprobs_dicts):
# Get the token with the maximum log probability from the logprob_dict
if logprob_dict: # Ensure the logprob_dict is not None
top_token = max(logprob_dict, key=logprob_dict.get)
if top_token != token:
is_greedy = False
break
return continuation_logprobs, is_greedy
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment