Commit ff1b649e authored by Nathan Habib's avatar Nathan Habib
Browse files

cleanup

parent 3c390c43
...@@ -111,11 +111,13 @@ class HFLM(TemplateLM): ...@@ -111,11 +111,13 @@ class HFLM(TemplateLM):
gpus = torch.cuda.device_count() gpus = torch.cuda.device_count()
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
self.accelerator = accelerator if accelerator.num_processes > 1:
self.accelerator = accelerator
if "npu" in accelerator.device.type: if "npu" in accelerator.device.type:
gpus = torch.npu.device_count() gpus = torch.npu.device_count()
# using one process with no model parallelism
if not (parallelize or accelerator.num_processes > 1): if not (parallelize or accelerator.num_processes > 1):
# use user-passed device # use user-passed device
device_list = set( device_list = set(
...@@ -155,7 +157,7 @@ class HFLM(TemplateLM): ...@@ -155,7 +157,7 @@ class HFLM(TemplateLM):
self._get_config( self._get_config(
pretrained, pretrained,
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
# determine which of 'causal' and 'seq2seq' backends to use # determine which of 'causal' and 'seq2seq' backends to use
...@@ -513,13 +515,11 @@ class HFLM(TemplateLM): ...@@ -513,13 +515,11 @@ class HFLM(TemplateLM):
revision: str = "main", revision: str = "main",
trust_remote_code: bool = False, trust_remote_code: bool = False,
) -> None: ) -> None:
with self.accelerator.main_process_first(): self._config = transformers.AutoConfig.from_pretrained(
self._config = transformers.AutoConfig.from_pretrained( pretrained,
pretrained, revision=revision,
revision=revision, trust_remote_code=trust_remote_code,
trust_remote_code=trust_remote_code, )
force_download=False,
)
def _create_model( def _create_model(
self, self,
...@@ -578,16 +578,13 @@ class HFLM(TemplateLM): ...@@ -578,16 +578,13 @@ class HFLM(TemplateLM):
model_kwargs["bnb_4bit_compute_dtype"] model_kwargs["bnb_4bit_compute_dtype"]
) )
with self.accelerator.main_process_first(): self._model = self.AUTO_MODEL_CLASS.from_pretrained(
#model_kwargs["device_map"] = "balanced_low_0" pretrained,
self._model = self.AUTO_MODEL_CLASS.from_pretrained( revision=revision,
pretrained, torch_dtype=get_dtype(dtype),
revision=revision, trust_remote_code=trust_remote_code,
torch_dtype=get_dtype(dtype), **model_kwargs,
trust_remote_code=trust_remote_code, )
force_download=False,
**model_kwargs,
)
else: else:
try: try:
from auto_gptq import AutoGPTQForCausalLM from auto_gptq import AutoGPTQForCausalLM
...@@ -679,7 +676,6 @@ class HFLM(TemplateLM): ...@@ -679,7 +676,6 @@ class HFLM(TemplateLM):
revision=revision, revision=revision,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer, use_fast=use_fast_tokenizer,
force_download=False
) )
else: else:
assert isinstance( assert isinstance(
...@@ -709,7 +705,7 @@ class HFLM(TemplateLM): ...@@ -709,7 +705,7 @@ class HFLM(TemplateLM):
) )
max_context_enc = len(context_enc[-(self.max_length + 1) :]) max_context_enc = len(context_enc[-(self.max_length + 1) :])
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :]) max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
security_margin_factor = 6 # batch sizes for log prob evals sometimes generate OOMs security_margin_factor = 4 # batch sizes for log prob evals sometimes generate OOMs
elif len(requests[0]) == 2: # generative evals elif len(requests[0]) == 2: # generative evals
# using rolling window with maximum context # using rolling window with maximum context
longest_context = max([len(self.tok_encode(request[0])) + request[1].get("max_gen_toks", self.max_length) for request in requests[pos:]]) longest_context = max([len(self.tok_encode(request[0])) + request[1].get("max_gen_toks", self.max_length) for request in requests[pos:]])
...@@ -721,7 +717,7 @@ class HFLM(TemplateLM): ...@@ -721,7 +717,7 @@ class HFLM(TemplateLM):
max_length = longest_context max_length = longest_context
max_context_enc = max_length max_context_enc = max_length
max_cont_enc = max_length max_cont_enc = max_length
security_margin_factor = 6 security_margin_factor = 4
# if OOM, then halves batch_size and tries again # if OOM, then halves batch_size and tries again
...@@ -751,7 +747,6 @@ class HFLM(TemplateLM): ...@@ -751,7 +747,6 @@ class HFLM(TemplateLM):
return batch_size return batch_size
try: try:
print(f"finding batch size on process {self.accelerator.local_process_index}")
batch_size = forward_batch() batch_size = forward_batch()
except RuntimeError as e: except RuntimeError as e:
if "No executable batch size found" in str(e): if "No executable batch size found" in str(e):
...@@ -762,7 +757,6 @@ class HFLM(TemplateLM): ...@@ -762,7 +757,6 @@ class HFLM(TemplateLM):
if self.world_size > 1: if self.world_size > 1:
# if multi-GPU, always take minimum over all selected batch sizes # if multi-GPU, always take minimum over all selected batch sizes
max_rnk_bs = torch.tensor([batch_size], device=self.device) max_rnk_bs = torch.tensor([batch_size], device=self.device)
print(f"gathering on process {self.accelerator.local_process_index}")
gathered = ( gathered = (
self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist() self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
) )
...@@ -1044,7 +1038,7 @@ class HFLM(TemplateLM): ...@@ -1044,7 +1038,7 @@ class HFLM(TemplateLM):
else None else None
) )
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn, accelerator=self.accelerator) chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm( pbar = tqdm(
total=len(requests), total=len(requests),
disable=(disable_tqdm or (self.rank != 0)), disable=(disable_tqdm or (self.rank != 0)),
...@@ -1064,8 +1058,6 @@ class HFLM(TemplateLM): ...@@ -1064,8 +1058,6 @@ class HFLM(TemplateLM):
# tensors, then we pack them together into a batch, call the model, and then pick it all apart # tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying # again because vectorizing is annoying
from pprint import pprint
for _, context_enc, continuation_enc in chunk: for _, context_enc, continuation_enc in chunk:
# sanity check # sanity check
assert len(context_enc) > 0 assert len(context_enc) > 0
...@@ -1210,8 +1202,6 @@ class HFLM(TemplateLM): ...@@ -1210,8 +1202,6 @@ class HFLM(TemplateLM):
) -> List[str]: ) -> List[str]:
res = [] res = []
self.accelerator.wait_for_everyone()
def _collate(req: Tuple[str, dict]): def _collate(req: Tuple[str, dict]):
"""Defines the key for the sorted method""" """Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages: # the negative sign on len(toks) sorts descending - this has a few advantages:
...@@ -1235,7 +1225,7 @@ class HFLM(TemplateLM): ...@@ -1235,7 +1225,7 @@ class HFLM(TemplateLM):
) )
batch_fn = ( batch_fn = (
self._batch_scheduler self._batch_scheduler
if self.batch_size == "auto" # and not adaptive_batch_size if self.batch_size == "auto"
else None else None
) )
...@@ -1277,9 +1267,10 @@ class HFLM(TemplateLM): ...@@ -1277,9 +1267,10 @@ class HFLM(TemplateLM):
until = [eos] until = [eos]
else: else:
until.append(eos) until.append(eos)
if "max_gen_toks" in kwargs.keys(): if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks") max_gen_toks = kwargs.pop("max_gen_toks")
if max_gen_toks > self.max_length: if max_gen_toks > self.max_length: # some model have low max length limit
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
...@@ -1287,6 +1278,8 @@ class HFLM(TemplateLM): ...@@ -1287,6 +1278,8 @@ class HFLM(TemplateLM):
# set the max length in tokens of inputs ("context_enc") # set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM: if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens # max len for inputs = max length, minus room to generate the max new tokens
# if the max new tokens is too large, halve it until it fits as we cannot change
# the max model length
max_ctx_len = self.max_length - max_gen_toks max_ctx_len = self.max_length - max_gen_toks
while max_ctx_len <= 0: while max_ctx_len <= 0:
max_gen_toks = max_gen_toks // 2 max_gen_toks = max_gen_toks // 2
......
...@@ -389,7 +389,7 @@ class Collator: ...@@ -389,7 +389,7 @@ class Collator:
self._arr_with_indices, fn=self._group_fn, group_by="contexts" self._arr_with_indices, fn=self._group_fn, group_by="contexts"
) )
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None, reset_batch_fn: Optional[Callable] = None, accelerator=None) -> Iterator: def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None, reset_batch_fn: Optional[Callable] = None) -> Iterator:
""" """
Generates and yields batches from the reordered array. The method of grouping and batching Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`. depends on the parameter `group_by`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment