Unverified Commit 101b2884 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

DeepSparseLM Implementation

parent 2c18e367
......@@ -3,6 +3,7 @@ from . import gpt3
from . import anthropic_llms
from . import huggingface
from . import textsynth
from . import deepsparse
from . import dummy
MODEL_REGISTRY = {
......@@ -14,6 +15,7 @@ MODEL_REGISTRY = {
"gpt3": gpt3.GPT3LM,
"anthropic": anthropic_llms.AnthropicLM,
"textsynth": textsynth.TextSynthLM,
"deepsparse": deepsparse.DeepSparseLM,
"dummy": dummy.DummyLM,
}
......
import torch
import deepsparse
from typing import Optional, Union
from lm_eval.base import BaseLM
class DeepSparseLM(BaseLM):
_DEFAULT_MAX_LENGTH = 2048
def __init__(
self,
pretrained: str,
tokenizer: Optional[str] = None,
batch_size: Optional[Union[int, str]] = 1,
max_gen_toks: Optional[int] = 256,
max_length: Optional[int] = None,
trust_remote_code: Optional[bool] = False,
):
super().__init__()
# Initialize new model and tokenizer instances
self.model = deepsparse.Pipeline.create(
task="text-generation",
model_path=pretrained,
sequence_length=max_length or _DEFAULT_MAX_LENGTH,
trust_remote_code=trust_remote_code,
batch_size=batch_size,
)
self.tokenizer = tokenizer if tokenizer else self.model.tokenizer
self.vocab_size = self.tokenizer.vocab_size
self._batch_size = int(batch_size)
self._max_length = max_length
self._max_gen_toks = max_gen_toks
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
if self._max_length: # if max length manually set, return it
return self._max_length
# seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
# for attr in seqlen_config_attrs:
# if hasattr(self.model.config, attr):
# return getattr(self.model.config, attr)
# if hasattr(self.tokenizer, "model_max_length"):
# if self.tokenizer.model_max_length == 1000000000000000019884624838656:
# return self._DEFAULT_MAX_LENGTH
# return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property
def max_gen_toks(self):
return self._max_gen_toks
@property
def batch_size(self):
return self._batch_size
@property
def device(self):
return "cpu"
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def greedy_until(
self, requests: List[Tuple[str, Union[List[str], str]]]
) -> List[str]:
def _collate(x):
tokens = self.tok_encode(x[0])
return len(tokens), x[0]
results = []
reorder = utils.Reorderer(requests, _collate)
# adaptive_batch_size = None
# if self.batch_size == "auto":
# # using rolling window with maximum context
# print("Passed argument batch_size = auto. Detecting largest batch size")
# batch_size = self._detect_batch_size()
# print(f"Determined Largest batch size: {batch_size}")
# adaptive_batch_size = batch_size
for chunk in utils.chunks(
tqdm(reorder.get_reordered(), disable=False),
self.batch_size,
):
context = [c[0] for c in chunk]
request_args = chunk[0][1]
stop = request_args.get("until", None)
stop_sequences = stop if isinstance(stop, list) else [stop]
max_generation_length = request_args.get("max_length", None)
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
assert isinstance(stop_sequences, list) or stop_sequences is None
# TODO: Find a better way to handle stop sequences for 0-shot.
if stop_sequences is None:
until = [self.eot_token]
else:
until = stop_sequences + [self.eot_token]
if max_generation_length is None:
max_tokens = self.max_gen_toks
else:
max_tokens = max_generation_length
# token_context = self.tok_encode_batch(context)
responses = self.model(
sequences=context,
max_new_tokens=max_tokens,
stop=until,
do_sample=False,
)
for response in responses:
response = response.generations[0].text
# Ensure the generated responses do not contain the stop sequences.
for term in until:
response = response.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), response)
results.append(response)
return reorder.get_original(results)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment