".github/vscode:/vscode.git/clone" did not exist on "6aebf44f47bc73ac34344fb7b5d941790c11d39d"
Unverified Commit 7fa54a1a authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Make `req_pool_indices` on CPU (#960)

parent 05abd126
......@@ -19,7 +19,6 @@ class GlobalConfig:
self.init_new_token_ratio = 0.7
self.base_min_new_token_ratio = 0.1
self.new_token_ratio_decay = 0.001
self.new_token_ratio_recovery = 0.05
# Runtime constants: The threshold (number of tokens) to trigger layer-wise cuda sync.
# This can improve the speed for large batch sizes during prefill.
......
......@@ -100,6 +100,9 @@ class Req:
self.output_ids = [] # Each decode stage's output ids
self.input_ids = None # input_ids = origin_input_ids + output_ids
# Memory info
self.req_pool_idx = None
# For incremental decoding
# ----- | --------- read_ids -------|
# ----- | surr_ids |
......@@ -321,6 +324,9 @@ class ScheduleBatch:
return_logprob=return_logprob,
)
def batch_size(self):
return len(self.reqs) if self.reqs is not None else 0
def is_empty(self):
return len(self.reqs) == 0
......@@ -328,118 +334,127 @@ class ScheduleBatch:
# Return whether batch has at least 1 streaming request
return any(r.stream for r in self.reqs)
def alloc_req_slots(self, num_reqs):
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"Out of memory. "
"Please set a smaller number for `--max-running-requests`."
)
return req_pool_indices
def alloc_token_slots(self, num_tokens: int):
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
if out_cache_loc is None:
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
if out_cache_loc is None:
logger.error("Prefill out of memory. Try to lower your batch size.")
if self.tree_cache is not None:
self.tree_cache.pretty_print()
exit(1)
return out_cache_loc
def batch_sampling_params(self, vocab_size, int_token_logit_bias):
device = "cuda"
bs, reqs = self.batch_size(), self.reqs
self.temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
device=device,
).view(-1, 1)
self.top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
)
self.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
)
self.frequency_penalties = torch.tensor(
[r.sampling_params.frequency_penalty for r in reqs],
dtype=torch.float,
device=device,
)
self.presence_penalties = torch.tensor(
[r.sampling_params.presence_penalty for r in reqs],
dtype=torch.float,
device=device,
)
# Handle logit bias but only allocate when needed
self.logit_bias = None
for i in range(bs):
if reqs[i].sampling_params.dtype == "int":
if self.logit_bias is None:
self.logit_bias = torch.zeros(
(bs, vocab_size), dtype=torch.float32, device=device
)
self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
device = "cuda"
bs = len(self.reqs)
bs = self.batch_size()
reqs = self.reqs
input_ids = [r.input_ids[len(r.prefix_indices) :] for r in reqs]
prefix_indices = [r.prefix_indices for r in reqs]
# Handle prefix
flatten_input_ids = []
extend_lens = []
prefix_lens = []
seq_lens = []
req_pool_indices = self.req_to_token_pool.alloc(bs)
req_pool_indices_cpu = self.alloc_req_slots(bs)
if req_pool_indices is None:
raise RuntimeError(
"Out of memory. "
"Please set a smaller number for `--max-running-requests`."
)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
for i in range(bs):
flatten_input_ids.extend(input_ids[i])
for i, req in enumerate(reqs):
req.req_pool_idx = req_pool_indices_cpu[i]
extend_lens.append(len(input_ids[i]))
if len(prefix_indices[i]) == 0:
prefix_lens.append(0)
else:
prefix_lens.append(len(prefix_indices[i]))
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
self.req_to_token_pool.req_to_token[req.req_pool_idx][
: len(prefix_indices[i])
] = prefix_indices[i]
seq_lens.append(prefix_lens[-1] + extend_lens[-1])
position_ids_offsets = torch.zeros((bs,), dtype=torch.int32, device=device)
# Allocate memory
seq_lens, prefix_lens = np.array(seq_lens), np.array(prefix_lens)
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None:
if self.tree_cache is not None:
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None:
logger.error("Prefill out of memory. Try to lower your batch size.")
if self.tree_cache is not None:
self.tree_cache.pretty_print()
exit(1)
out_cache_loc = self.alloc_token_slots(extend_num_tokens)
pt = 0
for i in range(bs):
self.req_to_token_pool.req_to_token[req_pool_indices_cpu[i]][
for i, req in enumerate(reqs):
self.req_to_token_pool.req_to_token[req.req_pool_idx][
prefix_lens[i] : prefix_lens[i] + extend_lens[i]
] = out_cache_loc[pt : pt + extend_lens[i]]
pt += extend_lens[i]
# Handle logit bias but only allocate when needed
logit_bias = None
for i in range(bs):
if reqs[i].sampling_params.dtype == "int":
if logit_bias is None:
logit_bias = torch.zeros(
(bs, vocab_size), dtype=torch.float32, device=device
)
logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
# Set fields
self.input_ids = torch.tensor(
flatten_input_ids, dtype=torch.int32, device=device
)
with torch.device("cuda"):
self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32)
self.req_pool_indices = torch.tensor(req_pool_indices_cpu)
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32)
self.position_ids_offsets = torch.zeros((bs,), dtype=torch.int32)
self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_size for r in reqs]
self.image_offsets = [
r.image_offset - p_len for r, p_len in zip(reqs, prefix_lens)
]
self.req_pool_indices = req_pool_indices
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32, device=device)
self.prefix_lens = torch.tensor(prefix_lens, dtype=torch.int32, device=device)
self.position_ids_offsets = position_ids_offsets
self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
device=device,
).view(-1, 1)
self.top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
)
self.top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
)
self.frequency_penalties = torch.tensor(
[r.sampling_params.frequency_penalty for r in reqs],
dtype=torch.float,
device=device,
)
self.presence_penalties = torch.tensor(
[r.sampling_params.presence_penalty for r in reqs],
dtype=torch.float,
device=device,
)
self.logit_bias = logit_bias
self.batch_sampling_params(vocab_size, int_token_logit_bias)
def check_decode_mem(self):
bs = len(self.reqs)
bs = self.batch_size()
if self.token_to_kv_pool.available_size() >= bs:
return True
......@@ -464,7 +479,6 @@ class ScheduleBatch:
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
while (
self.token_to_kv_pool.available_size()
< len(sorted_indices) * global_config.retract_decode_steps
......@@ -482,20 +496,20 @@ class ScheduleBatch:
if isinstance(self.tree_cache, ChunkCache):
# ChunkCache does not have eviction
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][: seq_lens_cpu[idx]]
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
: seq_lens_cpu[idx]
]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
self.req_to_token_pool.free(req.req_pool_idx)
del self.tree_cache.entries[req.rid]
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = len(req.prefix_indices)
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_cpu[idx]
][last_uncached_pos : seq_lens_cpu[idx]]
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx][
last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool.free(token_indices)
self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
......@@ -533,8 +547,6 @@ class ScheduleBatch:
jump_forward_reqs = []
filter_indices = [i for i in range(len(self.reqs))]
req_pool_indices_cpu = None
for i, req in enumerate(self.reqs):
if req.jump_forward_map is not None:
jump_forward_bytes = req.jump_forward_map.jump_forward_byte(
......@@ -584,13 +596,11 @@ class ScheduleBatch:
req.vid += 1
# insert the old request into tree_cache
if req_pool_indices_cpu is None:
req_pool_indices_cpu = self.req_pool_indices.tolist()
self.tree_cache.cache_req(
rid=req.rid,
token_ids=cur_all_ids,
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
req_pool_idx=req.req_pool_idx,
)
# unlock the last node
......@@ -626,14 +636,8 @@ class ScheduleBatch:
self.prefix_lens = None
# Alloc mem
bs = len(self.reqs)
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
if self.out_cache_loc is None:
logger.error("Decode out of memory. Try to lower your batch size.")
if self.tree_cache is not None:
self.tree_cache.pretty_print()
exit(1)
bs = self.batch_size()
self.out_cache_loc = self.alloc_token_slots(bs)
self.req_to_token_pool.req_to_token[
self.req_pool_indices, self.seq_lens - 1
......
......@@ -200,7 +200,6 @@ class ModelTpServer:
)
self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs):
try:
......@@ -625,13 +624,12 @@ class ModelTpServer:
req.output_top_logprobs.append(output.output_top_logprobs[i])
def cache_filled_batch(self, batch: ScheduleBatch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
for i, req in enumerate(batch.reqs):
new_prefix_indices, new_last_node = self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.input_ids),
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
req_pool_idx=req.req_pool_idx,
del_in_memory_pool=False,
old_last_node=req.last_node,
)
......@@ -639,7 +637,7 @@ class ModelTpServer:
if req is self.current_inflight_req:
# inflight request would get a new req idx
self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
self.req_to_token_pool.free(req.req_pool_idx)
def forward_decode_batch(self, batch: ScheduleBatch):
# Check if decode out of memory
......@@ -782,14 +780,13 @@ class ModelTpServer:
# Remove finished reqs
if finished_indices:
# Update radix cache
req_pool_indices_cpu = batch.req_pool_indices.tolist()
for i in finished_indices:
req = batch.reqs[i]
self.tree_cache.cache_req(
rid=req.rid,
token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
last_uncached_pos=len(req.prefix_indices),
req_pool_idx=req_pool_indices_cpu[i],
req_pool_idx=req.req_pool_idx,
)
self.tree_cache.dec_lock_ref(req.last_node)
......
......@@ -16,6 +16,7 @@ limitations under the License.
"""Memory pool."""
import logging
from typing import List
import torch
......@@ -27,34 +28,29 @@ class ReqToTokenPool:
def __init__(self, size: int, max_context_len: int):
self.size = size
self.mem_state = torch.ones((size,), dtype=torch.bool, device="cuda")
self.free_slots = list(range(size))
self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device="cuda"
)
self.can_use_mem_size = size
def alloc(self, need_size: int):
if need_size > self.can_use_mem_size:
def alloc(self, need_size: int) -> List[int]:
if need_size > len(self.free_slots):
return None
select_index = (
torch.nonzero(self.mem_state).squeeze(1)[:need_size].to(torch.int32)
)
self.mem_state[select_index] = False
self.can_use_mem_size -= need_size
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index):
self.mem_state[free_index] = True
if isinstance(free_index, (int,)):
self.can_use_mem_size += 1
self.free_slots.append(free_index)
else:
self.can_use_mem_size += free_index.shape[0]
self.free_slots.extend(free_index)
def clear(self):
self.mem_state.fill_(True)
self.can_use_mem_size = len(self.mem_state)
self.free_slots = list(range(self.size))
class BaseTokenToKVPool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment