Commit bc103ce2 authored by lintangsutawika's avatar lintangsutawika
Browse files

fixes

parents 9f1cb1e7 71388a7e
...@@ -104,6 +104,19 @@ python write_out.py \ ...@@ -104,6 +104,19 @@ python write_out.py \
This will write out one text file for each task. This will write out one text file for each task.
## Multi-GPU Evaluation
Multi-GPU evaluation is supported through [accelerate](https://github.com/huggingface/accelerate). To initialize the distributed environment, run ```accelerate config``` in terminal and follow the prompts. Once the environment is configured, evaluations can be launched with:
```bash
accelerate launch main.py \
--model hf-causal \
--tasks lambada_openai,arc_easy \
--batch_size 16 \
```
**Warning**: Distributed evaluation requires launching multiple processes of the evaluation script. Running ```python main.py *args*``` instead of ```accelerate launch main.py *args*``` on machine with multiple GPUs will only run the evaluations on a single device.
## Implementing new tasks ## Implementing new tasks
To implement a new task in the eval harness, see [this guide](./docs/task_guide.md). To implement a new task in the eval harness, see [this guide](./docs/task_guide.md).
......
...@@ -72,17 +72,18 @@ class HFLM(LM): ...@@ -72,17 +72,18 @@ class HFLM(LM):
# multigpu support with accelerate # multigpu support with accelerate
if gpus > 1: if gpus > 1:
accelerator = Accelerator(device_placement=False) # accelerator = Accelerator(device_placement=False)
accelerator = Accelerator()
if gpus > accelerator.num_processes: if gpus > accelerator.num_processes:
warning = ( warning = (
"WARNING: The number of total GPUs does not match the number of spawned processes. " "WARNING: The number of total system GPUs does not match the number of spawned processes. "
"If you would like to use data parallelism, please launch the script " "If you would like to use data parallelism, please launch the script "
"with 'accelerate launch *script*'. " "with 'accelerate launch *script*'. "
"Current run will proceed with single device." f"Current run will proceed with {accelerator.num_processes} devices."
) )
print(warning) print(warning)
self._rank = 0 self._rank = self.accelerator.local_process_index
self._world_size = 1 self._world_size = self.accelerator.num_processes
else: else:
self.gpt2 = accelerator.prepare(self.gpt2) self.gpt2 = accelerator.prepare(self.gpt2)
...@@ -103,10 +104,18 @@ class HFLM(LM): ...@@ -103,10 +104,18 @@ class HFLM(LM):
@property @property
def max_length(self): def max_length(self):
try: try:
return self.gpt2.config.n_ctx if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.gpt2).config.n_ctx
else:
return self.gpt2.config.n_ctx
except AttributeError: except AttributeError:
# gptneoconfig doesn't have n_ctx apparently # gptneoconfig doesn't have n_ctx apparently
return self.gpt2.config.max_position_embeddings if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(
self.gpt2
).config.max_position_embeddings
else:
return self.gpt2.config.max_position_embeddings
@property @property
def max_gen_toks(self): def max_gen_toks(self):
...@@ -173,94 +182,49 @@ class HFLM(LM): ...@@ -173,94 +182,49 @@ class HFLM(LM):
# TODO: Implement caching once we've confirmed the perplexity implementation # TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization # TODO: automatic batch size detection for vectorization
extra_pad = [] loglikelihoods = []
numpad_batches = 0 for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list(
if self.world_size > 1: map(
cumulative_batches = 0 # balance token batches among iterators utils.make_disjoint_window,
# compute cumulative batches seen per host utils.get_rolling_token_windows(
for (string,) in tqdm([req.args for req in requests], disable=True): token_list=self.tok_encode(string),
rolling_token_windows = list( prefix_token=self.eot_token_id,
map( max_seq_len=self.max_length,
utils.make_disjoint_window, context_len=1,
utils.get_rolling_token_windows( ),
token_list=self.tok_encode(string),
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
) )
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
cumulative_batches += len(rolling_token_windows)
cumul_batches_ranks = torch.tensor(cumulative_batches, device=self.device)
gathered_item = (
self.accelerator.gather(cumul_batches_ranks)
.cpu()
.detach()
.numpy()
.tolist()
) )
# compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks) rolling_token_windows = [(None,) + x for x in rolling_token_windows]
numpad_batches = max(gathered_item) - gathered_item[self.rank]
# pad iterators with a pseudodocument
extra_pad = (
[("pad",)] if max(gathered_item) - min(gathered_item) > 0 else []
)
loglikelihoods = [] # TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
for (string,) in tqdm( # that
extra_pad + [req.args for req in requests], disable=(self.rank != 0)
):
if numpad_batches > 0:
rolling_token_windows = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
token_list=[self.eot_token_id]
* self.max_length
* numpad_batches,
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
)
else: pad_amnt = 0
rolling_token_windows = list( if self.world_size > 1:
map( # TODO: Comment on what we do here
utils.make_disjoint_window, mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
utils.get_rolling_token_windows( gathered = (
token_list=self.tok_encode(string), self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
prefix_token=self.eot_token_id,
max_seq_len=self.max_length,
context_len=1,
),
)
) )
rolling_token_windows = [(None,) + x for x in rolling_token_windows] pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
# TODO: extract out this call so it only gets called once and also somehow figure out partial caching for
# that
string_nll = self._loglikelihood_tokens( string_nll = self._loglikelihood_tokens(
rolling_token_windows, disable_tqdm=True rolling_token_windows, disable_tqdm=True
) )
if (numpad_batches > 0) or (string == "pad"): if (self.world_size > 1) and (pad_amnt > 0):
numpad_batches = 0 string_nll = [x[0] for x in string_nll[:-pad_amnt]]
else: else:
# discard is_greedy # discard is_greedy
string_nll = [x[0] for x in string_nll] string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll) string_nll = sum(string_nll)
loglikelihoods.append(string_nll) loglikelihoods.append(string_nll)
return loglikelihoods return loglikelihoods
...@@ -285,6 +249,7 @@ class HFLM(LM): ...@@ -285,6 +249,7 @@ class HFLM(LM):
tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))), tqdm(re_ord.get_reordered(), disable=(disable_tqdm or (self.rank != 0))),
self.batch_size, self.batch_size,
): ):
inps = [] inps = []
cont_toks_list = [] cont_toks_list = []
inplens = [] inplens = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment