Commit ce164bb1 authored by baberabb's avatar baberabb
Browse files

fix imports

parent 667fc837
......@@ -8,11 +8,7 @@ from tqdm import tqdm
from lm_eval.api.registry import register_model
from lm_eval import utils
# TODO: Fix this once complete
try:
from vllm import LLM, SamplingParams
except ModuleNotFoundError:
pass
# flake8: noqa
@register_model("vllm")
......@@ -34,6 +30,7 @@ class VLLM(LM):
max_length: int = None,
):
super().__init__()
from vllm import LLM, SamplingParams
self.model = LLM(
model=pretrained,
......@@ -68,9 +65,17 @@ class VLLM(LM):
def max_gen_toks(self):
return self._max_gen_toks
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=False):
def tok_encode(
self,
string: str,
left_truncate_len=None,
add_special_tokens=False,
truncation=False,
):
""" """
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
encoding = self.tokenizer.encode(
string, add_special_tokens=add_special_tokens, truncation=truncation
)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len:
......@@ -109,14 +114,14 @@ class VLLM(LM):
)
return outputs
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
# end of text as context
context_enc, continuation_enc = [
self.eot_token_id
], self.tokenizer.tok_encode(continuation)
context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
continuation
)
else:
context_enc, continuation_enc = self.tokenizer(
[context, continuation],
......@@ -129,7 +134,7 @@ class VLLM(LM):
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests) -> List[float]:
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests]):
......@@ -256,7 +261,9 @@ class VLLM(LM):
return grouper.get_original(res)
def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False
self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]:
res = []
......@@ -271,7 +278,7 @@ class VLLM(LM):
n=self.batch_size,
fn=None,
)
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)))
pbar = tqdm(total=len(requests), disable=disable_tqdm)
for chunk in chunks:
inps = []
ctxlens = []
......@@ -305,11 +312,11 @@ class VLLM(LM):
return re_ord.get_original(res)
@staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int):
def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
"""Process logprobs and tokens.
:param tokens: list
Tokens from response
Tokens from context+continuations
:param outputs: RequestOutput
Contains prompt
:param ctxlen: int
......@@ -321,13 +328,13 @@ class VLLM(LM):
Whether argmax matches given continuation exactly
"""
# Extract the logprobs for the continuation tokens
# prompt_logprobs = [None, {}*len(context-1)]
continuation_logprobs_dicts = outputs.prompt_logprobs
# Calculate continuation_logprobs
# assume ctxlen always > 1
continuation_logprobs = sum(
logprob_dict.get(token) # Use .get to avoid KeyError and default to 0
logprob_dict.get(token)
for token, logprob_dict in zip(
tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment