Unverified Commit e0eda4d3 authored by Hailey Schoelkopf's avatar Hailey Schoelkopf Committed by GitHub
Browse files

Refactor `hf` modeling code (#1096)

* modularize HFLM code

* pass through extra kwargs to AutoModel.from_pretrained call

* remove explicit model_kwargs

* rename gptq -> autogptq

* fix tokenizer pad token errors

* ensure model always respects device_map and autogptq's selected devices

* add a _get_config helper fn
parent 5133c9c4
...@@ -23,7 +23,7 @@ from lm_eval.api.registry import register_model ...@@ -23,7 +23,7 @@ from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator, find_executable_batch_size, DistributedType from accelerate import Accelerator, find_executable_batch_size, DistributedType
from typing import List, Optional, Union, Tuple from typing import List, Optional, Union, Tuple, Literal
eval_logger = utils.eval_logger eval_logger = utils.eval_logger
...@@ -67,195 +67,182 @@ class HFLM(LM): ...@@ -67,195 +67,182 @@ class HFLM(LM):
def __init__( def __init__(
self, self,
pretrained: Optional[str] = "gpt2", pretrained: Optional[Union[str, transformers.PreTrainedModel]] = "gpt2",
backend: Optional[
Literal["default", "causal", "seq2seq"]
] = "default", # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision: Optional[str] = "main", revision: Optional[str] = "main",
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
tokenizer: Optional[str] = None, tokenizer: Optional[
Union[
str,
transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast,
]
] = None,
truncation: Optional[bool] = False, truncation: Optional[bool] = False,
max_length: Optional[int] = None, max_length: Optional[int] = None,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto", dtype: Optional[Union[str, torch.dtype]] = "auto",
batch_size: Optional[Union[int, str]] = 1, batch_size: Optional[Union[int, str]] = 1,
max_batch_size: Optional[int] = 64, max_batch_size: Optional[int] = 64,
low_cpu_mem_usage: Optional[bool] = True,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True, use_fast_tokenizer: Optional[bool] = True,
cache_dir: Optional[Union[str, os.PathLike]] = None,
# arguments used for splitting a model across GPUs naively. # arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`. # only used if `parallelize=True`.
parallelize: Optional[bool] = False, parallelize: Optional[bool] = False,
device_map_option: Optional[str] = "auto", device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload", offload_folder: Optional[Union[str, os.PathLike]] = "./offload",
# PEFT and quantization options # PEFT and quantization options
peft: Optional[str] = None, peft: Optional[str] = None,
load_in_8bit: Optional[bool] = False, autogptq: Optional[Union[bool, str]] = False,
load_in_4bit: Optional[bool] = False, **kwargs,
bnb_4bit_quant_type: Optional[str] = None,
bnb_4bit_compute_dtype: Optional[Union[str, torch.dtype]] = None,
gptq: Optional[Union[bool, str]] = False,
gptq_use_triton: Optional[bool] = False,
) -> None: ) -> None:
super().__init__() super().__init__()
assert isinstance(device, str) # optionally: take in an already-initialized transformers.PreTrainedModel
assert isinstance(pretrained, str) if not isinstance(pretrained, str):
assert isinstance(batch_size, (int, str)) eval_logger.warning(
"`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way."
gpus = torch.cuda.device_count()
accelerator = Accelerator()
if not (parallelize or accelerator.num_processes > 1):
# use user-passed device
device_list = set(
["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+ ["mps", "mps:0"]
) )
if device: assert (
if device not in device_list: not parallelize
device = int(device) ), "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
self._device = torch.device(device) self._model = pretrained
eval_logger.info(f"Using device '{device}'") self._device = self._model.device
if device in ("mps", "mps:0") and version.parse(
torch.__version__ self._config = self._model.config
) < version.parse("2.1"):
raise RuntimeError( if tokenizer:
f"mps requires torch >= 2.1. You have {torch.__version__}" assert isinstance(
) tokenizer, transformers.PreTrainedTokenizer
) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
self.tokenizer = tokenizer
else: else:
eval_logger.info("Device not specified") # Get tokenizer
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}") model_name = self._model.name_or_path
self._device = ( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
torch.device("cuda") model_name,
if torch.cuda.is_available() revision=revision,
else torch.device("cpu") trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
) )
else:
if device != "cuda":
eval_logger.info(
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
)
# TODO: include in warning that `load_in_8bit` etc. affect this too
self._device = device
model_kwargs = {}
if parallelize:
model_kwargs = _get_accelerate_args(
device_map_option,
max_memory_per_gpu,
max_cpu_memory,
offload_folder,
)
# TODO: update this to be less of a hack once subfolder is fixed in HF
revision = revision + ("/" + subfolder if subfolder is not None else "")
self._config = transformers.AutoConfig.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
)
if (
getattr(self._config, "model_type")
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
):
# first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models.
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
elif getattr(self._config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else: else:
if not trust_remote_code: assert isinstance(device, str)
eval_logger.warning( assert isinstance(pretrained, str)
"HF model type is neither marked as CausalLM or Seq2SeqLM. \ assert isinstance(batch_size, (int, str))
This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
gpus = torch.cuda.device_count()
accelerator = Accelerator()
if not (parallelize or accelerator.num_processes > 1):
# use user-passed device
device_list = set(
["cuda", "cpu"]
+ [f"cuda:{i}" for i in range(torch.cuda.device_count())]
+ ["mps", "mps:0"]
) )
# if model type is neither in HF transformers causal or seq2seq model registries if device:
# then we default to AutoModelForCausalLM if device not in device_list:
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM device = int(device)
self._device = torch.device(device)
eval_logger.info(f"Using device '{device}'")
if device in ("mps", "mps:0") and version.parse(
torch.__version__
) < version.parse("2.1"):
raise RuntimeError(
f"mps requires torch >= 2.1. You have {torch.__version__}"
)
else:
eval_logger.info("Device not specified")
eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
self._device = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("cpu")
)
else:
if device != "cuda":
eval_logger.info(
f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
)
# TODO: include in warning that `load_in_8bit` etc. affect this too
self._device = device
assert self.AUTO_MODEL_CLASS in [ # TODO: update this to be less of a hack once subfolder is fixed in HF
transformers.AutoModelForCausalLM, revision = revision + ("/" + subfolder if subfolder is not None else "")
transformers.AutoModelForSeq2SeqLM,
]
if not gptq: self._get_config(
if load_in_4bit:
assert (
transformers.__version__ >= "4.30.0"
), "load_in_4bit requires transformers >= 4.30.0"
if transformers.__version__ >= "4.30.0":
model_kwargs["load_in_4bit"] = load_in_4bit
if load_in_4bit:
if bnb_4bit_quant_type:
model_kwargs["bnb_4bit_quant_type"] = bnb_4bit_quant_type
if bnb_4bit_compute_dtype:
model_kwargs["bnb_4bit_compute_dtype"] = utils.get_dtype(
bnb_4bit_compute_dtype
)
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
torch_dtype=utils.get_dtype(dtype),
low_cpu_mem_usage=low_cpu_mem_usage,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
load_in_8bit=load_in_8bit,
**model_kwargs,
) )
else:
try:
from auto_gptq import AutoGPTQForCausalLM
except ModuleNotFoundError:
raise Exception(
"Tried to load auto_gptq, but auto-gptq is not installed ",
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
)
self._model = AutoGPTQForCausalLM.from_quantized( # determine which of 'causal' and 'seq2seq' backends to use
pretrained, self._get_backend(
model_basename=None if gptq is True else Path(gptq).stem, config=self.config, backend=backend, trust_remote_code=trust_remote_code
low_cpu_mem_usage=low_cpu_mem_usage, )
trust_remote_code=trust_remote_code,
use_safetensors=True if gptq is True else gptq.endswith(".safetensors"),
use_triton=gptq_use_triton,
warmup_triton=gptq_use_triton,
**model_kwargs,
)
if peft: # if we passed `pretrained` as a string, initialize our model now
if load_in_4bit: if isinstance(pretrained, str):
assert PEFT_VERSION >= "0.4.0", "load_in_4bit requires peft >= 0.4.0" self._create_model(
self._model = PeftModel.from_pretrained( pretrained=pretrained,
self._model, peft, revision=revision revision=revision,
dtype=dtype,
trust_remote_code=trust_remote_code,
parallelize=parallelize,
device_map_option=device_map_option,
max_memory_per_gpu=max_memory_per_gpu,
max_cpu_memory=max_cpu_memory,
offload_folder=offload_folder,
peft=peft,
autogptq=autogptq,
**kwargs,
) )
# forever after, access self._model through self.model property # access self._model through self.model property outside this method
self.model.eval() self.model.eval()
self.model.tie_weights() self.model.tie_weights()
if gpus <= 1 and not parallelize:
# place model onto device, if not using HF Accelerate in any form
try:
self.model.to(self.device)
except ValueError:
eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained( if gpus >= 1 and isinstance(pretrained, str):
pretrained if tokenizer is None else tokenizer, if not (parallelize or autogptq or ("device_map" in kwargs)):
# place model onto device requested manually,
# if not using HF Accelerate or device_map
# or any other option that preloads model onto device
try:
self.model.to(self.device)
except ValueError:
eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
)
self._create_tokenizer(
pretrained,
tokenizer,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer, use_fast_tokenizer=use_fast_tokenizer,
) )
self.truncation = truncation self.truncation = truncation
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id # select (or create) a pad token to use
if self.tokenizer.pad_token:
pass
elif self.tokenizer.unk_token:
self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
elif self.tokenizer.eos_token:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
else:
if "Qwen" in pretrained:
# Qwen's trust_remote_code tokenizer does not allow for adding special tokens
self.tokenizer.pad_token = "<|endoftext|>"
else:
self.tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
self._max_length = max_length self._max_length = max_length
...@@ -270,57 +257,55 @@ class HFLM(LM): ...@@ -270,57 +257,55 @@ class HFLM(LM):
else: else:
self.batch_size_per_gpu = int(batch_size) self.batch_size_per_gpu = int(batch_size)
# multigpu data-parallel support when launched with accelerate if isinstance(pretrained, str):
if gpus > 1: # multigpu data-parallel support when launched with accelerate
if parallelize: if gpus > 1:
if accelerator.num_processes > 1: if parallelize:
raise RuntimeError( if accelerator.num_processes > 1:
"Attempted to use both a HF Accelerate `device_map` and to launch via `accelerate launch`. If this is the case, please either remove `parallelize=True` from --model_args or launch outside of the Accelerate launcher." raise RuntimeError(
) "Attempted to use both a HF Accelerate `device_map` and to launch via `accelerate launch`. If this is the case, please either remove `parallelize=True` from --model_args or launch outside of the Accelerate launcher."
else: )
pass else:
elif gpus > accelerator.num_processes: pass
# TODO: make sure there's still never an edge case where we unintentionally default to CPU elif accelerator.num_processes == 1:
eval_logger.warning( # if we aren't launching via accelerate, ditch
"WARNING: The number of total system GPUs does not match the number of spawned processes. " self._rank = 0
"If you would like to use data parallelism, please launch the script " self._world_size = 1
"with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices."
)
self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
try:
self.model.to(self.device)
except ValueError:
eval_logger.info(
"Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes`. If the desired GPU is being used, this message is safe to ignore."
)
else:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else: else:
self._model = accelerator.prepare_model( if gpus > accelerator.num_processes:
self.model, evaluation_mode=True eval_logger.warning(
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. "
f"Current run will proceed with {accelerator.num_processes} devices."
)
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
if accelerator.distributed_type == DistributedType.FSDP:
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(
self.model, evaluation_mode=True
)
self._device = torch.device(
f"cuda:{accelerator.local_process_index}"
) )
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self.accelerator = accelerator
self.accelerator = accelerator
if self.accelerator.is_local_main_process: if self.accelerator.is_local_main_process:
eval_logger.info(f"Using {gpus} devices with data parallelism") eval_logger.info(f"Using {gpus} devices with data parallelism")
self._rank = self.accelerator.local_process_index self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes self._world_size = self.accelerator.num_processes
else:
# if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
eval_logger.warning(
"Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
)
self._rank = 0
self._world_size = 1
@property @property
def config(self): def config(self):
...@@ -374,6 +359,208 @@ class HFLM(LM): ...@@ -374,6 +359,208 @@ class HFLM(LM):
def world_size(self): def world_size(self):
return self._world_size return self._world_size
def _get_backend(
self,
config: transformers.AutoConfig,
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
trust_remote_code: Optional[bool] = False,
) -> None:
"""
Helper method during initialization.
Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder))
model type to be used.
"""
assert backend in ["default", "causal", "seq2seq"]
if backend != "default":
# if we've settled on non-default backend, use that manually
if backend == "causal":
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
elif backend == "seq2seq":
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
eval_logger.info(
f"Overrode HF model backend type, and using type '{backend}'"
)
else:
# determine and use the default HF backend for this model, based on its config + metadata.
if (
getattr(config, "model_type")
in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
):
# first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models.
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
elif (
getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
):
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
else:
if not trust_remote_code:
eval_logger.warning(
"HF model type is neither marked as CausalLM or Seq2SeqLM. \
This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
)
# if model type is neither in HF transformers causal or seq2seq model registries
# then we default to AutoModelForCausalLM
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
assert self.AUTO_MODEL_CLASS in [
transformers.AutoModelForCausalLM,
transformers.AutoModelForSeq2SeqLM,
]
return None
def _get_config(
self,
pretrained: str,
revision: str = "main",
trust_remote_code: bool = False,
) -> None:
self._config = transformers.AutoConfig.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
)
def _create_model(
self,
pretrained: str,
revision: Optional[str] = "main",
dtype: Optional[Union[str, torch.dtype]] = "auto",
trust_remote_code: Optional[bool] = False,
# arguments used for splitting a model across GPUs naively.
# only used if `parallelize=True`.
# (accelerate naive PP (device_map) options)
parallelize: Optional[bool] = False,
device_map_option: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None,
offload_folder: Optional[str] = "./offload",
# PEFT and quantization options
peft: Optional[str] = None,
autogptq: Optional[Union[bool, str]] = False,
**kwargs,
) -> None:
"""
Initializes an HF or HF-compatible PreTrainedModel from scratch
inside HFLM, using the kwargs passed into self.__init__().
Also handles functionality such as AutoGPTQ usage and PEFT wrapping.
For future similar extensions to AutoGPTQ that are not core to HF's ecosystem,
(such as PyTorch models that are nearly, but not quite, fully mirroring
HF's public interface relied on in this HFLM class)
please consider subclassing HFLM and overriding this and other methods as needed.
"""
model_kwargs = kwargs if kwargs else {}
if parallelize:
model_kwargs.update(
_get_accelerate_args(
device_map_option,
max_memory_per_gpu,
max_cpu_memory,
offload_folder,
)
)
if not autogptq:
if model_kwargs.get("load_in_4bit", None):
assert (
transformers.__version__ >= "4.30.0"
), "load_in_4bit requires transformers >= 4.30.0"
if transformers.__version__ >= "4.30.0":
if model_kwargs.get("load_in_4bit", None):
if model_kwargs.get("bnb_4bit_compute_dtype", None):
model_kwargs["bnb_4bit_compute_dtype"] = utils.get_dtype(
model_kwargs["bnb_4bit_compute_dtype"]
)
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision,
torch_dtype=utils.get_dtype(dtype),
trust_remote_code=trust_remote_code,
**model_kwargs,
)
else:
try:
from auto_gptq import AutoGPTQForCausalLM
except ModuleNotFoundError:
raise Exception(
"Tried to load auto_gptq, but auto-gptq is not installed ",
"please install auto-gptq via pip install lm-eval[gptq] or pip install -e .[gptq]",
)
self._model = AutoGPTQForCausalLM.from_quantized(
pretrained,
trust_remote_code=trust_remote_code,
model_basename=None if autogptq is True else Path(autogptq).stem,
use_safetensors=True
if autogptq is True
else autogptq.endswith(".safetensors"),
**model_kwargs,
)
if peft:
if model_kwargs.get("load_in_4bit", None):
assert PEFT_VERSION >= "0.4.0", "load_in_4bit requires peft >= 0.4.0"
self._model = PeftModel.from_pretrained(
self._model, peft, revision=revision
)
return None
def _create_tokenizer(
self,
pretrained: Union[str, transformers.PreTrainedModel],
tokenizer: Optional[
Union[
str,
transformers.PreTrainedTokenizer,
transformers.PreTrainedTokenizerFast,
]
],
revision: Optional[str] = "main",
trust_remote_code: Optional[bool] = False,
use_fast_tokenizer: Optional[bool] = True,
) -> None:
"""
Helper method during initialization.
Create a tokenizer object corresponding to the correct
tokenizer for value of `pretrained`, or use the pre-initialized tokenizer passed.
"""
if tokenizer:
if isinstance(tokenizer, str):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
tokenizer,
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
)
else:
assert isinstance(
tokenizer, transformers.PreTrainedTokenizer
) or isinstance(tokenizer, transformers.PreTrainedTokenizerFast)
self.tokenizer = tokenizer
else:
# Get tokenizer based on 'pretrained'
if isinstance(pretrained, str):
model_name = pretrained
else:
# get the HF hub name via accessor on model
model_name = self.model.name_or_path
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name,
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
)
return None
def _detect_batch_size(self, requests=None, pos: int = 0): def _detect_batch_size(self, requests=None, pos: int = 0):
if requests: if requests:
_, context_enc, continuation_enc = requests[pos] _, context_enc, continuation_enc = requests[pos]
...@@ -509,7 +696,7 @@ class HFLM(LM): ...@@ -509,7 +696,7 @@ class HFLM(LM):
def _model_generate(self, context, max_length, stop, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly # we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search. # for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys(): if "do_sample" not in generation_kwargs:
generation_kwargs["do_sample"] = False generation_kwargs["do_sample"] = False
# build stopping criteria # build stopping criteria
stopping_criteria = stop_sequences_criteria( stopping_criteria = stop_sequences_criteria(
...@@ -519,7 +706,7 @@ class HFLM(LM): ...@@ -519,7 +706,7 @@ class HFLM(LM):
input_ids=context, input_ids=context,
max_length=max_length, max_length=max_length,
stopping_criteria=stopping_criteria, stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id, pad_token_id=self.tokenizer.pad_token_id,
use_cache=True, use_cache=True,
**generation_kwargs, **generation_kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment