Commit 3168fc00 authored by Benjamin Fattori's avatar Benjamin Fattori
Browse files

autobatching support for enc-dec

parent 81a11d6d
...@@ -342,15 +342,27 @@ class HFLM(LM): ...@@ -342,15 +342,27 @@ class HFLM(LM):
max_length = len( max_length = len(
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1] (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
) )
max_context_enc = len(context_enc[-(self.max_length + 1) :])
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
else: else:
max_length = self.max_length max_length = self.max_length
# if OOM, then halves batch_size and tries again # if OOM, then halves batch_size and tries again
@find_executable_batch_size(starting_batch_size=self.max_batch_size) @find_executable_batch_size(starting_batch_size=self.max_batch_size)
def forward_batch(batch_size): def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_length), device=self.device).long() if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
length = max(max_context_enc, max_cont_enc)
batched_conts = torch.ones((batch_size, length), device=self.device).long()
test_batch = torch.ones((batch_size, length), device=self.device).long()
call_kwargs = {
"attn_mask": test_batch,
"labels": batched_conts,
}
else:
call_kwargs = {}
test_batch = torch.ones((batch_size, max_length), device=self.device).long()
for _ in range(5): for _ in range(5):
out = F.log_softmax(self._model_call(test_batch), dim=-1) out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1)
return batch_size return batch_size
batch_size = forward_batch() batch_size = forward_batch()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment