Unverified Commit 0bfb0220 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

batch `loglikelihood_rolling` across requests (#2559)

* batch all rolling token windows

* nit

* copy to vllm

* fix max_length for `get_rolling_token_windows`

* bugfix

* bugfix

* add type hints
parent 976d8a0b
......@@ -905,8 +905,6 @@ class HFLM(TemplateLM):
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
loglikelihoods = []
adaptive_batch_size = None
if self.batch_size == "auto":
# using rolling window with maximum context
......@@ -915,10 +913,17 @@ class HFLM(TemplateLM):
print(f"Determined Largest batch size: {batch_size}")
adaptive_batch_size = batch_size
for (string,) in tqdm(
[req.args for req in requests], disable=(disable_tqdm or (self.rank != 0))
# First, collect all windows from all requests
all_windows = [] # List of (request_idx, window) tuples
request_window_counts = [] # Track number of windows per request
for req_idx, (string,) in enumerate(
tqdm(
[req.args for req in requests],
disable=(disable_tqdm or (self.rank != 0)),
)
):
rolling_token_windows = list(
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
map(
utils.make_disjoint_window,
utils.get_rolling_token_windows(
......@@ -931,37 +936,55 @@ class HFLM(TemplateLM):
)
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = 0
if self.world_size > 1:
# We pad out the external document-level iterator so the inner iterator doesn't hang
mytensor = torch.tensor(len(rolling_token_windows), device=self.device)
gathered = (
self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
)
windows = [(None,) + x for x in rolling_token_windows]
pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
rolling_token_windows += pad_amnt * [rolling_token_windows[0]]
# Store windows with their request index
all_windows.extend((req_idx, window) for window in windows)
request_window_counts.append(len(windows))
string_nll = self._loglikelihood_tokens(
requests=rolling_token_windows,
disable_tqdm=True,
override_bs=adaptive_batch_size,
# Handle distributed case padding
pad_amnt = 0
if self.world_size > 1:
mytensor = torch.tensor(len(all_windows), device=self.device)
gathered = self.accelerator.gather(mytensor).cpu().detach().numpy().tolist()
pad_amnt = max(gathered) - gathered[self.rank]
if pad_amnt > 0:
all_windows += pad_amnt * [all_windows[0]]
all_nlls = []
batch_size = adaptive_batch_size or self.batch_size
for i in range(0, len(all_windows), batch_size):
batch = all_windows[i : i + batch_size]
# Extract just the windows for processing, keeping track of request indices
batch_indices, batch_windows = zip(*batch)
batch_nlls = self._loglikelihood_tokens(
requests=batch_windows,
disable_tqdm=False,
override_bs=len(batch_windows),
)
# Store results with their request indices
all_nlls.extend(zip(batch_indices, batch_nlls))
if (self.world_size > 1) and (pad_amnt > 0):
string_nll = [x[0] for x in string_nll[:-pad_amnt]]
else:
# discard is_greedy
string_nll = [x[0] for x in string_nll]
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
# Remove padding if necessary
if (self.world_size > 1) and (pad_amnt > 0):
all_nlls = all_nlls[:-pad_amnt]
# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
# Reconstruct per-request loglikelihoods
loglikelihoods = []
current_idx = 0
for window_count in request_window_counts:
# Get all nlls for this request
request_nlls = all_nlls[current_idx : current_idx + window_count]
# Sum up the nlls for this request (discarding is_greedy)
request_total = sum(nll[0] for _, nll in request_nlls)
loglikelihoods.append(request_total)
current_idx += window_count
string = requests[len(loglikelihoods) - 1].args[0]
self.cache_hook.add_partial(
"loglikelihood_rolling", (string,), request_total
)
return loglikelihoods
......
......@@ -102,7 +102,7 @@ class VLLM(TemplateLM):
self.batch_size = (
"auto"
if isinstance(batch_size, str) and "auto" in batch_size
else batch_size
else int(batch_size)
)
if self.data_parallel_size <= 1:
self.model = LLM(**self.model_args)
......@@ -281,10 +281,21 @@ class VLLM(TemplateLM):
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
rolling_token_windows = list(
adaptive_batch_size = None
if self.batch_size == "auto":
adaptive_batch_size = len(requests)
# First, collect all windows from all requests
all_windows = [] # List of (request_idx, window) tuples
request_window_counts = [] # Track number of windows per request
for req_idx, (string,) in enumerate(
tqdm(
[req.args for req in requests],
disable=(disable_tqdm or (self.rank != 0)),
)
):
rolling_token_windows: List[Tuple[List[int], List[int]]] = list(
map(
make_disjoint_window,
get_rolling_token_windows(
......@@ -297,20 +308,42 @@ class VLLM(TemplateLM):
)
)
rolling_token_windows = [(None,) + x for x in rolling_token_windows]
# TODO: Right now, we pass single EOT token to the Encoder and the full context to the decoder, in seq2seq case
windows = [(None,) + x for x in rolling_token_windows]
string_nll = self._loglikelihood_tokens(
rolling_token_windows,
)
# Store windows with their request index
all_windows.extend((req_idx, window) for window in windows)
request_window_counts.append(len(windows))
# discard is_greedy
string_nll = [x[0] for x in string_nll]
all_nlls = []
batch_size = adaptive_batch_size or int(self.batch_size)
for i in range(0, len(all_windows), batch_size):
batch = all_windows[i : i + batch_size]
# Extract just the windows for processing, keeping track of request indices
batch_indices, batch_windows = zip(*batch)
string_nll = sum(string_nll)
loglikelihoods.append(string_nll)
batch_nlls = self._loglikelihood_tokens(
requests=batch_windows,
disable_tqdm=False,
)
# Store results with their request indices
all_nlls.extend(zip(batch_indices, batch_nlls))
# cache this loglikelihood_rolling request
self.cache_hook.add_partial("loglikelihood_rolling", (string,), string_nll)
# Reconstruct per-request loglikelihoods
loglikelihoods = []
current_idx = 0
for window_count in request_window_counts:
# Get all nlls for this request
request_nlls = all_nlls[current_idx : current_idx + window_count]
# Sum up the nlls for this request (discarding is_greedy)
request_total = sum(nll[0] for _, nll in request_nlls)
loglikelihoods.append(request_total)
current_idx += window_count
string = requests[len(loglikelihoods) - 1].args[0]
self.cache_hook.add_partial(
"loglikelihood_rolling", (string,), request_total
)
return loglikelihoods
......
......@@ -10,7 +10,7 @@ import os
import re
from dataclasses import asdict, is_dataclass
from itertools import islice
from typing import Any, Callable, List
from typing import Any, Callable, Generator, List, Tuple
import numpy as np
import yaml
......@@ -201,7 +201,9 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
return [f for f in filenames if "/samples_" in f and ".json" in f]
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
def get_rolling_token_windows(
token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int
) -> Generator[Tuple[List[int], List[int]], None, None]:
"""
- context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context
......@@ -228,7 +230,7 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
# Special handling for first window: predict all tokens
first_seq_len = min(max_seq_len, len(token_list))
yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
yield [prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len]
predicted += first_seq_len
while predicted < len(token_list):
......@@ -242,7 +244,9 @@ def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len
predicted += window_pred_len
def make_disjoint_window(pair):
def make_disjoint_window(
pair: Tuple[List[int], List[int]],
) -> Tuple[List[int], List[int]]:
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a, b = pair
return a[: len(a) - (len(b) - 1)], b
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment