Commit cec27dad authored by mgoin's avatar mgoin
Browse files

Fix implementation

parent 101b2884
import torch from typing import List, Optional, Tuple, Union
from tqdm import tqdm
import random
import deepsparse import deepsparse
from typing import Optional, Union
from lm_eval import utils
from lm_eval.base import BaseLM from lm_eval.base import BaseLM
class DeepSparseLM(BaseLM): class DeepSparseLM(BaseLM):
_DEFAULT_MAX_LENGTH = 2048 _DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
...@@ -23,7 +26,7 @@ class DeepSparseLM(BaseLM): ...@@ -23,7 +26,7 @@ class DeepSparseLM(BaseLM):
self.model = deepsparse.Pipeline.create( self.model = deepsparse.Pipeline.create(
task="text-generation", task="text-generation",
model_path=pretrained, model_path=pretrained,
sequence_length=max_length or _DEFAULT_MAX_LENGTH, sequence_length=max_length or self._DEFAULT_MAX_LENGTH,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
batch_size=batch_size, batch_size=batch_size,
) )
...@@ -36,8 +39,11 @@ class DeepSparseLM(BaseLM): ...@@ -36,8 +39,11 @@ class DeepSparseLM(BaseLM):
self._max_gen_toks = max_gen_toks self._max_gen_toks = max_gen_toks
@property @property
def eot_token_id(self): def eot_token(self) -> str:
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* return self.tokenizer.eos_token
@property
def eot_token_id(self) -> int:
return self.tokenizer.eos_token_id return self.tokenizer.eos_token_id
@property @property
...@@ -125,6 +131,8 @@ class DeepSparseLM(BaseLM): ...@@ -125,6 +131,8 @@ class DeepSparseLM(BaseLM):
do_sample=False, do_sample=False,
) )
responses = responses if type(responses) is list else [responses]
for response in responses: for response in responses:
response = response.generations[0].text response = response.generations[0].text
# Ensure the generated responses do not contain the stop sequences. # Ensure the generated responses do not contain the stop sequences.
...@@ -136,3 +144,19 @@ class DeepSparseLM(BaseLM): ...@@ -136,3 +144,19 @@ class DeepSparseLM(BaseLM):
return reorder.get_original(results) return reorder.get_original(results)
def loglikelihood(self, requests):
raise NotImplementedError()
def loglikelihood_rolling(self, requests):
raise NotImplementedError()
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
raise NotImplementedError("No support for logits.")
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment