Unverified Commit 0bfb0220 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

batch `loglikelihood_rolling` across requests (#2559)

* batch all rolling token windows

* nit

* copy to vllm

* fix max_length for `get_rolling_token_windows`

* bugfix

* bugfix

* add type hints
parent 976d8a0b
...@@ -905,8 +905,6 @@ class HFLM(TemplateLM): ...@@ -905,8 +905,6 @@ class HFLM(TemplateLM):
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> List[float]:
loglikelihoods = []
adaptive_batch_size = None adaptive_batch_size = None
if self.batch_size == "auto": if self.batch_size == "auto":
# using rolling window with maximum context # using rolling window with maximum context
...@@ -915,10 +913,17 @@ class HFLM(TemplateLM): ...@@ -915,10 +913,17 @@ class HFLM(TemplateLM):
print(f"Determined Largest batch size: {batch_size}") print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size adaptive_batch_size = batch_size
for (string,) in tqdm( # First, collect all windows from all requests
[req.args for req in requests], disable=(disable_tqdm or (self.rank != 0)) all_windows = [] # List of (request_idx, window) tuples
request_window_counts = [] # Track number of windows per request
for req_idx, (string,) in enumerate(
tqdm(
[req.args for req in requests],
disable=(disable_tqdm or (self.rank != 0)),
)
): ):
rolling_token_windows = list( rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
map( map(
utils.make_disjoint_window, utils.make_disjoint_window,
utils.get_rolling_token_windows( utils.get_rolling_token_windows(
...@@ -931,37 +936,55 @@ class HFLM(TemplateLM): ...@@ -931,37 +936,55 @@ class HFLM(TemplateLM):
) )
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows] windows = [(None,) + x for x in rolling_token_windows]
# Store windows with their request index
all_windows.extend((req_idx, window) for window in windows)
request_window_counts.append(len(windows))
# Handle distributed case padding
pad_amnt = 0 pad_amnt = 0
if self.world_size > 1: if self.world_size > 1:
# We pad out the external document-level iterator so the inner iterator doesn't hang mytensor = torch.tensor(len(all_windows), device=self.device)
mytensor = torch.tensor(len(rolling_token_windows), device=self.device) gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
gathered = (
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
)
pad_amnt = max(gathered) - gathered[self.rank] pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0: if pad_amnt > 0:
rolling_token_windows += pad_amnt * [rolling_token_windows[0]] all_windows += pad_amnt * [all_windows[0]]
all_nlls = []
batch_size = adaptive_batch_size or self.batch_size
for i in range(0, len(all_windows), batch_size):
batch = all_windows[i : i + batch_size]
# Extract just the windows for processing, keeping track of request indices
batch_indices, batch_windows = zip(*batch)
string_nll = self._loglikelihood_tokens( batch_nlls = self._loglikelihood_tokens(
requests=rolling_token_windows, requests=batch_windows,
disable_tqdm=True, disable_tqdm=False,
override_bs=adaptive_batch_size, override_bs=len(batch_windows),
) )
# Store results with their request indices
all_nlls.extend(zip(batch_indices, batch_nlls))
# Remove padding if necessary
if (self.world_size > 1) and (pad_amnt > 0): if (self.world_size > 1) and (pad_amnt > 0):
string_nll = [x[0] for x in string_nll[:-pad_amnt]] all_nlls = all_nlls[:-pad_amnt]
else:
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
# cache this loglikelihood_rolling request # Reconstruct per-request loglikelihoods
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll) loglikelihoods = []
current_idx = 0
for window_count in request_window_counts:
# Get all nlls for this request
request_nlls = all_nlls[current_idx : current_idx + window_count]
# Sum up the nlls for this request (discarding is_greedy)
request_total = sum(nll[0] for _, nll in request_nlls)
loglikelihoods.append(request_total)
current_idx += window_count
string = requests[len(loglikelihoods) - 1].args[0]
self.cache_hook.add_partial(
"loglikelihood_rolling", (string,), request_total
)
return loglikelihoods return loglikelihoods
......
...@@ -102,7 +102,7 @@ class VLLM(TemplateLM): ...@@ -102,7 +102,7 @@ class VLLM(TemplateLM):
self.batch_size = ( self.batch_size = (
"auto" "auto"
if isinstance(batch_size, str) and "auto" in batch_size if isinstance(batch_size, str) and "auto" in batch_size
else batch_size else int(batch_size)
) )
if self.data_parallel_size <= 1: if self.data_parallel_size <= 1:
self.model = LLM(**self.model_args) self.model = LLM(**self.model_args)
...@@ -281,10 +281,21 @@ class VLLM(TemplateLM): ...@@ -281,10 +281,21 @@ class VLLM(TemplateLM):
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]: ) -> List[float]:
loglikelihoods = [] adaptive_batch_size = None
if self.batch_size == "auto":
adaptive_batch_size = len(requests)
# First, collect all windows from all requests
all_windows = [] # List of (request_idx, window) tuples
request_window_counts = [] # Track number of windows per request
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm): for req_idx, (string,) in enumerate(
rolling_token_windows = list( tqdm(
[req.args for req in requests],
disable=(disable_tqdm or (self.rank != 0)),
)
):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
map( map(
make_disjoint_window, make_disjoint_window,
get_rolling_token_windows( get_rolling_token_windows(
...@@ -297,20 +308,42 @@ class VLLM(TemplateLM): ...@@ -297,20 +308,42 @@ class VLLM(TemplateLM):
) )
) )
rolling_token_windows = [(None,) + x for x in rolling_token_windows] # TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
windows = [(None,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens( # Store windows with their request index
rolling_token_windows, all_windows.extend((req_idx, window) for window in windows)
) request_window_counts.append(len(windows))
# discard is_greedy all_nlls = []
string_nll = [x[0] for x in string_nll] batch_size = adaptive_batch_size or int(self.batch_size)
for i in range(0, len(all_windows), batch_size):
batch = all_windows[i : i + batch_size]
# Extract just the windows for processing, keeping track of request indices
batch_indices, batch_windows = zip(*batch)
string_nll = sum(string_nll) batch_nlls = self._loglikelihood_tokens(
loglikelihoods.append(string_nll) requests=batch_windows,
disable_tqdm=False,
)
# Store results with their request indices
all_nlls.extend(zip(batch_indices, batch_nlls))
# cache this loglikelihood_rolling request # Reconstruct per-request loglikelihoods
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll) loglikelihoods = []
current_idx = 0
for window_count in request_window_counts:
# Get all nlls for this request
request_nlls = all_nlls[current_idx : current_idx + window_count]
# Sum up the nlls for this request (discarding is_greedy)
request_total = sum(nll[0] for _, nll in request_nlls)
loglikelihoods.append(request_total)
current_idx += window_count
string = requests[len(loglikelihoods) - 1].args[0]
self.cache_hook.add_partial(
"loglikelihood_rolling", (string,), request_total
)
return loglikelihoods return loglikelihoods
......
...@@ -10,7 +10,7 @@ import os ...@@ -10,7 +10,7 @@ import os
import re import re
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from itertools import islice from itertools import islice
from typing import Any, Callable, List from typing import Any, Callable, Generator, List, Tuple
import numpy as np import numpy as np
import yaml import yaml
...@@ -201,7 +201,9 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]: ...@@ -201,7 +201,9 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
return [f for f in filenames if "/samples_" in f and ".json" in f] return [f for f in filenames if "/samples_" in f and ".json" in f]
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len): def get_rolling_token_windows(
token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int
) -> Generator[Tuple[List[int], List[int]], None, None]:
""" """
- context_len allows for a rolling window context, allowing each prediction window to potentially - context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context condition on some context
...@@ -228,7 +230,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len ...@@ -228,7 +230,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
# Special handling for first window: predict all tokens # Special handling for first window: predict all tokens
first_seq_len = min(max_seq_len, len(token_list)) first_seq_len = min(max_seq_len, len(token_list))
yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]) yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]
predicted += first_seq_len predicted += first_seq_len
while predicted < len(token_list): while predicted < len(token_list):
...@@ -242,7 +244,9 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len ...@@ -242,7 +244,9 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
predicted += window_pred_len predicted += window_pred_len
def make_disjoint_window(pair): def make_disjoint_window(
pair: Tuple[List[int], List[int]],
) -> Tuple[List[int], List[int]]:
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a, b = pair a, b = pair
return a[: len(a) - (len(b) - 1)], b return a[: len(a) - (len(b) - 1)], b
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment