Commit 4d21ab6b authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

address PR comments

parent 9142999e
......@@ -12,7 +12,6 @@ from tqdm import tqdm
import torch
import torch.nn.functional as F
from accelerate import find_executable_batch_size
import gc
from lm_eval.metrics import mean, weighted_perplexity, weighted_mean, bits_per_byte
from lm_eval import utils
......
......@@ -6,6 +6,7 @@ from typing import List, Mapping, NewType, Optional, Tuple, Union
from tqdm import tqdm
from transformers import BatchEncoding
from accelerate import find_executable_batch_size
from lm_eval import utils
from lm_eval.base import BaseLM
......@@ -313,10 +314,30 @@ class HuggingFaceAutoLM(BaseLM):
tokens = self.tok_encode(x[0])
return len(tokens), x[0]
results = []
reorder = utils.Reorderer(requests, _collate)
_, context_enc, continuation_enc = reorder.get_reordered()[0]
max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
adaptive_batch_size = None
if self.batch_size == 'auto':
# using rolling window with maximum context
print('Passed argument batch_size = auto. Detecting largest batch size')
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_context), device=self.device).long()
for _ in range(5):
out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
return batch_size
batch_size = forward_batch()
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
for chunk in utils.chunks(
tqdm(reorder.get_reordered(), disable=False), self.batch_size
tqdm(reorder.get_reordered(), disable=False), self.batch_size if self.batch_size != "auto" else adaptive_batch_size
):
context = [c[0] for c in chunk]
request_args = chunk[0][1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment