Commit ce164bb1 authored by baberabb's avatar baberabb
Browse files

fix imports

parent 667fc837
...@@ -8,11 +8,7 @@ from tqdm import tqdm ...@@ -8,11 +8,7 @@ from tqdm import tqdm
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval import utils from lm_eval import utils
# TODO: Fix this once complete # flake8: noqa
try:
from vllm import LLM, SamplingParams
except ModuleNotFoundError:
pass
@register_model("vllm") @register_model("vllm")
...@@ -34,6 +30,7 @@ class VLLM(LM): ...@@ -34,6 +30,7 @@ class VLLM(LM):
max_length: int = None, max_length: int = None,
): ):
super().__init__() super().__init__()
from vllm import LLM, SamplingParams
self.model = LLM( self.model = LLM(
model=pretrained, model=pretrained,
...@@ -68,9 +65,17 @@ class VLLM(LM): ...@@ -68,9 +65,17 @@ class VLLM(LM):
def max_gen_toks(self): def max_gen_toks(self):
return self._max_gen_toks return self._max_gen_toks
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=False): def tok_encode(
self,
string: str,
left_truncate_len=None,
add_special_tokens=False,
truncation=False,
):
""" """ """ """
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens) encoding = self.tokenizer.encode(
string, add_special_tokens=add_special_tokens, truncation=truncation
)
# left-truncate the encoded context to be at most `left_truncate_len` tokens long # left-truncate the encoded context to be at most `left_truncate_len` tokens long
if left_truncate_len: if left_truncate_len:
...@@ -109,14 +114,14 @@ class VLLM(LM): ...@@ -109,14 +114,14 @@ class VLLM(LM):
) )
return outputs return outputs
def loglikelihood(self, requests) -> List[Tuple[float, bool]]: def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
# end of text as context # end of text as context
context_enc, continuation_enc = [ context_enc, continuation_enc = [self.eot_token_id], self.tok_encode(
self.eot_token_id continuation
], self.tokenizer.tok_encode(continuation) )
else: else:
context_enc, continuation_enc = self.tokenizer( context_enc, continuation_enc = self.tokenizer(
[context, continuation], [context, continuation],
...@@ -129,7 +134,7 @@ class VLLM(LM): ...@@ -129,7 +134,7 @@ class VLLM(LM):
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests) -> List[float]: def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm([req.args for req in requests]): for (string,) in tqdm([req.args for req in requests]):
...@@ -256,7 +261,9 @@ class VLLM(LM): ...@@ -256,7 +261,9 @@ class VLLM(LM):
return grouper.get_original(res) return grouper.get_original(res)
def _loglikelihood_tokens( def _loglikelihood_tokens(
self, requests, disable_tqdm: bool = False self,
requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
disable_tqdm: bool = False,
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
res = [] res = []
...@@ -271,7 +278,7 @@ class VLLM(LM): ...@@ -271,7 +278,7 @@ class VLLM(LM):
n=self.batch_size, n=self.batch_size,
fn=None, fn=None,
) )
pbar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0))) pbar = tqdm(total=len(requests), disable=disable_tqdm)
for chunk in chunks: for chunk in chunks:
inps = [] inps = []
ctxlens = [] ctxlens = []
...@@ -305,11 +312,11 @@ class VLLM(LM): ...@@ -305,11 +312,11 @@ class VLLM(LM):
return re_ord.get_original(res) return re_ord.get_original(res)
@staticmethod @staticmethod
def _parse_logprobs(tokens: List, outputs, ctxlen: int): def _parse_logprobs(tokens: List, outputs, ctxlen: int) -> Tuple[float, bool]:
"""Process logprobs and tokens. """Process logprobs and tokens.
:param tokens: list :param tokens: list
Tokens from response Tokens from context+continuations
:param outputs: RequestOutput :param outputs: RequestOutput
Contains prompt Contains prompt
:param ctxlen: int :param ctxlen: int
...@@ -321,13 +328,13 @@ class VLLM(LM): ...@@ -321,13 +328,13 @@ class VLLM(LM):
Whether argmax matches given continuation exactly Whether argmax matches given continuation exactly
""" """
# Extract the logprobs for the continuation tokens # prompt_logprobs = [None, {}*len(context-1)]
continuation_logprobs_dicts = outputs.prompt_logprobs continuation_logprobs_dicts = outputs.prompt_logprobs
# Calculate continuation_logprobs # Calculate continuation_logprobs
# assume ctxlen always > 1
continuation_logprobs = sum( continuation_logprobs = sum(
logprob_dict.get(token) # Use .get to avoid KeyError and default to 0 logprob_dict.get(token)
for token, logprob_dict in zip( for token, logprob_dict in zip(
tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:] tokens[ctxlen:], continuation_logprobs_dicts[ctxlen:]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment