Commit 6a6068f8 authored by Nathan Habib's avatar Nathan Habib
Browse files

security margins are now constants

parent 548ec466
......@@ -679,6 +679,9 @@ class HFLM(TemplateLM):
return None
def _detect_batch_size(self, requests=None, pos: int = 0) -> int:
SECURITY_MARGIN_FACTOR_LOG_PROBS = 4
SECURITY_MARGIN_FACTOR_GENERATE_UNTIL = 1
if len(requests[0]) == 3: # logprob evals
_, context_enc, continuation_enc = requests[pos]
max_length = len(
......@@ -686,9 +689,7 @@ class HFLM(TemplateLM):
)
max_context_enc = len(context_enc[-(self.max_length + 1) :])
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
security_margin_factor = (
4 # batch sizes for log prob evals sometimes generate OOMs
)
security_margin_factor = SECURITY_MARGIN_FACTOR_LOG_PROBS
elif len(requests[0]) == 2: # generative evals
# using rolling window with maximum context
longest_context = max(
......@@ -706,7 +707,7 @@ class HFLM(TemplateLM):
max_length = longest_context
max_context_enc = max_length
max_cont_enc = max_length
security_margin_factor = 4
security_margin_factor = SECURITY_MARGIN_FACTOR_GENERATE_UNTIL
# if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment