Unverified Commit 994bdb3f authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Hf: minor egde cases (#1380)

* edge cases where variable might not be assigned.

* type hint
parent f5408b6b
......@@ -108,8 +108,8 @@ class HFLM(LM):
assert not parallelize, "`parallelize=True` is not compatible with passing pre-initialized model to `pretrained`"
self._model = pretrained
self._device = self._model.device
self._config = self._model.config
gpus = 0
if tokenizer:
assert isinstance(
......@@ -372,7 +372,7 @@ class HFLM(LM):
def _get_backend(
self,
config: transformers.AutoConfig,
config: Union[transformers.PretrainedConfig, transformers.AutoConfig],
backend: Optional[Literal["default", "causal", "seq2seq"]] = "default",
trust_remote_code: Optional[bool] = False,
) -> None:
......@@ -1059,6 +1059,7 @@ class HFLM(LM):
return -len(toks), x[0]
pbar = tqdm(total=len(requests), disable=(self.rank != 0))
adaptive_batch_size = None
if self.batch_size == "auto":
# using rolling window with maximum context
print("Passed argument batch_size = auto. Detecting largest batch size")
......@@ -1103,7 +1104,7 @@ class HFLM(LM):
)
else:
raise ValueError(
f"Expected `kwargs` to be of type `dict` but got {kwargs}"
f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
)
if not until:
until = [self.tok_decode(self.eot_token_id)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment