Unverified Commit f4355f93 authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #472 from juletx/patch-1

Fix bugs introduced in #394 #406 and max length bug
parents 9a877197 cede13c6
...@@ -340,6 +340,8 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -340,6 +340,8 @@ class HuggingFaceAutoLM(BaseLM):
if hasattr(self._config, attr): if hasattr(self._config, attr):
return getattr(self._config, attr) return getattr(self._config, attr)
if hasattr(self.tokenizer, "model_max_length"): if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH return self._DEFAULT_MAX_LENGTH
...@@ -371,13 +373,9 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -371,13 +373,9 @@ class HuggingFaceAutoLM(BaseLM):
def _collate(x): def _collate(x):
tokens = self.tok_encode(x[0]) tokens = self.tok_encode(x[0])
return len(tokens), x[0] return len(tokens), x[0]
results = [] results = []
reorder = utils.Reorderer(requests, _collate) reorder = utils.Reorderer(requests, _collate)
_, context_enc, continuation_enc = reorder.get_reordered()[0]
max_context = len((context_enc + continuation_enc)[-(self.max_length + 1) :][:-1])
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == 'auto': if self.batch_size == 'auto':
...@@ -385,7 +383,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -385,7 +383,7 @@ class HuggingFaceAutoLM(BaseLM):
print('Passed argument batch_size = auto. Detecting largest batch size') print('Passed argument batch_size = auto. Detecting largest batch size')
@find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again @find_executable_batch_size(starting_batch_size=512) # if OOM, then halves batch_size and tries again
def forward_batch(batch_size): def forward_batch(batch_size):
test_batch = torch.ones((batch_size, max_context), device=self.device).long() test_batch = torch.ones((batch_size, self.max_length), device=self.device).long()
for _ in range(5): for _ in range(5):
out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu() out = F.log_softmax(self._model_call(test_batch), dim = -1).cpu()
return batch_size return batch_size
...@@ -400,7 +398,7 @@ class HuggingFaceAutoLM(BaseLM): ...@@ -400,7 +398,7 @@ class HuggingFaceAutoLM(BaseLM):
context = [c[0] for c in chunk] context = [c[0] for c in chunk]
request_args = chunk[0][1] request_args = chunk[0][1]
stop = request_args.get('until', None) stop = request_args.get('until', None)
stop_sequences = [stop] if isinstance(stop, list) else stop stop_sequences = stop if isinstance(stop, list) else [stop]
max_generation_length = request_args.get("max_length", None) max_generation_length = request_args.get("max_length", None)
assert ( assert (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment