Commit ff1b649e authored by Nathan Habib's avatar Nathan Habib
Browse files

cleanup

parent 3c390c43
......@@ -111,11 +111,13 @@ class HFLM(TemplateLM):
gpus = torch.cuda.device_count()
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
self.accelerator = accelerator
if accelerator.num_processes > 1:
self.accelerator = accelerator
if "npu" in accelerator.device.type:
gpus = torch.npu.device_count()
# using one process with no model parallelism
if not (parallelize or accelerator.num_processes > 1):
# use user-passed device
device_list = set(
......@@ -155,7 +157,7 @@ class HFLM(TemplateLM):
self._get_config(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
trust_remote_code=trust_remote_code,
)
# determine which of 'causal' and 'seq2seq' backends to use
......@@ -513,13 +515,11 @@ class HFLM(TemplateLM):
revision: str = "main",
trust_remote_code: bool = False,
) -> None:
with self.accelerator.main_process_first():
self._config = transformers.AutoConfig.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
force_download=False,
)
self._config = transformers.AutoConfig.from_pretrained(
pretrained,
revision=revision,
trust_remote_code=trust_remote_code,
)
def _create_model(
self,
......@@ -578,16 +578,13 @@ class HFLM(TemplateLM):
model_kwargs["bnb_4bit_compute_dtype"]
)
with self.accelerator.main_process_first():
#model_kwargs["device_map"] = "balanced_low_0"
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision,
torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code,
force_download=False,
**model_kwargs,
)
self._model = self.AUTO_MODEL_CLASS.from_pretrained(
pretrained,
revision=revision,
torch_dtype=get_dtype(dtype),
trust_remote_code=trust_remote_code,
**model_kwargs,
)
else:
try:
from auto_gptq import AutoGPTQForCausalLM
......@@ -679,7 +676,6 @@ class HFLM(TemplateLM):
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast_tokenizer,
force_download=False
)
else:
assert isinstance(
......@@ -709,7 +705,7 @@ class HFLM(TemplateLM):
)
max_context_enc = len(context_enc[-(self.max_length + 1) :])
max_cont_enc = len(continuation_enc[-(self.max_length + 1) :])
security_margin_factor = 6 # batch sizes for log prob evals sometimes generate OOMs
security_margin_factor = 4 # batch sizes for log prob evals sometimes generate OOMs
elif len(requests[0]) == 2: # generative evals
# using rolling window with maximum context
longest_context = max([len(self.tok_encode(request[0])) + request[1].get("max_gen_toks", self.max_length) for request in requests[pos:]])
......@@ -721,7 +717,7 @@ class HFLM(TemplateLM):
max_length = longest_context
max_context_enc = max_length
max_cont_enc = max_length
security_margin_factor = 6
security_margin_factor = 4
# if OOM, then halves batch_size and tries again
......@@ -751,7 +747,6 @@ class HFLM(TemplateLM):
return batch_size
try:
print(f"finding batch size on process {self.accelerator.local_process_index}")
batch_size = forward_batch()
except RuntimeError as e:
if "No executable batch size found" in str(e):
......@@ -762,7 +757,6 @@ class HFLM(TemplateLM):
if self.world_size > 1:
# if multi-GPU, always take minimum over all selected batch sizes
max_rnk_bs = torch.tensor([batch_size], device=self.device)
print(f"gathering on process {self.accelerator.local_process_index}")
gathered = (
self.accelerator.gather(max_rnk_bs).cpu().detach().numpy().tolist()
)
......@@ -1044,7 +1038,7 @@ class HFLM(TemplateLM):
else None
)
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn, accelerator=self.accelerator)
chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
pbar = tqdm(
total=len(requests),
disable=(disable_tqdm or (self.rank != 0)),
......@@ -1064,8 +1058,6 @@ class HFLM(TemplateLM):
# tensors, then we pack them together into a batch, call the model, and then pick it all apart
# again because vectorizing is annoying
from pprint import pprint
for _, context_enc, continuation_enc in chunk:
# sanity check
assert len(context_enc) > 0
......@@ -1210,8 +1202,6 @@ class HFLM(TemplateLM):
) -> List[str]:
res = []
self.accelerator.wait_for_everyone()
def _collate(req: Tuple[str, dict]):
"""Defines the key for the sorted method"""
# the negative sign on len(toks) sorts descending - this has a few advantages:
......@@ -1235,7 +1225,7 @@ class HFLM(TemplateLM):
)
batch_fn = (
self._batch_scheduler
if self.batch_size == "auto" # and not adaptive_batch_size
if self.batch_size == "auto"
else None
)
......@@ -1277,9 +1267,10 @@ class HFLM(TemplateLM):
until = [eos]
else:
until.append(eos)
if "max_gen_toks" in kwargs.keys():
max_gen_toks = kwargs.pop("max_gen_toks")
if max_gen_toks > self.max_length:
if max_gen_toks > self.max_length: # some model have low max length limit
max_gen_toks = self.max_gen_toks
else:
max_gen_toks = self.max_gen_toks
......@@ -1287,6 +1278,8 @@ class HFLM(TemplateLM):
# set the max length in tokens of inputs ("context_enc")
if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
# max len for inputs = max length, minus room to generate the max new tokens
# if the max new tokens is too large, halve it until it fits as we cannot change
# the max model length
max_ctx_len = self.max_length - max_gen_toks
while max_ctx_len <= 0:
max_gen_toks = max_gen_toks // 2
......
......@@ -389,7 +389,7 @@ class Collator:
self._arr_with_indices, fn=self._group_fn, group_by="contexts"
)
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None, reset_batch_fn: Optional[Callable] = None, accelerator=None) -> Iterator:
def get_batched(self, n: int = 1, batch_fn: Optional[Callable] = None, reset_batch_fn: Optional[Callable] = None) -> Iterator:
"""
Generates and yields batches from the reordered array. The method of grouping and batching
depends on the parameter `group_by`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment