"tests/python/common/test_random.py" did not exist on "00ba409440d098cce8ad1dd32de7cde38b4a47fe"
Unverified Commit 40ab1f01 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix the possible bug of decode out of memory (#36)

parent 199e82a1
......@@ -40,7 +40,7 @@ def extract_prefix_by_tracing(program, backend):
try:
with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments)
except (StopTracing, TypeError):
except (StopTracing, TypeError, AttributeError):
# Some exceptions may not be catched
pass
......
from dataclasses import dataclass
from enum import Enum, auto
from typing import List
......@@ -38,6 +39,7 @@ class Req:
self.adjust_input_len = 0
self.prefix_indices = []
self.last_node = None
self.normalized_logprob = None
......@@ -81,27 +83,56 @@ class Req:
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
@dataclass
class Batch:
def __init__(
self,
reqs: List[Req],
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: TokenToKVPool,
tree_cache: RadixCache,
):
self.reqs = reqs
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.tree_cache = tree_cache
self.return_normalized_logprob = any(
req.return_normalized_logprob for req in reqs
reqs: List[Req]
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
tree_cache: RadixCache
# batched arguments to model runner
input_ids: torch.Tensor = None
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
prefix_lens: torch.Tensor = None
position_ids_offsets: torch.Tensor = None
out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None
return_normalized_logprob: bool = False
# for multimodal
pixel_values: List[torch.Tensor] = None
image_offsets: List[int] = None
# other arguments for control
output_ids: torch.Tensor = None
extend_num_tokens: int = None
# batched sampling params
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
frequency_penalties: torch.Tensor = None
presence_penalties: torch.Tensor = None
logit_bias: torch.Tensor = None
@classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_normalized_logprob = any(req.return_normalized_logprob for req in reqs)
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache,
return_normalized_logprob=return_normalized_logprob,
)
def is_empty(self):
return len(self.reqs) == 0
def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
device = "cuda"
bs = len(self.reqs)
reqs = self.reqs
......@@ -141,7 +172,7 @@ class Batch:
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None:
print("Prefill out of memory.")
print("Prefill out of memory. This should nerver happen.")
self.tree_cache.pretty_print()
exit()
......@@ -196,7 +227,50 @@ class Batch:
)
self.logit_bias = logit_bias
def update_for_decode(self, input_ids=None):
def check_decode_mem(self):
bs = len(self.reqs)
avai_size = self.token_to_kv_pool.available_size()
if avai_size >= bs:
return True
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
if self.token_to_kv_pool.available_size() >= bs:
return True
return False
def retract_decode(self):
sorted_indices = [i for i in range(len(self.reqs))]
sorted_indices.sort(
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
reverse=True,
)
retracted_reqs = []
seq_lens_np = self.seq_lens.cpu().numpy()
req_pool_indices_np = self.req_pool_indices.cpu().numpy()
while self.token_to_kv_pool.available_size() < len(self.reqs):
idx = sorted_indices.pop()
req = self.reqs[idx]
retracted_reqs.append(req)
self.tree_cache.dec_ref_counter(req.last_node)
req.prefix_indices = None
req.last_node = None
req.adjust_input_len = 0
req.output_ids = []
# TODO: apply more fine-grained retraction
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_np[idx]
][: seq_lens_np[idx]]
self.token_to_kv_pool.free(token_indices)
self.filter_batch(sorted_indices)
return retracted_reqs
def prepare_for_decode(self, input_ids=None):
if input_ids is None:
input_ids = [
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
......@@ -212,13 +286,9 @@ class Batch:
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
if self.out_cache_loc is None:
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
if self.out_cache_loc is None:
print("Decode out of memory.")
self.tree_cache.pretty_print()
exit()
print("Decode out of memory. This should nerver happen.")
self.tree_cache.pretty_print()
exit()
self.out_cache_cont_start = None
self.out_cache_cont_end = None
......@@ -240,6 +310,9 @@ class Batch:
self.prefix_lens = None
self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
self.return_normalized_logprob = any(
req.return_normalized_logprob for req in self.reqs
)
for item in [
"temperatures",
......@@ -263,6 +336,9 @@ class Batch:
[self.position_ids_offsets, other.position_ids_offsets]
)
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
self.return_normalized_logprob = any(
req.return_normalized_logprob for req in self.reqs
)
for item in [
"temperatures",
......
......@@ -45,7 +45,6 @@ class ModelRpcServer(rpyc.Service):
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic
self.schedule_conservativeness = server_args.schedule_conservativeness
# Init model and tokenizer
self.model_config = ModelConfig(
......@@ -114,6 +113,11 @@ class ModelRpcServer(rpyc.Service):
# Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(self.tokenizer)
# Init new token estimation
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
def exposed_step(self, recv_reqs):
if self.tp_size != 1:
recv_reqs = obtain(recv_reqs)
......@@ -209,11 +213,6 @@ class ModelRpcServer(rpyc.Service):
req.stream = recv_req.stream
req.tokenizer = self.tokenizer
# init the regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm_state = 0
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
# Truncate long prompts
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min(
......@@ -249,13 +248,10 @@ class ModelRpcServer(rpyc.Service):
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
new_ratio = (
self.scheduler.new_token_estimation_ratio() * self.schedule_conservativeness
)
if self.running_batch:
available_size -= sum(
[
(r.max_new_tokens() - len(r.output_ids)) * new_ratio
(r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
for r in self.running_batch.reqs
]
)
......@@ -311,7 +307,7 @@ class ModelRpcServer(rpyc.Service):
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
)
new_batch = Batch(
new_batch = Batch.init_new(
can_run_list,
self.req_to_token_pool,
self.token_to_kv_pool,
......@@ -322,7 +318,16 @@ class ModelRpcServer(rpyc.Service):
def forward_fill_batch(self, batch: Batch):
# Build batch tensors
batch.init_extend_batch(self.model_config.vocab_size, self.int_token_logit_bias)
batch.prepare_for_extend(
self.model_config.vocab_size, self.int_token_logit_bias
)
# init the regex fsm before first sampling
for req in batch.reqs:
if req.sampling_params.regex is not None:
req.regex_fsm_state = 0
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
if batch.extend_num_tokens != 0:
# Forward
logits, normalized_logprobs = self.model_runner.forward(
......@@ -350,9 +355,27 @@ class ModelRpcServer(rpyc.Service):
self.handle_finished_requests(batch)
def forward_decode_batch(self, batch: Batch):
# check if decode out of memory
if not batch.check_decode_mem():
old_ratio = self.new_token_ratio
self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
retracted_reqs = batch.retract_decode()
logger.info(
"decode out of memory happened, "
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self.forward_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_step[0],
self.min_new_token_ratio,
)
# Update batch tensors
self.decode_forward_ct += 1
batch.update_for_decode()
batch.prepare_for_decode()
# Forward
logits = self.model_runner.forward(batch, ForwardMode.DECODE)
......
......@@ -17,9 +17,6 @@ class Scheduler:
self.max_total_num_token = max_total_num_token
self.tree_cache = tree_cache
def new_token_estimation_ratio(self):
return 0.5 if self.schedule_heuristic != "fcfs" else 0.6
def get_priority_queue(self, forward_queue):
if self.schedule_heuristic == "lpm":
# longest prefix match
......
......@@ -119,7 +119,7 @@ class ServerArgs:
"--schedule-conservativeness",
type=float,
default=ServerArgs.schedule_conservativeness,
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see out-of-memory errors.",
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
)
parser.add_argument(
"--random-seed",
......
......@@ -34,8 +34,8 @@ def test_generate_worker(model_path, tp_rank, tp_size):
reqs.append(req)
# Prefill
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.init_extend_batch(model.model_config.vocab_size(), None)
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.prepare_for_extend(model.model_config.vocab_size, None)
logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits)
print("extend logits (first)", logits)
......@@ -47,8 +47,8 @@ def test_generate_worker(model_path, tp_rank, tp_size):
req.prefix_indices = model.req_to_token_pool.req_to_token[
batch.req_pool_indices[i], :cut_num
]
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.init_extend_batch(model.model_config.vocab_size(), None)
batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.prepare_for_extend(model.model_config.vocab_size, None)
logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits)
......@@ -59,7 +59,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
# Decode
for i in range(6):
batch.update_for_decode(next_token_ids.cpu().numpy())
batch.prepare_for_decode(next_token_ids.cpu().numpy())
logits = model.forward(batch, ForwardMode.DECODE)
next_token_ids, next_token_probs = batch.sample(logits)
......
import argparse
import random
import string
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from vllm.transformers_utils.tokenizer import get_tokenizer
import sglang as sgl
TOKENIZER = None
RANDOM_PREFILL_LEN = None
RANDOM_DECODE_LEN = None
def gen_prompt(token_num):
if RANDOM_PREFILL_LEN:
token_num = random.randint(1, token_num)
cha_set = string.ascii_letters + string.digits
ret = "".join(random.choices(cha_set, k=token_num))
while len(TOKENIZER(ret).input_ids) < token_num:
ret += random.choice(cha_set)
return ret
def robust_test_dfs(s, d, args, leaf_states):
if d == 0:
s += "END"
leaf_states.append(s)
return
s += gen_prompt(args.len_prefill)
forks = s.fork(args.num_fork)
for fork_s in forks:
fork_s += gen_prompt(args.len_prefill)
new_tokens = (
args.len_decode
if not RANDOM_DECODE_LEN
else random.randint(1, args.len_decode)
)
fork_s += sgl.gen(
max_tokens=new_tokens,
ignore_eos=True,
)
for fork_s in forks:
robust_test_dfs(fork_s, d - 1, args, leaf_states)
def robust_test_bfs(s, args, leaf_states):
old_forks = [s]
new_forks = []
for _ in range(args.depth):
for old_fork in old_forks:
old_fork += gen_prompt(args.len_prefill)
forks = old_fork.fork(args.num_fork)
for fork_s in forks:
fork_s += gen_prompt(args.len_prefill)
new_tokens = (
args.len_decode
if not RANDOM_DECODE_LEN
else random.randint(1, args.len_decode)
)
fork_s += sgl.gen(
max_tokens=new_tokens,
ignore_eos=True,
)
new_forks.extend(forks)
old_forks = new_forks
new_forks = []
for old_fork in old_forks:
old_fork += "END"
leaf_states.append(old_fork)
@sgl.function
def robust_test(s, args):
leaf_states = []
if args.mode == "bfs":
robust_test_bfs(s, args, leaf_states)
else:
robust_test_dfs(s, args.depth, args, leaf_states)
return leaf_states
def main(args):
backend = select_sglang_backend(args)
arguments = [{"args": args} for _ in range(args.num_req)]
states = robust_test.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel
)
with open(f"tmp_robust_{args.mode}.txt", "w") as f:
for state in states:
leaf_states = state.ret_value
for leaf_state in leaf_states:
assert leaf_state.text()[-3:] == "END"
f.write(leaf_state.text()[:-3] + "\n")
if __name__ == "__main__":
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--num-req", type=int, default=2)
parser.add_argument("--depth", type=int, default=3)
parser.add_argument("--num-fork", type=int, default=2)
parser.add_argument("--len-prefill", type=int, default=128)
parser.add_argument("--len-decode", type=int, default=128)
parser.add_argument("--random-prefill-len", action="store_true")
parser.add_argument("--random-decode-len", action="store_true")
parser.add_argument("--mode", type=str, default="bfs", choices=["dfs", "bfs"])
parser.add_argument("--tokenizer", type=str, default = "meta-llama/Llama-2-7b-chat-hf")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--seed", type=int, default=42)
args = add_common_sglang_args_and_parse(parser)
# fmt: on
RANDOM_PREFILL_LEN = args.random_prefill_len
RANDOM_DECODE_LEN = args.random_decode_len
TOKENIZER = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
random.seed(args.seed)
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment