Commit cec27dad authored by mgoin's avatar mgoin
Browse files

Fix implementation

parent 101b2884
import torch
from typing import List, Optional, Tuple, Union
from tqdm import tqdm
import random
import deepsparse
from typing import Optional, Union
from lm_eval import utils
from lm_eval.base import BaseLM
class DeepSparseLM(BaseLM):
_DEFAULT_MAX_LENGTH = 2048
def __init__(
......@@ -23,7 +26,7 @@ class DeepSparseLM(BaseLM):
self.model = deepsparse.Pipeline.create(
task="text-generation",
model_path=pretrained,
sequence_length=max_length or _DEFAULT_MAX_LENGTH,
sequence_length=max_length or self._DEFAULT_MAX_LENGTH,
trust_remote_code=trust_remote_code,
batch_size=batch_size,
)
......@@ -36,8 +39,11 @@ class DeepSparseLM(BaseLM):
self._max_gen_toks = max_gen_toks
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
def eot_token(self) -> str:
return self.tokenizer.eos_token
@property
def eot_token_id(self) -> int:
return self.tokenizer.eos_token_id
@property
......@@ -125,6 +131,8 @@ class DeepSparseLM(BaseLM):
do_sample=False,
)
responses = responses if type(responses) is list else [responses]
for response in responses:
response = response.generations[0].text
# Ensure the generated responses do not contain the stop sequences.
......@@ -136,3 +144,19 @@ class DeepSparseLM(BaseLM):
return reorder.get_original(results)
def loglikelihood(self, requests):
raise NotImplementedError()
def loglikelihood_rolling(self, requests):
raise NotImplementedError()
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
raise NotImplementedError("No support for logits.")
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment