Commit e377c47f authored by Nathan Habib's avatar Nathan Habib
Browse files

linting

parent 84f59a7f
......@@ -13,7 +13,6 @@ from accelerate import (
InitProcessGroupKwargs,
find_executable_batch_size,
)
from accelerate.utils import get_max_memory
from huggingface_hub import HfApi
from packaging import version
from peft import PeftModel
......@@ -680,17 +679,25 @@ class HFLM(TemplateLM):
return None
def _detect_batch_size(self, requests=None, pos: int = 0) -> int:
if len(requests[0]) == 3: # logprob evals
if len(requests[0]) == 3: # logprob evals
_, context_enc, continuation_enc = requests[pos]
max_length = len(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
)
max_context_enc = len(context_enc[-(self.max_length + 1) :])
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
security_margin_factor = 4 # batch sizes for log prob evals sometimes generate OOMs
elif len(requests[0]) == 2: # generative evals
security_margin_factor = (
4 # batch sizes for log prob evals sometimes generate OOMs
)
elif len(requests[0]) == 2: # generative evals
# using rolling window with maximum context
longest_context = max([len(self.tok_encode(request[0])) + request[1].get("max_gen_toks", self.max_length) for request in requests[pos:]])
longest_context = max(
[
len(self.tok_encode(request[0]))
+ request[1].get("max_gen_toks", self.max_length)
for request in requests[pos:]
]
)
if longest_context > self.max_length:
eval_logger.warning(
f"Longest context length of {longest_context} exceeds max_length of {self.max_length}. Truncating to max_length."
......@@ -701,7 +708,6 @@ class HFLM(TemplateLM):
max_cont_enc = max_length
security_margin_factor = 4
# if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size):
......@@ -711,7 +717,9 @@ class HFLM(TemplateLM):
batched_conts = torch.ones(
(batch_size + security_margin, length), device=self.device
).long()
test_batch = torch.ones((batch_size + security_margin, length), device=self.device).long()
test_batch = torch.ones(
(batch_size + security_margin, length), device=self.device
).long()
call_kwargs = {
"attn_mask": test_batch,
"labels": batched_conts,
......@@ -722,7 +730,7 @@ class HFLM(TemplateLM):
(batch_size + security_margin, max_length), device=self.device
).long()
for _ in range(5*security_margin_factor):
for _ in range(5 * security_margin_factor):
logits = self._model_call(inps=test_batch, **call_kwargs).float()
scores = F.log_softmax(logits, dim=-1) # noqa: F841
......@@ -1122,7 +1130,9 @@ class HFLM(TemplateLM):
}
multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1, dtype=torch.float16
self._model_call(batched_inps, **call_kwargs),
dim=-1,
dtype=torch.float16,
) # [batch, padding_length (inp or cont), vocab]
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
......@@ -1200,16 +1210,8 @@ class HFLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests",
)
batch_size = (
self.batch_size
if self.batch_size != "auto"
else 0
)
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto"
else None
)
batch_size = self.batch_size if self.batch_size != "auto" else 0
batch_fn = self._batch_scheduler if self.batch_size == "auto" else None
# we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
......@@ -1221,7 +1223,9 @@ class HFLM(TemplateLM):
group_by="gen_kwargs",
group_fn=lambda x: x[1],
)
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn, reset_batch_fn=self._reset_batch_scheduler)
chunks = re_ords.get_batched(
n=batch_size, batch_fn=batch_fn, reset_batch_fn=self._reset_batch_scheduler
)
for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same
......@@ -1252,7 +1256,9 @@ class HFLM(TemplateLM):
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
if max_gen_toks > self.max_length: # some model have low max length limit
if (
max_gen_toks > self.max_length
): # some model have low max length limit
max_gen_toks = self.max_gen_toks
else:
max_gen_toks = self.max_gen_toks
......
......@@ -389,7 +389,12 @@ class Collator:
self._arr_with_indices, fn=self._group_fn, group_by="contexts"
)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None, reset_batch_fn: Optional[Callable] = None) -> Iterator:
def get_batched(
self,
n: int = 1,
batch_fn: Optional[Callable] = None,
reset_batch_fn: Optional[Callable] = None,
) -> Iterator:
"""
Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`.
......@@ -402,7 +407,7 @@ class Collator:
- n (int): The size of each batch. Defaults to 1.
- batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
each batch. Optional, defaults to None.
- reset_batch_fn ([Callable[[int, Iterable], int]] | None): A function to reset the scheduler of
- reset_batch_fn ([Callable[[int, Iterable], int]] | None): A function to reset the scheduler of
the batch_fn, if present, when we change group in generative mode.
Returns:
......@@ -414,7 +419,9 @@ class Collator:
"""
if self._group_by == "gen_kwargs":
for key, values in self._arr_with_indices.items(): # type: ignore
if reset_batch_fn is not None: # with each group change, we must recompute the batch size, so we restart the scheduler
if (
reset_batch_fn is not None
): # with each group change, we must recompute the batch size, so we restart the scheduler
reset_batch_fn()
values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment