Commit 3a6adddb authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

discard ctx toks for causal greedy_until

parent b48f5205
...@@ -17,7 +17,7 @@ from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria ...@@ -17,7 +17,7 @@ from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator from accelerate import Accelerator
@register_model("hf-auto") @register_model("hf-auto", "hf", "huggingface")
class HFLM(LM): class HFLM(LM):
""" """
An abstracted Huggingface model class. Enables usage with both models of An abstracted Huggingface model class. Enables usage with both models of
...@@ -27,6 +27,7 @@ class HFLM(LM): ...@@ -27,6 +27,7 @@ class HFLM(LM):
""" """
AUTO_MODEL_CLASS = None AUTO_MODEL_CLASS = None
_DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
self, self,
...@@ -34,6 +35,7 @@ class HFLM(LM): ...@@ -34,6 +35,7 @@ class HFLM(LM):
pretrained="gpt2", pretrained="gpt2",
revision="main", revision="main",
low_cpu_mem_usage=None, low_cpu_mem_usage=None,
max_length=None,
subfolder=None, subfolder=None,
tokenizer=None, tokenizer=None,
batch_size=1, batch_size=1,
...@@ -98,6 +100,8 @@ class HFLM(LM): ...@@ -98,6 +100,8 @@ class HFLM(LM):
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
self._max_length = max_length
# multithreading and batching # multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size self.batch_size_per_gpu = batch_size # todo: adaptive batch size
...@@ -105,6 +109,7 @@ class HFLM(LM): ...@@ -105,6 +109,7 @@ class HFLM(LM):
if gpus > 1: if gpus > 1:
accelerator = Accelerator() accelerator = Accelerator()
if gpus > accelerator.num_processes: if gpus > accelerator.num_processes:
# TODO: make sure there's still never an edge case where we unintentionally default to CPU
eval_logger.warning( eval_logger.warning(
"WARNING: The number of total system GPUs does not match the number of spawned processes. " "WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script " "If you would like to use data parallelism, please launch the script "
...@@ -152,11 +157,17 @@ class HFLM(LM): ...@@ -152,11 +157,17 @@ class HFLM(LM):
@property @property
def max_length(self): def max_length(self):
try: if self._max_length: # if max length manually set, return it
return self.model.config.n_ctx return self._max_length
except AttributeError: seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
# gptneoconfig doesn't have n_ctx apparently for attr in seqlen_config_attrs:
return self.model.config.max_position_embeddings if hasattr(self.model.config, attr):
return getattr(self.model.config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property @property
def max_gen_toks(self): def max_gen_toks(self):
...@@ -485,7 +496,6 @@ class HFLM(LM): ...@@ -485,7 +496,6 @@ class HFLM(LM):
until = None until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
print(gen_kwargs)
if "until" in gen_kwargs.keys(): if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until") until = gen_kwargs.pop("until")
if isinstance(until, str): if isinstance(until, str):
...@@ -517,8 +527,9 @@ class HFLM(LM): ...@@ -517,8 +527,9 @@ class HFLM(LM):
max_ctx_len = self.max_length max_ctx_len = self.max_length
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context, left_truncate_len=max_ctx_len)] [self.tok_encode(context, left_truncate_len=max_ctx_len)],
).to(self.device) device=self.device,
)
cont = self._model_generate( cont = self._model_generate(
context=context_enc, context=context_enc,
...@@ -526,7 +537,13 @@ class HFLM(LM): ...@@ -526,7 +537,13 @@ class HFLM(LM):
stop=primary_until, stop=primary_until,
**gen_kwargs, **gen_kwargs,
) )
s = self.tok_decode(cont[0].tolist())
cont_toks_list = cont[0].tolist()
# discard context toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks_list = cont_toks_list[context_enc.shape[1] :]
s = self.tok_decode(cont_toks_list)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until: for term in until:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment