Commit 2f4124fa authored by lintangsutawika's avatar lintangsutawika
Browse files

pre-commit format

parent 804c6ffe
...@@ -51,7 +51,7 @@ pip install -e ".[gptq]" ...@@ -51,7 +51,7 @@ pip install -e ".[gptq]"
## Support ## Support
The best way to get support is to open an issue on this repo or join the EleutherAI discord server](discord.gg/eleutherai). The `#lm-thunderdome` channel is dedicated to developing this project and the `#release-discussion` channel is for recieving support for our releases. The best way to get support is to open an issue on this repo or join the EleutherAI discord server](discord.gg/eleutherai). The `#lm-thunderdome` channel is dedicated to developing this project and the `#release-discussion` channel is for receiving support for our releases.
## Basic Usage ## Basic Usage
......
...@@ -367,7 +367,9 @@ class HFLM(LM): ...@@ -367,7 +367,9 @@ class HFLM(LM):
def forward_batch(batch_size): def forward_batch(batch_size):
if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
length = max(max_context_enc, max_cont_enc) length = max(max_context_enc, max_cont_enc)
batched_conts = torch.ones((batch_size, length), device=self.device).long() batched_conts = torch.ones(
(batch_size, length), device=self.device
).long()
test_batch = torch.ones((batch_size, length), device=self.device).long() test_batch = torch.ones((batch_size, length), device=self.device).long()
call_kwargs = { call_kwargs = {
"attn_mask": test_batch, "attn_mask": test_batch,
...@@ -375,9 +377,13 @@ class HFLM(LM): ...@@ -375,9 +377,13 @@ class HFLM(LM):
} }
else: else:
call_kwargs = {} call_kwargs = {}
test_batch = torch.ones((batch_size, max_length), device=self.device).long() test_batch = torch.ones(
(batch_size, max_length), device=self.device
).long()
for _ in range(5): for _ in range(5):
out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1) out = F.log_softmax(self._model_call(test_batch, **call_kwargs), dim=-1)
out = out # Identity process so that it passes pre-commit
return batch_size return batch_size
batch_size = forward_batch() batch_size = forward_batch()
...@@ -392,11 +398,9 @@ class HFLM(LM): ...@@ -392,11 +398,9 @@ class HFLM(LM):
utils.clear_torch_cache() utils.clear_torch_cache()
return batch_size return batch_size
utils.clear_torch_cache() utils.clear_torch_cache()
return batch_size return batch_size
def tok_encode(self, string: str, left_truncate_len=None): def tok_encode(self, string: str, left_truncate_len=None):
""" """ """ """
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
...@@ -573,7 +577,9 @@ class HFLM(LM): ...@@ -573,7 +577,9 @@ class HFLM(LM):
rolling_token_windows += pad_amnt * [rolling_token_windows[0]] rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
string_nll = self._loglikelihood_tokens( string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True, override_bs=adaptive_batch_size rolling_token_windows,
disable_tqdm=True,
override_bs=adaptive_batch_size,
) )
if (self.world_size > 1) and (pad_amnt > 0): if (self.world_size > 1) and (pad_amnt > 0):
...@@ -607,18 +613,23 @@ class HFLM(LM): ...@@ -607,18 +613,23 @@ class HFLM(LM):
n_reordered_requests = len(re_ord.get_reordered()) n_reordered_requests = len(re_ord.get_reordered())
# automatic (variable) batch size detection for vectorization # automatic (variable) batch size detection for vectorization
# pull longest context sample from request # pull longest context sample from request
def _batch_scheduler(pos): def _batch_scheduler(pos):
sched = pos // int(n_reordered_requests / self.batch_schedule) sched = pos // int(n_reordered_requests / self.batch_schedule)
if sched in self.batch_sizes: if sched in self.batch_sizes:
return self.batch_sizes[sched] return self.batch_sizes[sched]
if (len(self.batch_sizes) > 1) and (self.batch_sizes[sched-1] == self.max_batch_size): if (len(self.batch_sizes) > 1) and (
self.batch_sizes[sched - 1] == self.max_batch_size
):
# if previous batch size is already maximal, skip recomputation # if previous batch size is already maximal, skip recomputation
self.batch_sizes[sched] = self.max_batch_size self.batch_sizes[sched] = self.max_batch_size
return self.batch_sizes[sched] return self.batch_sizes[sched]
print( print(
f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size" f"Passed argument batch_size = auto:{self.batch_schedule}. Detecting largest batch size"
) )
self.batch_sizes[sched] = self._detect_batch_size(re_ord.get_reordered(), pos) self.batch_sizes[sched] = self._detect_batch_size(
re_ord.get_reordered(), pos
)
print(f"Determined largest batch size: {self.batch_sizes[sched]}") print(f"Determined largest batch size: {self.batch_sizes[sched]}")
return self.batch_sizes[sched] return self.batch_sizes[sched]
...@@ -630,7 +641,9 @@ class HFLM(LM): ...@@ -630,7 +641,9 @@ class HFLM(LM):
if override_bs is not None if override_bs is not None
else 0, else 0,
fn=_batch_scheduler fn=_batch_scheduler
if self.batch_size == "auto" and n_reordered_requests > 0 and not override_bs if self.batch_size == "auto"
and n_reordered_requests > 0
and not override_bs
else None, else None,
): ):
inps = [] inps = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment