"vscode:/vscode.git/clone" did not exist on "aafb5eb18781f1ac9e06a17c3e53d968dd53dcc0"
Unverified Commit ab2c46c3 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

HF: switch conditional checks to `self.backend` from `AUTO_MODEL_CLASS` (#2353)



* switch conditional checks to `self.backend`

* nit

* nit

* commit feedback

* fix test; update precommit hooks

* add escape hatch for custom self.AUTO_MODEL_CLASS

* add escape hatch for custom self.AUTO_MODEL_CLASS

* fix

* move assertion

* add logging messages

* update AUTO_MODEL_CLASS behavior in _get_backend

---------
Co-authored-by: default avatarhaileyschoelkopf <hailey@eleuther.ai>
parent 4cec66e4
...@@ -55,7 +55,7 @@ class HFLM(TemplateLM): ...@@ -55,7 +55,7 @@ class HFLM(TemplateLM):
def __init__( def __init__(
self, self,
pretrained: Union[str, transformers.PreTrainedModel], pretrained: Union[str, transformers.PreTrainedModel],
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default", backend: Literal["default", "causal", "seq2seq"] = "default",
# override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq) # override whether the model should be treated as decoder-only (causal) or encoder-decoder (seq2seq)
revision: Optional[str] = "main", revision: Optional[str] = "main",
subfolder: Optional[str] = None, subfolder: Optional[str] = None,
...@@ -90,7 +90,6 @@ class HFLM(TemplateLM): ...@@ -90,7 +90,6 @@ class HFLM(TemplateLM):
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
# optionally: take in an already-initialized transformers.PreTrainedModel # optionally: take in an already-initialized transformers.PreTrainedModel
if not isinstance(pretrained, str): if not isinstance(pretrained, str):
eval_logger.warning( eval_logger.warning(
...@@ -164,7 +163,7 @@ class HFLM(TemplateLM): ...@@ -164,7 +163,7 @@ class HFLM(TemplateLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
# determine which of 'causal' and 'seq2seq' backends to use # determine which of 'causal' and 'seq2seq' backends to use for HF models
self._get_backend( self._get_backend(
config=self.config, backend=backend, trust_remote_code=trust_remote_code config=self.config, backend=backend, trust_remote_code=trust_remote_code
) )
...@@ -287,7 +286,7 @@ class HFLM(TemplateLM): ...@@ -287,7 +286,7 @@ class HFLM(TemplateLM):
def _get_accelerate_args( def _get_accelerate_args(
self, self,
parallelize: bool = None, parallelize: Optional[bool] = None,
device_map: Optional[str] = "auto", device_map: Optional[str] = "auto",
max_memory_per_gpu: Optional[Union[int, str]] = None, max_memory_per_gpu: Optional[Union[int, str]] = None,
max_cpu_memory: Optional[Union[int, str]] = None, max_cpu_memory: Optional[Union[int, str]] = None,
...@@ -441,31 +440,26 @@ class HFLM(TemplateLM): ...@@ -441,31 +440,26 @@ class HFLM(TemplateLM):
def _get_backend( def _get_backend(
self, self,
config: Union[transformers.PretrainedConfig, transformers.AutoConfig], config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default", backend: Literal["default", "causal", "seq2seq"] = "default",
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
) -> None: ) -> None:
""" """
Helper method during initialization. Helper method during initialization.
Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) Determines the backend ("causal" (decoder-only) or "seq2seq" (encoder-decoder)) model type to be used.
model type to be used.
sets `self.AUTO_MODEL_CLASS` appropriately if not already set. sets `self.AUTO_MODEL_CLASS` appropriately if not already set.
**If not calling HFLM.__init__() or HFLM._get_backend() within a subclass of HFLM,
user must set `self.backend` to be either "causal" or "seq2seq" manually!**
""" """
# escape hatch: if we're using a subclass that shouldn't follow
# the default _get_backend logic,
# then skip over the method.
# TODO: this seems very much undesirable in some cases--our code in HFLM
# references AutoModelForCausalLM at times to check for equality
if self.AUTO_MODEL_CLASS is not None:
return
assert backend in ["default", "causal", "seq2seq"] assert backend in ["default", "causal", "seq2seq"]
if backend != "default": if backend != "default":
# if we've settled on non-default backend, use that manually # if we've settled on non-default backend, use that manually
if backend == "causal": if backend == "causal":
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM self.backend = backend
elif backend == "seq2seq": elif backend == "seq2seq":
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM self.backend = backend
eval_logger.info( eval_logger.info(
f"Overrode HF model backend type, and using type '{backend}'" f"Overrode HF model backend type, and using type '{backend}'"
) )
...@@ -478,26 +472,32 @@ class HFLM(TemplateLM): ...@@ -478,26 +472,32 @@ class HFLM(TemplateLM):
# first check if model type is listed under seq2seq models, since some # first check if model type is listed under seq2seq models, since some
# models like MBart are listed in both seq2seq and causal mistakenly in HF transformers. # models like MBart are listed in both seq2seq and causal mistakenly in HF transformers.
# these special cases should be treated as seq2seq models. # these special cases should be treated as seq2seq models.
self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM self.backend = "seq2seq"
eval_logger.info(f"Using model type '{backend}'")
elif ( elif (
getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES getattr(self.config, "model_type") in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
): ):
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM self.backend = "causal"
eval_logger.info(f"Using model type '{backend}'")
else: else:
if not trust_remote_code: if not trust_remote_code:
eval_logger.warning( eval_logger.warning(
"HF model type is neither marked as CausalLM or Seq2SeqLM. \ "HF model type is neither marked as CausalLM or Seq2SeqLM. \
This is expected if your model requires `trust_remote_code=True` but may be an error otherwise." This is expected if your model requires `trust_remote_code=True` but may be an error otherwise."
"Setting backend to causal"
) )
# if model type is neither in HF transformers causal or seq2seq model registries # if model type is neither in HF transformers causal or seq2seq model registries
# then we default to AutoModelForCausalLM # then we default to assuming AutoModelForCausalLM
self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM self.backend = "causal"
eval_logger.info(
f"Model type cannot be determined. Using default model type '{backend}'"
)
assert self.AUTO_MODEL_CLASS in [ if self.AUTO_MODEL_CLASS is None:
transformers.AutoModelForCausalLM, if self.backend == "causal":
transformers.AutoModelForSeq2SeqLM, self.AUTO_MODEL_CLASS = transformers.AutoModelForCausalLM
] elif self.backend == "seq2seq":
return None self.AUTO_MODEL_CLASS = transformers.AutoModelForSeq2SeqLM
def _get_config( def _get_config(
self, self,
...@@ -505,6 +505,7 @@ class HFLM(TemplateLM): ...@@ -505,6 +505,7 @@ class HFLM(TemplateLM):
revision: str = "main", revision: str = "main",
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> None: ) -> None:
"""Return the model config for HuggingFace models"""
self._config = transformers.AutoConfig.from_pretrained( self._config = transformers.AutoConfig.from_pretrained(
pretrained, pretrained,
revision=revision, revision=revision,
...@@ -703,7 +704,7 @@ class HFLM(TemplateLM): ...@@ -703,7 +704,7 @@ class HFLM(TemplateLM):
# if OOM, then halves batch_size and tries again # if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size) @find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size): def forward_batch(batch_size):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: if self.backend == "seq2seq":
length = max(max_context_enc, max_cont_enc) length = max(max_context_enc, max_cont_enc)
batched_conts = torch.ones( batched_conts = torch.ones(
(batch_size, length), device=self.device (batch_size, length), device=self.device
...@@ -754,7 +755,7 @@ class HFLM(TemplateLM): ...@@ -754,7 +755,7 @@ class HFLM(TemplateLM):
# by default for CausalLM - false or self.add_bos_token is set # by default for CausalLM - false or self.add_bos_token is set
if add_special_tokens is None: if add_special_tokens is None:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
special_tokens_kwargs = { special_tokens_kwargs = {
"add_special_tokens": False or self.add_bos_token "add_special_tokens": False or self.add_bos_token
} }
...@@ -782,7 +783,7 @@ class HFLM(TemplateLM): ...@@ -782,7 +783,7 @@ class HFLM(TemplateLM):
self.tokenizer.padding_side = padding_side self.tokenizer.padding_side = padding_side
add_special_tokens = {} add_special_tokens = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
add_special_tokens = {"add_special_tokens": False or self.add_bos_token} add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
encoding = self.tokenizer( encoding = self.tokenizer(
...@@ -860,14 +861,14 @@ class HFLM(TemplateLM): ...@@ -860,14 +861,14 @@ class HFLM(TemplateLM):
def _select_cont_toks( def _select_cont_toks(
self, logits: torch.Tensor, contlen: int = None, inplen: int = None self, logits: torch.Tensor, contlen: int = None, inplen: int = None
) -> torch.Tensor: ) -> torch.Tensor:
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
assert ( assert (
contlen and inplen contlen and inplen
), "Must pass input len and cont. len to select scored logits for causal LM" ), "Must pass input len and cont. len to select scored logits for causal LM"
# discard right-padding. # discard right-padding.
# also discard the input/context tokens. we'll only score continuations. # also discard the input/context tokens. we'll only score continuations.
logits = logits[inplen - contlen : inplen] logits = logits[inplen - contlen : inplen]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.backend == "seq2seq":
assert ( assert (
contlen and not inplen contlen and not inplen
), "Selecting scored logits for Seq2SeqLM requires only cont. len" ), "Selecting scored logits for Seq2SeqLM requires only cont. len"
...@@ -990,8 +991,7 @@ class HFLM(TemplateLM): ...@@ -990,8 +991,7 @@ class HFLM(TemplateLM):
requests, requests,
sort_fn=_collate, sort_fn=_collate,
group_by="contexts" group_by="contexts"
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM if self.backend == "causal" and self.logits_cache
and self.logits_cache
else None, else None,
group_fn=_lookup_one_token_cont, group_fn=_lookup_one_token_cont,
) )
...@@ -1048,14 +1048,14 @@ class HFLM(TemplateLM): ...@@ -1048,14 +1048,14 @@ class HFLM(TemplateLM):
# cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice # cont_toks 4 5 6 7 8 9 [:, -len(continuation_enc):, :self.vocab_size] slice
# when too long to fit in context, truncate from the left # when too long to fit in context, truncate from the left
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
inp = torch.tensor( inp = torch.tensor(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1], (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
dtype=torch.long, dtype=torch.long,
device=self.device, device=self.device,
) )
(inplen,) = inp.shape (inplen,) = inp.shape
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.backend == "seq2seq":
inp = torch.tensor( inp = torch.tensor(
(context_enc)[-self.max_length :], (context_enc)[-self.max_length :],
dtype=torch.long, dtype=torch.long,
...@@ -1095,11 +1095,11 @@ class HFLM(TemplateLM): ...@@ -1095,11 +1095,11 @@ class HFLM(TemplateLM):
# create encoder attn mask and batched conts, if seq2seq # create encoder attn mask and batched conts, if seq2seq
call_kwargs = {} call_kwargs = {}
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
batched_inps = pad_and_concat( batched_inps = pad_and_concat(
padding_len_inp, inps, padding_side="right" padding_len_inp, inps, padding_side="right"
) # [batch, padding_len_inp] ) # [batch, padding_len_inp]
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.backend == "seq2seq":
# TODO: left-pad encoder inps and mask? # TODO: left-pad encoder inps and mask?
batched_inps = pad_and_concat( batched_inps = pad_and_concat(
padding_len_inp, inps padding_len_inp, inps
...@@ -1130,7 +1130,7 @@ class HFLM(TemplateLM): ...@@ -1130,7 +1130,7 @@ class HFLM(TemplateLM):
# from prompt/prefix tuning tokens, if applicable # from prompt/prefix tuning tokens, if applicable
ctx_len = ( ctx_len = (
inplen + (logits.shape[0] - padding_len_inp) inplen + (logits.shape[0] - padding_len_inp)
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM if self.backend == "causal"
else None else None
) )
logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len) logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
...@@ -1265,10 +1265,10 @@ class HFLM(TemplateLM): ...@@ -1265,10 +1265,10 @@ class HFLM(TemplateLM):
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: elif self.backend == "seq2seq":
# max len for inputs = encoder's whole max_length # max len for inputs = encoder's whole max_length
max_ctx_len = self.max_length max_ctx_len = self.max_length
...@@ -1295,7 +1295,7 @@ class HFLM(TemplateLM): ...@@ -1295,7 +1295,7 @@ class HFLM(TemplateLM):
cont_toks_list = cont.tolist() cont_toks_list = cont.tolist()
for cont_toks, context in zip(cont_toks_list, contexts): for cont_toks, context in zip(cont_toks_list, contexts):
# discard context + left-padding toks if using causal decoder-only LM # discard context + left-padding toks if using causal decoder-only LM
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.backend == "causal":
cont_toks = cont_toks[context_enc.shape[1] :] cont_toks = cont_toks[context_enc.shape[1] :]
s = self.tok_decode(cont_toks) s = self.tok_decode(cont_toks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment