Commit 3a6adddb authored by haileyschoelkopf's avatar haileyschoelkopf Committed by lintangsutawika
Browse files

discard ctx toks for causal greedy_until

parent b48f5205
......@@ -17,7 +17,7 @@ from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator
@register_model("hf-auto")
@register_model("hf-auto", "hf", "huggingface")
class HFLM(LM):
"""
An abstracted Huggingface model class. Enables usage with both models of
......@@ -27,6 +27,7 @@ class HFLM(LM):
"""
AUTO_MODEL_CLASS = None
_DEFAULT_MAX_LENGTH = 2048
def __init__(
self,
......@@ -34,6 +35,7 @@ class HFLM(LM):
pretrained="gpt2",
revision="main",
low_cpu_mem_usage=None,
max_length=None,
subfolder=None,
tokenizer=None,
batch_size=1,
......@@ -98,6 +100,8 @@ class HFLM(LM):
self.vocab_size = self.tokenizer.vocab_size
self._max_length = max_length
# multithreading and batching
self.batch_size_per_gpu = batch_size # todo: adaptive batch size
......@@ -105,6 +109,7 @@ class HFLM(LM):
if gpus > 1:
accelerator = Accelerator()
if gpus > accelerator.num_processes:
# TODO: make sure there's still never an edge case where we unintentionally default to CPU
eval_logger.warning(
"WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script "
......@@ -152,11 +157,17 @@ class HFLM(LM):
@property
def max_length(self):
try:
return self.model.config.n_ctx
except AttributeError:
# gptneoconfig doesn't have n_ctx apparently
return self.model.config.max_position_embeddings
if self._max_length: # if max length manually set, return it
return self._max_length
seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
for attr in seqlen_config_attrs:
if hasattr(self.model.config, attr):
return getattr(self.model.config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property
def max_gen_toks(self):
......@@ -485,7 +496,6 @@ class HFLM(LM):
until = None
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
print(gen_kwargs)
if "until" in gen_kwargs.keys():
until = gen_kwargs.pop("until")
if isinstance(until, str):
......@@ -517,8 +527,9 @@ class HFLM(LM):
max_ctx_len = self.max_length
context_enc = torch.tensor(
[self.tok_encode(context, left_truncate_len=max_ctx_len)]
).to(self.device)
[self.tok_encode(context, left_truncate_len=max_ctx_len)],
device=self.device,
)
cont = self._model_generate(
context=context_enc,
......@@ -526,7 +537,13 @@ class HFLM(LM):
stop=primary_until,
**gen_kwargs,
)
s = self.tok_decode(cont[0].tolist())
cont_toks_list = cont[0].tolist()
# discard context toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
cont_toks_list = cont_toks_list[context_enc.shape[1] :]
s = self.tok_decode(cont_toks_list)
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment