Unverified Commit 52e1d456 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Merge pull request #1014 from mgoin/deepsparselm

Add DeepSparseLM
parents 33a215c7 bf783d60
......@@ -81,6 +81,21 @@ To evaluate models that are loaded via `AutoSeq2SeqLM` in Huggingface, you inste
> **Warning**: Choosing the wrong model may result in erroneous outputs despite not erroring.
### Neural Magic `deepsparse`
Models from [SparseZoo](https://sparsezoo.neuralmagic.com/) can be evaluated directly in lm-evaluation-harness using [DeepSparse](https://github.com/neuralmagic/deepsparse):
```bash
pip install deepsparse-nightly[llm]
python main.py --model deepsparse --model_args pretrained=zoo:mpt-7b-gsm8k_mpt_pretrain-pruned70_quantized --tasks gsm8k
```
Compatible models hosted on Hugging Face Hub can be used as well:
```bash
python main.py --model deepsparse --model_args pretrained=hf:mgoin/TinyLlama-1.1B-Chat-v0.3-ds --tasks hellaswag
python main.py --model deepsparse --model_args pretrained=hf:neuralmagic/mpt-7b-gsm8k-pruned60-quant-ds --tasks gsm8k
```
### OpenVINO models converted via HuggingFace Optimum
```bash
python main.py --model optimum-causal --model_args pretrained=<model_path_or_name> --task lambada_openai
......
......@@ -3,6 +3,7 @@ from . import gpt3
from . import anthropic_llms
from . import huggingface
from . import textsynth
from . import deepsparse
from . import dummy
from . import gguf
......@@ -15,6 +16,7 @@ MODEL_REGISTRY = {
"gpt3": gpt3.GPT3LM,
"anthropic": anthropic_llms.AnthropicLM,
"textsynth": textsynth.TextSynthLM,
"deepsparse": deepsparse.DeepSparseLM,
"dummy": dummy.DummyLM,
"gguf": gguf.GGUFLM,
"optimum-causal": gpt2.OPTIMUMLM,
......
from typing import List, Optional, Tuple, Union
from tqdm import tqdm
import random
import numpy
import torch
from lm_eval import utils
from lm_eval.base import BaseLM
class DeepSparseLM(BaseLM):
# Default max sequence length setting for when no `max_length` is provided
_DEFAULT_MAX_LENGTH = 2048
def __init__(
self,
pretrained: str,
tokenizer: Optional[str] = None,
batch_size: Optional[Union[int, str]] = 1,
max_gen_toks: Optional[int] = 256,
max_length: Optional[int] = None,
trust_remote_code: Optional[bool] = False,
):
"""
Wrapper around the DeepSparse pipeline to make it compatible with the
llm-evaluation-harness.
"""
super().__init__()
import deepsparse
self._batch_size = int(batch_size)
self._max_length = max_length or self._DEFAULT_MAX_LENGTH
self._max_gen_toks = max_gen_toks
# Initialize new model and tokenizer instances
self.model = deepsparse.TextGeneration(
model_path=pretrained,
sequence_length=self._max_length,
trust_remote_code=trust_remote_code,
batch_size=batch_size,
)
self.tokenizer = tokenizer if tokenizer else self.model.tokenizer
self.vocab_size = self.tokenizer.vocab_size
def _model_call(self, inps) -> torch.Tensor:
"""
Override the _model_call method to use the DeepSparse pipeline for
logits generation.
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
# Encode the tokens to strings
prompt = self.model.tokenizer.batch_decode(inps.numpy())
# Run the model to map the prompt to logits
out = self.model(
prompt=prompt,
max_new_tokens=0,
include_prompt_logits=True,
output_scores=True,
)
logits_numpy = numpy.stack([generation.score for generation in out.generations])
return torch.from_numpy(logits_numpy)
def greedy_until(
self, requests: List[Tuple[str, Union[List[str], str]]]
) -> List[str]:
def _collate(x):
tokens = self.tok_encode(x[0])
return len(tokens), x[0]
results = []
reorder = utils.Reorderer(requests, _collate)
for chunk in utils.chunks(
tqdm(reorder.get_reordered(), disable=False),
self.batch_size,
):
context = [c[0] for c in chunk]
request_args = chunk[0][1]
stop = request_args.get("until", None)
stop_sequences = stop if isinstance(stop, list) else [stop]
max_generation_length = request_args.get("max_length", None)
assert (
isinstance(max_generation_length, int) or max_generation_length is None
)
assert isinstance(stop_sequences, list) or stop_sequences is None
# TODO: Find a better way to handle stop sequences for 0-shot.
if stop_sequences is None:
until = [self.eot_token]
else:
until = stop_sequences + [self.eot_token]
if max_generation_length is None:
max_tokens = self.max_gen_toks
else:
max_tokens = max_generation_length
responses = self.model(
sequences=context,
max_new_tokens=max_tokens,
stop=until,
do_sample=False,
)
responses = responses if type(responses) is list else [responses]
for response in responses:
response = response.generations[0].text
# Ensure the generated responses do not contain the stop sequences.
for term in until:
response = response.split(term)[0]
# partial caching
self.cache_hook.add_partial("greedy_until", (context, until), response)
results.append(response)
return reorder.get_original(results)
def _model_generate(self, context, max_length, eos_token_id):
# Isn't used because we override greedy_until
raise NotImplementedError()
@property
def eot_token(self) -> str:
return self.tokenizer.eos_token
@property
def eot_token_id(self) -> int:
return self.tokenizer.eos_token_id
@property
def max_length(self):
return self._max_length
@property
def max_gen_toks(self):
return self._max_gen_toks
@property
def batch_size(self):
return self._batch_size
@property
def device(self):
pass
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
......@@ -50,6 +50,7 @@ setuptools.setup(
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
"auto-gptq": ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"],
"anthropic": ["anthropic"],
"deepsparse": ["deepsparse-nightly[llm]"],
"openvino": ["openvino", "nncf", "onnx", "optimum-intel @ git+https://github.com/huggingface/optimum-intel.git"],
},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment