Unverified Commit d21107eb authored by YooSungHyun's avatar YooSungHyun Committed by GitHub
Browse files

support gptQ on Polyglot branch (#1094)

* Update: support gptq

* Update: support GPTQ

* Update: support GPTQ
parent a77f4be9
...@@ -7,7 +7,7 @@ This project provides a unified framework to test generative language models on ...@@ -7,7 +7,7 @@ This project provides a unified framework to test generative language models on
Features: Features:
- 200+ tasks implemented. See the [task-table](./docs/task_table.md) for a complete list. - 200+ tasks implemented. See the [task-table](./docs/task_table.md) for a complete list.
- Support for models loaded via [transformers](https://github.com/huggingface/transformers/), [GPT-NeoX](https://github.com/EleutherAI/gpt-neox), and [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/), with a flexible tokenization-agnostic interface. - Support for models loaded via [transformers](https://github.com/huggingface/transformers/) (including quantization via [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ)), [GPT-NeoX](https://github.com/EleutherAI/gpt-neox), and [Megatron-DeepSpeed](https://github.com/microsoft/Megatron-DeepSpeed/), with a flexible tokenization-agnostic interface.
- Support for commercial APIs including [OpenAI](https://openai.com), [goose.ai](https://goose.ai), and [TextSynth](https://textsynth.com/). - Support for commercial APIs including [OpenAI](https://openai.com), [goose.ai](https://goose.ai), and [TextSynth](https://textsynth.com/).
- Support for evaluation on adapters (e.g. LoRa) supported in [HuggingFace's PEFT library](https://github.com/huggingface/peft). - Support for evaluation on adapters (e.g. LoRa) supported in [HuggingFace's PEFT library](https://github.com/huggingface/peft).
- Evaluating with publicly available prompts ensures reproducibility and comparability between papers. - Evaluating with publicly available prompts ensures reproducibility and comparability between papers.
...@@ -29,6 +29,12 @@ To install additional multilingual tokenization and text segmentation packages, ...@@ -29,6 +29,12 @@ To install additional multilingual tokenization and text segmentation packages,
pip install -e ".[multilingual]" pip install -e ".[multilingual]"
``` ```
To support loading GPTQ quantized models, install the package with the `auto-gptq` extra:
```bash
pip install -e ".[auto-gptq]"
```
## Basic Usage ## Basic Usage
> **Note**: When reporting results from eval harness, please include the task versions (shown in `results["versions"]`) for reproducibility. This allows bug fixes to tasks while also ensuring that previously reported scores are reproducible. See the [Task Versioning](#task-versioning) section for more info. > **Note**: When reporting results from eval harness, please include the task versions (shown in `results["versions"]`) for reproducibility. This allows bug fixes to tasks while also ensuring that previously reported scores are reproducible. See the [Task Versioning](#task-versioning) section for more info.
...@@ -111,6 +117,23 @@ python main.py \ ...@@ -111,6 +117,23 @@ python main.py \
--device cuda:0 --device cuda:0
``` ```
GPTQ quantized models can be loaded by specifying their file names in `,quantized=NAME` (or `,quantized=True` for default names) in the `model_args` argument:
```bash
python main.py \
--model hf-causal-experimental \
--model_args pretrained=model-name-or-path,quantized=model.safetensors,gptq_use_triton=True \
--tasks hellaswag
```
If use multi-gpu and `device_map="auto"`, make sure `use_accelerate=True` in the `model_args` argument:
```bash
python main.py \
--model hf-causal-experimental \
--model_args pretrained=model-name-or-path,use_accelerate=True \
--tasks hellaswag
```
We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`. We support wildcards in task names, for example you can run all of the machine-translated lambada tasks via `--task lambada_openai_mt_*`.
......
...@@ -3,6 +3,7 @@ import torch ...@@ -3,6 +3,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformers import transformers
import peft import peft
from pathlib import Path
from typing import List, Mapping, NewType, Optional, Tuple, Union from typing import List, Mapping, NewType, Optional, Tuple, Union
from tqdm import tqdm from tqdm import tqdm
...@@ -26,10 +27,7 @@ def _get_accelerate_args( ...@@ -26,10 +27,7 @@ def _get_accelerate_args(
"""Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`.""" """Returns the kwargs needed to apply `accelerate` in `AutoModel.from_pretrained`."""
max_memory = {} max_memory = {}
if max_memory_per_gpu is not None: if max_memory_per_gpu is not None:
max_memory_per_gpu_map = { max_memory_per_gpu_map = {device_idx: max_memory_per_gpu for device_idx in range(torch.cuda.device_count())}
device_idx: max_memory_per_gpu
for device_idx in range(torch.cuda.device_count())
}
max_memory.update(max_memory_per_gpu_map) max_memory.update(max_memory_per_gpu_map)
if max_cpu_memory is not None: if max_cpu_memory is not None:
max_memory["cpu"] = max_cpu_memory max_memory["cpu"] = max_cpu_memory
...@@ -42,9 +40,7 @@ def _get_accelerate_args( ...@@ -42,9 +40,7 @@ def _get_accelerate_args(
return args return args
def _get_dtype( def _get_dtype(dtype: Union[str, torch.dtype], config: Optional[transformers.AutoConfig] = None) -> torch.dtype:
dtype: Union[str, torch.dtype], config: Optional[transformers.AutoConfig] = None
) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible.""" """Converts `dtype` from `str` to torch.dtype when possible."""
if dtype is None and config is not None: if dtype is None and config is not None:
_torch_dtype = config.torch_dtype _torch_dtype = config.torch_dtype
...@@ -69,6 +65,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -69,6 +65,7 @@ class HuggingFaceAutoLM(BaseLM):
def __init__( def __init__(
self, self,
pretrained: str, pretrained: str,
quantized: Optional[Union[bool, str]] = None,
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
revision: Optional[str] = "main", revision: Optional[str] = "main",
...@@ -86,6 +83,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -86,6 +83,7 @@ class HuggingFaceAutoLM(BaseLM):
peft: str = None, peft: str = None,
load_in_8bit: Optional[bool] = False, load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
gptq_use_triton: Optional[bool] = False,
): ):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation. """Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation.
Args: Args:
...@@ -93,6 +91,9 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -93,6 +91,9 @@ class HuggingFaceAutoLM(BaseLM):
The HuggingFace Hub model ID name or the path to a pre-trained The HuggingFace Hub model ID name or the path to a pre-trained
model to load. This is effectively the `pretrained_model_name_or_path` model to load. This is effectively the `pretrained_model_name_or_path`
argument of `from_pretrained` in the HuggingFace `transformers` API. argument of `from_pretrained` in the HuggingFace `transformers` API.
quantized (str or True, optional, defaults to None):
File name of a GPTQ quantized model to load. Set to `True` to use the
default name of the quantized model.
add_special_tokens (bool, optional, defaults to True): add_special_tokens (bool, optional, defaults to True):
Whether to add special tokens to the input sequences. If `None`, the Whether to add special tokens to the input sequences. If `None`, the
default value will be set to `True` for seq2seq models (e.g. T5) and default value will be set to `True` for seq2seq models (e.g. T5) and
...@@ -139,16 +140,15 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -139,16 +140,15 @@ class HuggingFaceAutoLM(BaseLM):
https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.load_in_8bit https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.from_pretrained.load_in_8bit
trust_remote_code (bool, optional, defaults to False): trust_remote_code (bool, optional, defaults to False):
If True, will trust the remote code when loading the model. If True, will trust the remote code when loading the model.
gptq_use_triton (bool, optional, defaults to False):
Use Triton for GPTQ inference.
""" """
super().__init__() super().__init__()
assert isinstance(pretrained, str) assert isinstance(pretrained, str)
assert isinstance(device, str) assert isinstance(device, str)
assert isinstance(batch_size, (int, str)) assert isinstance(batch_size, (int, str))
if ( if add_special_tokens is not None and self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM:
add_special_tokens is not None
and self.AUTO_MODEL_CLASS is transformers.AutoModelForCausalLM
):
# TODO: Support evaluating causal models with special tokens. Currently, # TODO: Support evaluating causal models with special tokens. Currently,
# this is not possible because the `_loglikelihood_tokens()` method for # this is not possible because the `_loglikelihood_tokens()` method for
# causal LMs makes a no-special-tokens assumption given that contexts # causal LMs makes a no-special-tokens assumption given that contexts
...@@ -192,10 +192,12 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -192,10 +192,12 @@ class HuggingFaceAutoLM(BaseLM):
model_kwargs["load_in_8bit"] = load_in_8bit model_kwargs["load_in_8bit"] = load_in_8bit
self.model = self._create_auto_model( self.model = self._create_auto_model(
pretrained=pretrained, pretrained=pretrained,
quantized=quantized,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
torch_dtype=_get_dtype(dtype, self._config), torch_dtype=_get_dtype(dtype, self._config),
gptq_use_triton=gptq_use_triton,
**model_kwargs, **model_kwargs,
) )
# note: peft_path can be different than pretrained model path # note: peft_path can be different than pretrained model path
...@@ -224,6 +226,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -224,6 +226,7 @@ class HuggingFaceAutoLM(BaseLM):
self, self,
*, *,
pretrained: str, pretrained: str,
quantized: Optional[Union[bool, str]] = None,
revision: str, revision: str,
subfolder: str, subfolder: str,
device_map: Optional[Union[str, _DeviceMapping]] = None, device_map: Optional[Union[str, _DeviceMapping]] = None,
...@@ -232,18 +235,33 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -232,18 +235,33 @@ class HuggingFaceAutoLM(BaseLM):
load_in_8bit: Optional[bool] = False, load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
torch_dtype: Optional[Union[str, torch.dtype]] = None, torch_dtype: Optional[Union[str, torch.dtype]] = None,
gptq_use_triton: Optional[bool] = False,
) -> transformers.AutoModel: ) -> transformers.AutoModel:
"""Returns a pre-trained pytorch model from a pre-trained model configuration.""" """Returns a pre-trained pytorch model from a pre-trained model configuration."""
model = self.AUTO_MODEL_CLASS.from_pretrained( if quantized is None:
pretrained, model = self.AUTO_MODEL_CLASS.from_pretrained(
revision=revision + ("/" + subfolder if subfolder is not None else ""), pretrained,
device_map=device_map, revision=revision + ("/" + subfolder if subfolder is not None else ""),
max_memory=max_memory, device_map=device_map,
offload_folder=offload_folder, max_memory=max_memory,
load_in_8bit=load_in_8bit, offload_folder=offload_folder,
trust_remote_code=trust_remote_code, load_in_8bit=load_in_8bit,
torch_dtype=torch_dtype, trust_remote_code=trust_remote_code,
) torch_dtype=torch_dtype,
)
else:
from auto_gptq import AutoGPTQForCausalLM
model = AutoGPTQForCausalLM.from_quantized(
pretrained,
model_basename=None if quantized == True else Path(quantized).stem,
device_map=device_map,
max_memory=max_memory,
trust_remote_code=trust_remote_code,
use_safetensors=True if quantized == True else quantized.endswith(".safetensors"),
use_triton=gptq_use_triton,
warmup_triton=gptq_use_triton,
)
return model return model
def _create_auto_model_peft( def _create_auto_model_peft(
...@@ -369,9 +387,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -369,9 +387,7 @@ class HuggingFaceAutoLM(BaseLM):
def tok_decode(self, tokens: torch.LongTensor) -> List[str]: def tok_decode(self, tokens: torch.LongTensor) -> List[str]:
return self.tokenizer.batch_decode(tokens, skip_special_tokens=True) return self.tokenizer.batch_decode(tokens, skip_special_tokens=True)
def greedy_until( def greedy_until(self, requests: List[Tuple[str, Union[List[str], str]]]) -> List[str]:
self, requests: List[Tuple[str, Union[List[str], str]]]
) -> List[str]:
def _collate(x): def _collate(x):
tokens = self.tok_encode(x[0]) tokens = self.tok_encode(x[0])
return len(tokens), x[0] return len(tokens), x[0]
...@@ -384,13 +400,9 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -384,13 +400,9 @@ class HuggingFaceAutoLM(BaseLM):
# using rolling window with maximum context # using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size") print("Passed argument batch_size = auto. Detecting largest batch size")
@find_executable_batch_size( @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
starting_batch_size=512
) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size): def forward_batch(batch_size):
test_batch = torch.ones( test_batch = torch.ones((batch_size, self.max_length), device=self.device).long()
(batch_size, self.max_length), device=self.device
).long()
for _ in range(5): for _ in range(5):
_ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu() _ = F.log_softmax(self._model_call(test_batch), dim=-1).cpu()
return batch_size return batch_size
...@@ -409,9 +421,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -409,9 +421,7 @@ class HuggingFaceAutoLM(BaseLM):
stop_sequences = stop if isinstance(stop, list) else [stop] stop_sequences = stop if isinstance(stop, list) else [stop]
max_generation_length = request_args.get("max_length", None) max_generation_length = request_args.get("max_length", None)
assert ( assert isinstance(max_generation_length, int) or max_generation_length is None
isinstance(max_generation_length, int) or max_generation_length is None
)
assert isinstance(stop_sequences, list) or stop_sequences is None assert isinstance(stop_sequences, list) or stop_sequences is None
# TODO: Find a better way to handle stop sequences for 0-shot. # TODO: Find a better way to handle stop sequences for 0-shot.
...@@ -470,9 +480,7 @@ class AutoCausalLM(HuggingFaceAutoLM): ...@@ -470,9 +480,7 @@ class AutoCausalLM(HuggingFaceAutoLM):
tokenizer.padding_side = "left" tokenizer.padding_side = "left"
return tokenizer return tokenizer
def _model_call( def _model_call(self, inputs: TokenSequence, labels: Optional[TokenSequence] = None) -> TokenSequence:
self, inputs: TokenSequence, labels: Optional[TokenSequence] = None
) -> TokenSequence:
return self.model(inputs)["logits"] return self.model(inputs)["logits"]
def _model_generate( def _model_generate(
...@@ -484,15 +492,11 @@ class AutoCausalLM(HuggingFaceAutoLM): ...@@ -484,15 +492,11 @@ class AutoCausalLM(HuggingFaceAutoLM):
# Ensure that the context does not encroach into the `space` # Ensure that the context does not encroach into the `space`
# for the generation. # for the generation.
input_ids = inputs["input_ids"][:, self.max_gen_toks - self.max_length :] input_ids = inputs["input_ids"][:, self.max_gen_toks - self.max_length :]
attention_mask = inputs["attention_mask"][ attention_mask = inputs["attention_mask"][:, self.max_gen_toks - self.max_length :]
:, self.max_gen_toks - self.max_length :
]
input_ids = input_ids.to(self.device) input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device) attention_mask = attention_mask.to(self.device)
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, input_ids.shape[1], input_ids.shape[0])
self.tokenizer, stop, input_ids.shape[1], input_ids.shape[0]
)
generations = self.model.generate( generations = self.model.generate(
input_ids=input_ids, input_ids=input_ids,
...@@ -527,17 +531,13 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM): ...@@ -527,17 +531,13 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM):
return self._max_length return self._max_length
return self._DEFAULT_MAX_LENGTH return self._DEFAULT_MAX_LENGTH
def loglikelihood( def loglikelihood(self, requests: List[Tuple[str, str]]) -> List[Tuple[float, bool]]:
self, requests: List[Tuple[str, str]]
) -> List[Tuple[float, bool]]:
new_requests = [] new_requests = []
for chunk in utils.chunks(requests, self.batch_size): for chunk in utils.chunks(requests, self.batch_size):
context, continuation = zip(*chunk) context, continuation = zip(*chunk)
# Fill empty contexts with the EOT token. # Fill empty contexts with the EOT token.
context = [ context = [f"{self.eot_token}" if len(text) == 0 else text for text in context]
f"{self.eot_token}" if len(text) == 0 else text for text in context
]
context_enc = self.tok_encode_batch(context) context_enc = self.tok_encode_batch(context)
for key in context_enc: for key in context_enc:
context_enc[key] = context_enc[key][:, -self.max_length :] context_enc[key] = context_enc[key][:, -self.max_length :]
...@@ -550,9 +550,7 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM): ...@@ -550,9 +550,7 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM):
for key in continuation_enc: for key in continuation_enc:
continuation_enc[key] = continuation_enc[key][:, -self.max_length :] continuation_enc[key] = continuation_enc[key][:, -self.max_length :]
new_requests.append( new_requests.append(((context, continuation), context_enc, continuation_enc))
((context, continuation), context_enc, continuation_enc)
)
return self._loglikelihood_tokens(new_requests) return self._loglikelihood_tokens(new_requests)
def loglikelihood_rolling(self, requests: List[Tuple[str, str]]) -> List[float]: def loglikelihood_rolling(self, requests: List[Tuple[str, str]]) -> List[float]:
...@@ -592,12 +590,8 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM): ...@@ -592,12 +590,8 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM):
) )
# TODO: Extract out this call so it only gets called once and also # TODO: Extract out this call so it only gets called once and also
# somehow figure out partial caching for. # somehow figure out partial caching for.
rolling_token_windows_request = [ rolling_token_windows_request = [((contexts, conts), contexts_enc, conts_enc)]
((contexts, conts), contexts_enc, conts_enc) string_nll = self._loglikelihood_tokens(rolling_token_windows_request, disable_tqdm=True)
]
string_nll = self._loglikelihood_tokens(
rolling_token_windows_request, disable_tqdm=True
)
string_nll = [x[0] for x in string_nll] # discard is_greedy string_nll = [x[0] for x in string_nll] # discard is_greedy
string_nll = sum(string_nll) string_nll = sum(string_nll)
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
...@@ -609,9 +603,7 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM): ...@@ -609,9 +603,7 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM):
disable_tqdm: Optional[bool] = False, disable_tqdm: Optional[bool] = False,
) -> List[Tuple[float, bool]]: ) -> List[Tuple[float, bool]]:
results = [] results = []
for chunk in tqdm( for chunk in tqdm(requests, total=math.ceil(len(requests)), disable=disable_tqdm):
requests, total=math.ceil(len(requests)), disable=disable_tqdm
):
cache_keys, inputs_tokens, targets_tokens = chunk cache_keys, inputs_tokens, targets_tokens = chunk
inputs_tokens = inputs_tokens.to(self.device) inputs_tokens = inputs_tokens.to(self.device)
targets_tokens = targets_tokens.to(self.device) targets_tokens = targets_tokens.to(self.device)
...@@ -630,18 +622,14 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM): ...@@ -630,18 +622,14 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM):
target_tokens = target_tokens[:length] target_tokens = target_tokens[:length]
greedy_tokens = log_softmax.argmax(dim=-1) greedy_tokens = log_softmax.argmax(dim=-1)
max_equal = (greedy_tokens == target_tokens).all() max_equal = (greedy_tokens == target_tokens).all()
target_logits = torch.gather( target_logits = torch.gather(log_softmax, 1, target_tokens.unsqueeze(-1)).squeeze(-1)
log_softmax, 1, target_tokens.unsqueeze(-1)
).squeeze(-1)
answer = (float(target_logits.sum()), bool(max_equal)) answer = (float(target_logits.sum()), bool(max_equal))
results.append(answer) results.append(answer)
if cache_key is not None: if cache_key is not None:
self.cache_hook.add_partial("loglikelihood", cache_key, answer) self.cache_hook.add_partial("loglikelihood", cache_key, answer)
return results return results
def _model_call( def _model_call(self, inputs: TokenSequence, labels: Optional[TokenSequence] = None) -> TokenSequence:
self, inputs: TokenSequence, labels: Optional[TokenSequence] = None
) -> TokenSequence:
return self.model(**inputs, labels=labels["input_ids"]) return self.model(**inputs, labels=labels["input_ids"])
def _model_generate( def _model_generate(
...@@ -663,9 +651,7 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM): ...@@ -663,9 +651,7 @@ class AutoSeq2SeqLM(HuggingFaceAutoLM):
# initial_decoder_input_length = len(one_tok_gen) - 1 # initial_decoder_input_length = len(one_tok_gen) - 1
# Assume that there will always only be one token in the decoder inputs, assumption holds for existing HF models # Assume that there will always only be one token in the decoder inputs, assumption holds for existing HF models
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(self.tokenizer, stop, 1, input_ids.shape[0])
self.tokenizer, stop, 1, input_ids.shape[0]
)
generations = self.model.generate( generations = self.model.generate(
input_ids=input_ids, input_ids=input_ids,
...@@ -696,9 +682,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria): ...@@ -696,9 +682,7 @@ class MultiTokenEOSCriteria(transformers.StoppingCriteria):
def __call__(self, input_ids, scores, **kwargs) -> bool: def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence # For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][ lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][:, -self.sequence_id_len :]
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch) lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
...@@ -717,9 +701,7 @@ def stop_sequences_criteria( ...@@ -717,9 +701,7 @@ def stop_sequences_criteria(
return transformers.StoppingCriteriaList( return transformers.StoppingCriteriaList(
[ [
*[ *[
MultiTokenEOSCriteria( MultiTokenEOSCriteria(sequence, tokenizer, initial_decoder_input_length, batch_size)
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences for sequence in stop_sequences
], ],
] ]
......
...@@ -44,5 +44,6 @@ setuptools.setup( ...@@ -44,5 +44,6 @@ setuptools.setup(
"dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"], "dev": ["black", "flake8", "pre-commit", "pytest", "pytest-cov"],
"multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1", "evaluate>=0.4.0"], "multilingual": ["nagisa>=0.2.7", "jieba>=0.42.1", "evaluate>=0.4.0"],
"sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"], "sentencepiece": ["sentencepiece>=0.1.98", "protobuf>=4.22.1"],
"auto-gptq": ["auto-gptq[triton] @ git+https://github.com/PanQiWei/AutoGPTQ"],
}, },
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment