Commit 2c8e66d7 authored by mgoin's avatar mgoin
Browse files

Update with Damian's implementation for logits

parent 8941c067
from typing import List, Optional, Tuple, Union
from tqdm import tqdm
import random
import numpy
import torch
import deepsparse
......@@ -9,6 +11,7 @@ from lm_eval.base import BaseLM
class DeepSparseLM(BaseLM):
# Default max sequence length setting for when no `max_length` is provided
_DEFAULT_MAX_LENGTH = 2048
def __init__(
......@@ -20,14 +23,20 @@ class DeepSparseLM(BaseLM):
max_length: Optional[int] = None,
trust_remote_code: Optional[bool] = False,
):
"""
Wrapper around the DeepSparse pipeline to make it compatible with the
llm-evaluation-harness.
"""
super().__init__()
self._batch_size = int(batch_size)
self._max_length = max_length or self._DEFAULT_MAX_LENGTH
self._max_gen_toks = max_gen_toks
# Initialize new model and tokenizer instances
self.model = deepsparse.Pipeline.create(
task="text-generation",
self.model = deepsparse.TextGeneration(
model_path=pretrained,
sequence_length=max_length or self._DEFAULT_MAX_LENGTH,
prompt_sequence_length=16,
sequence_length=self._max_length,
trust_remote_code=trust_remote_code,
batch_size=batch_size,
)
......@@ -35,9 +44,42 @@ class DeepSparseLM(BaseLM):
self.vocab_size = self.tokenizer.vocab_size
self._batch_size = int(batch_size)
self._max_length = max_length
self._max_gen_toks = max_gen_toks
def _model_call(self, inps) -> torch.Tensor:
"""
Override the _model_call method to use the DeepSparse pipeline for
logits generation.
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
# Encode the tokens to strings
prompt = self.model.tokenizer.batch_decode(inps.numpy())
# Run the model to map the prompt to logits
out = self.model(
prompt=prompt,
max_new_tokens=0,
include_prompt_logits=True,
output_scores=True,
)
logits_numpy = numpy.stack([generation.score for generation in out.generations])
return torch.from_numpy(logits_numpy)
def _model_generate(self, context, max_length, eos_token_id):
# Encode the prompt tokens to strings
prompt = self.tokenizer.batch_decode(context.numpy())
# Run generation
out = self.model(
prompt=prompt, max_new_tokens=max_length, force_max_tokens=True
)
# Return tokens for prompt + generated text
return numpy.array(
[self.tokenizer(prompt[0] + out.generations[0].text)["input_ids"]]
)
@property
def eot_token(self) -> str:
......@@ -49,17 +91,7 @@ class DeepSparseLM(BaseLM):
@property
def max_length(self):
if self._max_length: # if max length manually set, return it
return self._max_length
# seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
# for attr in seqlen_config_attrs:
# if hasattr(self.model.config, attr):
# return getattr(self.model.config, attr)
# if hasattr(self.tokenizer, "model_max_length"):
# if self.tokenizer.model_max_length == 1000000000000000019884624838656:
# return self._DEFAULT_MAX_LENGTH
# return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
return self._max_length
@property
def max_gen_toks(self):
......@@ -71,93 +103,11 @@ class DeepSparseLM(BaseLM):
@property
def device(self):
return "cpu"
pass
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
return self.tokenizer.encode(string)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def greedy_until(
self, requests: List[Tuple[str, Union[List[str], str]]]
) -> List[str]:
def _collate(x):
tokens = self.tok_encode(x[0])
return len(tokens), x[0]
results = []
reorder = utils.Reorderer(requests, _collate)
# adaptive_batch_size = None
# if self.batch_size == "auto":
# # using rolling window with maximum context
# print("Passed argument batch_size = auto. Detecting largest batch size")
# batch_size = self._detect_batch_size()
# print(f"Determined Largest batch size: {batch_size}")
# adaptive_batch_size = batch_size
for chunk in utils.chunks(
tqdm(reorder.get_reordered(), disable=False),
self.batch_size,
):
context = [c[0] for c in chunk]
request_args = chunk[0][1]
stop = request_args.get("until", None)
stop_sequences = stop if isinstance(stop, list) else [stop]
max_generation_length = request_args.get("max_length", None)
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
assert isinstance(stop_sequences, list) or stop_sequences is None
# TODO: Find a better way to handle stop sequences for 0-shot.
if stop_sequences is None:
until = [self.eot_token]
else:
until = stop_sequences + [self.eot_token]
if max_generation_length is None:
max_tokens = self.max_gen_toks
else:
max_tokens = max_generation_length
# token_context = self.tok_encode_batch(context)
responses = self.model(
sequences=context,
max_new_tokens=max_tokens,
stop=until,
do_sample=False,
)
responses = responses if type(responses) is list else [responses]
for response in responses:
response = response.generations[0].text
# Ensure the generated responses do not contain the stop sequences.
for term in until:
response = response.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), response)
results.append(response)
return reorder.get_original(results)
def loglikelihood(self, requests):
raise NotImplementedError()
def loglikelihood_rolling(self, requests):
raise NotImplementedError()
def _loglikelihood_tokens(self, requests, disable_tqdm=False):
raise NotImplementedError("No support for logits.")
def _model_call(self, inps):
# Isn't used because we override _loglikelihood_tokens
raise NotImplementedError()
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment