Commit 0635af13 authored by baberabb's avatar baberabb
Browse files

bugfix

parent 103a10e3
...@@ -104,6 +104,8 @@ class MyCustomLM(LM): ...@@ -104,6 +104,8 @@ class MyCustomLM(LM):
Using this decorator results in the class being added to an accounting of the usable LM types maintained internally to the library at `lm_eval.api.registry.MODEL_REGISTRY`. See `lm_eval.api.registry` for more detail on what sorts of registries and decorators exist in the library! Using this decorator results in the class being added to an accounting of the usable LM types maintained internally to the library at `lm_eval.api.registry.MODEL_REGISTRY`. See `lm_eval.api.registry` for more detail on what sorts of registries and decorators exist in the library!
**Tip: be sure to import your model in `lm_eval/models/__init__.py!`**
## Testing ## Testing
We also recommend that new model contributions be accompanied by short tests of their 3 core functionalities, at minimum. To see an example of such tests, look at https://github.com/EleutherAI/lm-evaluation-harness/blob/35bdecd379c0cefad6897e67db892f4a6026a128/tests/test_ggml.py . We also recommend that new model contributions be accompanied by short tests of their 3 core functionalities, at minimum. To see an example of such tests, look at https://github.com/EleutherAI/lm-evaluation-harness/blob/35bdecd379c0cefad6897e67db892f4a6026a128/tests/test_ggml.py .
......
...@@ -3,6 +3,7 @@ from . import openai_completions ...@@ -3,6 +3,7 @@ from . import openai_completions
from . import textsynth from . import textsynth
from . import dummy from . import dummy
from . import anthropic_llms from . import anthropic_llms
from . import vllm_causallms
# TODO: implement __all__ # TODO: implement __all__
...@@ -7,7 +7,7 @@ import copy ...@@ -7,7 +7,7 @@ import copy
from tqdm import tqdm from tqdm import tqdm
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval import utils from lm_eval import utils
from vllm import LLM, SamplingParams from vllm_causallms import LLM, SamplingParams
@register_model("vllm") @register_model("vllm")
...@@ -32,6 +32,7 @@ class VLLM(LM): ...@@ -32,6 +32,7 @@ class VLLM(LM):
self.model = LLM( self.model = LLM(
model=pretrained, model=pretrained,
gpu_memory_utilization=0.2,
revision=revision, revision=revision,
dtype=dtype, dtype=dtype,
tokenizer_mode=tokenizer_mode, tokenizer_mode=tokenizer_mode,
...@@ -298,8 +299,8 @@ class VLLM(LM): ...@@ -298,8 +299,8 @@ class VLLM(LM):
:param tokens: list :param tokens: list
Tokens from response Tokens from response
:param outputs: list :param outputs: RequestOutput
Logprobabilities tied to tokens Contains prompt
:param ctxlen: int :param ctxlen: int
Length of context (so we can slice them away and only keep the predictions) Length of context (so we can slice them away and only keep the predictions)
:return: :return:
...@@ -309,22 +310,26 @@ class VLLM(LM): ...@@ -309,22 +310,26 @@ class VLLM(LM):
Whether argmax matches given continuation exactly Whether argmax matches given continuation exactly
""" """
# Extract the logprobs for the continuation tokens
continuation_logprobs_dicts = outputs.prompt_logprobs
# Calculate continuation_logprobs # Calculate continuation_logprobs
continuation_logprobs = sum( continuation_logprobs = sum(
logprob_dict[token] logprob_dict.get(token) # Use .get to avoid KeyError and default to 0
for token, logprob_dict in zip( for token, logprob_dict in zip(
tokens[ctxlen:], outputs.prompt_logprobs[ctxlen:] tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
) )
) )
# Determine if is_greedy # Determine if is_greedy
is_greedy = True is_greedy = True
for i in range(ctxlen, len(tokens)): for token, logprob_dict in zip(tokens[ctxlen:], continuation_logprobs_dicts):
token = tokens[i] # Get the token with the maximum log probability from the logprob_dict
top_tokens = outputs[i] if logprob_dict: # Ensure the logprob_dict is not None
top_token = max(top_tokens.keys(), key=lambda x: top_tokens[x]) top_token = max(logprob_dict, key=logprob_dict.get)
if top_token != token: if top_token != token:
is_greedy = False is_greedy = False
break break
return continuation_logprobs, is_greedy return continuation_logprobs, is_greedy
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment