Commit e377c47f authored by Nathan Habib's avatar Nathan Habib
Browse files

linting

parent 84f59a7f
...@@ -13,7 +13,6 @@ from accelerate import ( ...@@ -13,7 +13,6 @@ from accelerate import (
InitProcessGroupKwargs, InitProcessGroupKwargs,
find_executable_batch_size, find_executable_batch_size,
) )
from accelerate.utils import get_max_memory
from huggingface_hub import HfApi from huggingface_hub import HfApi
from packaging import version from packaging import version
from peft import PeftModel from peft import PeftModel
...@@ -680,17 +679,25 @@ class HFLM(TemplateLM): ...@@ -680,17 +679,25 @@ class HFLM(TemplateLM):
return None return None
def _detect_batch_size(self, requests=None, pos: int = 0) -> int: def _detect_batch_size(self, requests=None, pos: int = 0) -> int:
if len(requests[0]) == 3: # logprob evals if len(requests[0]) == 3: # logprob evals
_, context_enc, continuation_enc = requests[pos] _, context_enc, continuation_enc = requests[pos]
max_length = len( max_length = len(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
) )
max_context_enc = len(context_enc[-(self.max_length + 1) :]) max_context_enc = len(context_enc[-(self.max_length + 1) :])
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :]) max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
security_margin_factor = 4 # batch sizes for log prob evals sometimes generate OOMs security_margin_factor = (
elif len(requests[0]) == 2: # generative evals 4 # batch sizes for log prob evals sometimes generate OOMs
)
elif len(requests[0]) == 2: # generative evals
# using rolling window with maximum context # using rolling window with maximum context
longest_context = max([len(self.tok_encode(request[0])) + request[1].get("max_gen_toks", self.max_length) for request in requests[pos:]]) longest_context = max(
[
len(self.tok_encode(request[0]))
+ request[1].get("max_gen_toks", self.max_length)
for request in requests[pos:]
]
)
if longest_context > self.max_length: if longest_context > self.max_length:
eval_logger.warning( eval_logger.warning(
f"Longest context length of {longest_context} exceeds max_length of {self.max_length}. Truncating to max_length." f"Longest context length of {longest_context} exceeds max_length of {self.max_length}. Truncating to max_length."
...@@ -701,7 +708,6 @@ class HFLM(TemplateLM): ...@@ -701,7 +708,6 @@ class HFLM(TemplateLM):
max_cont_enc = max_length max_cont_enc = max_length
security_margin_factor = 4 security_margin_factor = 4
# if OOM, then halves batch_size and tries again # if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size) @find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size): def forward_batch(batch_size):
...@@ -711,7 +717,9 @@ class HFLM(TemplateLM): ...@@ -711,7 +717,9 @@ class HFLM(TemplateLM):
batched_conts = torch.ones( batched_conts = torch.ones(
(batch_size + security_margin, length), device=self.device (batch_size + security_margin, length), device=self.device
).long() ).long()
test_batch = torch.ones((batch_size + security_margin, length), device=self.device).long() test_batch = torch.ones(
(batch_size + security_margin, length), device=self.device
).long()
call_kwargs = { call_kwargs = {
"attn_mask": test_batch, "attn_mask": test_batch,
"labels": batched_conts, "labels": batched_conts,
...@@ -722,7 +730,7 @@ class HFLM(TemplateLM): ...@@ -722,7 +730,7 @@ class HFLM(TemplateLM):
(batch_size + security_margin, max_length), device=self.device (batch_size + security_margin, max_length), device=self.device
).long() ).long()
for _ in range(5*security_margin_factor): for _ in range(5 * security_margin_factor):
logits = self._model_call(inps=test_batch, **call_kwargs).float() logits = self._model_call(inps=test_batch, **call_kwargs).float()
scores = F.log_softmax(logits, dim=-1) # noqa: F841 scores = F.log_softmax(logits, dim=-1) # noqa: F841
...@@ -1122,7 +1130,9 @@ class HFLM(TemplateLM): ...@@ -1122,7 +1130,9 @@ class HFLM(TemplateLM):
} }
multi_logits = F.log_softmax( multi_logits = F.log_softmax(
self._model_call(batched_inps, **call_kwargs), dim=-1, dtype=torch.float16 self._model_call(batched_inps, **call_kwargs),
dim=-1,
dtype=torch.float16,
) # [batch, padding_length (inp or cont), vocab] ) # [batch, padding_length (inp or cont), vocab]
for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip( for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
...@@ -1200,16 +1210,8 @@ class HFLM(TemplateLM): ...@@ -1200,16 +1210,8 @@ class HFLM(TemplateLM):
disable=(disable_tqdm or (self.rank != 0)), disable=(disable_tqdm or (self.rank != 0)),
desc="Running generate_until requests", desc="Running generate_until requests",
) )
batch_size = ( batch_size = self.batch_size if self.batch_size != "auto" else 0
self.batch_size batch_fn = self._batch_scheduler if self.batch_size == "auto" else None
if self.batch_size != "auto"
else 0
)
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto"
else None
)
# we group requests by their generation_kwargs, # we group requests by their generation_kwargs,
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
...@@ -1221,7 +1223,9 @@ class HFLM(TemplateLM): ...@@ -1221,7 +1223,9 @@ class HFLM(TemplateLM):
group_by="gen_kwargs", group_by="gen_kwargs",
group_fn=lambda x: x[1], group_fn=lambda x: x[1],
) )
chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn, reset_batch_fn=self._reset_batch_scheduler) chunks = re_ords.get_batched(
n=batch_size, batch_fn=batch_fn, reset_batch_fn=self._reset_batch_scheduler
)
for chunk in chunks: for chunk in chunks:
contexts, all_gen_kwargs = zip(*chunk) contexts, all_gen_kwargs = zip(*chunk)
# we assume all gen kwargs in the batch are the same # we assume all gen kwargs in the batch are the same
...@@ -1252,7 +1256,9 @@ class HFLM(TemplateLM): ...@@ -1252,7 +1256,9 @@ class HFLM(TemplateLM):
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
if max_gen_toks > self.max_length: # some model have low max length limit if (
max_gen_toks > self.max_length
): # some model have low max length limit
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
......
...@@ -389,7 +389,12 @@ class Collator: ...@@ -389,7 +389,12 @@ class Collator:
self._arr_with_indices, fn=self._group_fn, group_by="contexts" self._arr_with_indices, fn=self._group_fn, group_by="contexts"
) )
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None, reset_batch_fn: Optional[Callable] = None) -> Iterator: def get_batched(
self,
n: int = 1,
batch_fn: Optional[Callable] = None,
reset_batch_fn: Optional[Callable] = None,
) -> Iterator:
""" """
Generates and yields batches from the reordered array. The method of grouping and batching Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`. depends on the parameter `group_by`.
...@@ -402,7 +407,7 @@ class Collator: ...@@ -402,7 +407,7 @@ class Collator:
- n (int): The size of each batch. Defaults to 1. - n (int): The size of each batch. Defaults to 1.
- batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of - batch_fn ([Callable[[int, Iterable], int]] | None): A function to determine the size of
each batch. Optional, defaults to None. each batch. Optional, defaults to None.
- reset_batch_fn ([Callable[[int, Iterable], int]] | None): A function to reset the scheduler of - reset_batch_fn ([Callable[[int, Iterable], int]] | None): A function to reset the scheduler of
the batch_fn, if present, when we change group in generative mode. the batch_fn, if present, when we change group in generative mode.
Returns: Returns:
...@@ -414,7 +419,9 @@ class Collator: ...@@ -414,7 +419,9 @@ class Collator:
""" """
if self._group_by == "gen_kwargs": if self._group_by == "gen_kwargs":
for key, values in self._arr_with_indices.items(): # type: ignore for key, values in self._arr_with_indices.items(): # type: ignore
if reset_batch_fn is not None: # with each group change, we must recompute the batch size, so we restart the scheduler if (
reset_batch_fn is not None
): # with each group change, we must recompute the batch size, so we restart the scheduler
reset_batch_fn() reset_batch_fn()
values = self._reorder(values) values = self._reorder(values)
batch = self.get_chunks(values, n=n, fn=batch_fn) batch = self.get_chunks(values, n=n, fn=batch_fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment