Unverified Commit 40ab1f01 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix the possible bug of decode out of memory (#36)

parent 199e82a1
...@@ -40,7 +40,7 @@ def extract_prefix_by_tracing(program, backend): ...@@ -40,7 +40,7 @@ def extract_prefix_by_tracing(program, backend):
try: try:
with TracingScope(tracer): with TracingScope(tracer):
tracer.ret_value = program.func(tracer, **arguments) tracer.ret_value = program.func(tracer, **arguments)
except (StopTracing, TypeError): except (StopTracing, TypeError, AttributeError):
# Some exceptions may not be catched # Some exceptions may not be catched
pass pass
......
from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import List from typing import List
...@@ -38,6 +39,7 @@ class Req: ...@@ -38,6 +39,7 @@ class Req:
self.adjust_input_len = 0 self.adjust_input_len = 0
self.prefix_indices = [] self.prefix_indices = []
self.last_node = None
self.normalized_logprob = None self.normalized_logprob = None
...@@ -81,27 +83,56 @@ class Req: ...@@ -81,27 +83,56 @@ class Req:
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, " return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
@dataclass
class Batch: class Batch:
def __init__( reqs: List[Req]
self, req_to_token_pool: ReqToTokenPool
reqs: List[Req], token_to_kv_pool: TokenToKVPool
req_to_token_pool: ReqToTokenPool, tree_cache: RadixCache
token_to_kv_pool: TokenToKVPool,
tree_cache: RadixCache, # batched arguments to model runner
): input_ids: torch.Tensor = None
self.reqs = reqs req_pool_indices: torch.Tensor = None
self.req_to_token_pool = req_to_token_pool seq_lens: torch.Tensor = None
self.token_to_kv_pool = token_to_kv_pool prefix_lens: torch.Tensor = None
self.tree_cache = tree_cache position_ids_offsets: torch.Tensor = None
out_cache_loc: torch.Tensor = None
self.return_normalized_logprob = any( out_cache_cont_start: torch.Tensor = None
req.return_normalized_logprob for req in reqs out_cache_cont_end: torch.Tensor = None
return_normalized_logprob: bool = False
# for multimodal
pixel_values: List[torch.Tensor] = None
image_offsets: List[int] = None
# other arguments for control
output_ids: torch.Tensor = None
extend_num_tokens: int = None
# batched sampling params
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
frequency_penalties: torch.Tensor = None
presence_penalties: torch.Tensor = None
logit_bias: torch.Tensor = None
@classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_normalized_logprob = any(req.return_normalized_logprob for req in reqs)
return cls(
reqs=reqs,
req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache,
return_normalized_logprob=return_normalized_logprob,
) )
def is_empty(self): def is_empty(self):
return len(self.reqs) == 0 return len(self.reqs) == 0
def init_extend_batch(self, vocab_size: int, int_token_logit_bias: torch.Tensor): def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor):
device = "cuda" device = "cuda"
bs = len(self.reqs) bs = len(self.reqs)
reqs = self.reqs reqs = self.reqs
...@@ -141,7 +172,7 @@ class Batch: ...@@ -141,7 +172,7 @@ class Batch:
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens) out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
if out_cache_loc is None: if out_cache_loc is None:
print("Prefill out of memory.") print("Prefill out of memory. This should nerver happen.")
self.tree_cache.pretty_print() self.tree_cache.pretty_print()
exit() exit()
...@@ -196,7 +227,50 @@ class Batch: ...@@ -196,7 +227,50 @@ class Batch:
) )
self.logit_bias = logit_bias self.logit_bias = logit_bias
def update_for_decode(self, input_ids=None): def check_decode_mem(self):
bs = len(self.reqs)
avai_size = self.token_to_kv_pool.available_size()
if avai_size >= bs:
return True
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
if self.token_to_kv_pool.available_size() >= bs:
return True
return False
def retract_decode(self):
sorted_indices = [i for i in range(len(self.reqs))]
sorted_indices.sort(
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
reverse=True,
)
retracted_reqs = []
seq_lens_np = self.seq_lens.cpu().numpy()
req_pool_indices_np = self.req_pool_indices.cpu().numpy()
while self.token_to_kv_pool.available_size() < len(self.reqs):
idx = sorted_indices.pop()
req = self.reqs[idx]
retracted_reqs.append(req)
self.tree_cache.dec_ref_counter(req.last_node)
req.prefix_indices = None
req.last_node = None
req.adjust_input_len = 0
req.output_ids = []
# TODO: apply more fine-grained retraction
token_indices = self.req_to_token_pool.req_to_token[
req_pool_indices_np[idx]
][: seq_lens_np[idx]]
self.token_to_kv_pool.free(token_indices)
self.filter_batch(sorted_indices)
return retracted_reqs
def prepare_for_decode(self, input_ids=None):
if input_ids is None: if input_ids is None:
input_ids = [ input_ids = [
r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs r.output_ids[-1] if r.output_ids else r.input_ids[-1] for r in self.reqs
...@@ -212,11 +286,7 @@ class Batch: ...@@ -212,11 +286,7 @@ class Batch:
self.out_cache_loc = self.token_to_kv_pool.alloc(bs) self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
if self.out_cache_loc is None: if self.out_cache_loc is None:
self.tree_cache.evict(bs, self.token_to_kv_pool.free) print("Decode out of memory. This should nerver happen.")
self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
if self.out_cache_loc is None:
print("Decode out of memory.")
self.tree_cache.pretty_print() self.tree_cache.pretty_print()
exit() exit()
...@@ -240,6 +310,9 @@ class Batch: ...@@ -240,6 +310,9 @@ class Batch:
self.prefix_lens = None self.prefix_lens = None
self.position_ids_offsets = self.position_ids_offsets[new_indices] self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
self.return_normalized_logprob = any(
req.return_normalized_logprob for req in self.reqs
)
for item in [ for item in [
"temperatures", "temperatures",
...@@ -263,6 +336,9 @@ class Batch: ...@@ -263,6 +336,9 @@ class Batch:
[self.position_ids_offsets, other.position_ids_offsets] [self.position_ids_offsets, other.position_ids_offsets]
) )
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
self.return_normalized_logprob = any(
req.return_normalized_logprob for req in self.reqs
)
for item in [ for item in [
"temperatures", "temperatures",
......
...@@ -45,7 +45,6 @@ class ModelRpcServer(rpyc.Service): ...@@ -45,7 +45,6 @@ class ModelRpcServer(rpyc.Service):
self.tp_rank = tp_rank self.tp_rank = tp_rank
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
self.schedule_heuristic = server_args.schedule_heuristic self.schedule_heuristic = server_args.schedule_heuristic
self.schedule_conservativeness = server_args.schedule_conservativeness
# Init model and tokenizer # Init model and tokenizer
self.model_config = ModelConfig( self.model_config = ModelConfig(
...@@ -114,6 +113,11 @@ class ModelRpcServer(rpyc.Service): ...@@ -114,6 +113,11 @@ class ModelRpcServer(rpyc.Service):
# Init the FSM cache for constrained generation # Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(self.tokenizer) self.regex_fsm_cache = FSMCache(self.tokenizer)
# Init new token estimation
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
def exposed_step(self, recv_reqs): def exposed_step(self, recv_reqs):
if self.tp_size != 1: if self.tp_size != 1:
recv_reqs = obtain(recv_reqs) recv_reqs = obtain(recv_reqs)
...@@ -209,11 +213,6 @@ class ModelRpcServer(rpyc.Service): ...@@ -209,11 +213,6 @@ class ModelRpcServer(rpyc.Service):
req.stream = recv_req.stream req.stream = recv_req.stream
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
# init the regex fsm
if req.sampling_params.regex is not None:
req.regex_fsm_state = 0
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
# Truncate long prompts # Truncate long prompts
req.input_ids = req.input_ids[: self.model_config.context_len - 1] req.input_ids = req.input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min( req.sampling_params.max_new_tokens = min(
...@@ -249,13 +248,10 @@ class ModelRpcServer(rpyc.Service): ...@@ -249,13 +248,10 @@ class ModelRpcServer(rpyc.Service):
available_size = ( available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
) )
new_ratio = (
self.scheduler.new_token_estimation_ratio() * self.schedule_conservativeness
)
if self.running_batch: if self.running_batch:
available_size -= sum( available_size -= sum(
[ [
(r.max_new_tokens() - len(r.output_ids)) * new_ratio (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
for r in self.running_batch.reqs for r in self.running_batch.reqs
] ]
) )
...@@ -311,7 +307,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -311,7 +307,7 @@ class ModelRpcServer(rpyc.Service):
f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" f"#running_req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
) )
new_batch = Batch( new_batch = Batch.init_new(
can_run_list, can_run_list,
self.req_to_token_pool, self.req_to_token_pool,
self.token_to_kv_pool, self.token_to_kv_pool,
...@@ -322,7 +318,16 @@ class ModelRpcServer(rpyc.Service): ...@@ -322,7 +318,16 @@ class ModelRpcServer(rpyc.Service):
def forward_fill_batch(self, batch: Batch): def forward_fill_batch(self, batch: Batch):
# Build batch tensors # Build batch tensors
batch.init_extend_batch(self.model_config.vocab_size, self.int_token_logit_bias) batch.prepare_for_extend(
self.model_config.vocab_size, self.int_token_logit_bias
)
# init the regex fsm before first sampling
for req in batch.reqs:
if req.sampling_params.regex is not None:
req.regex_fsm_state = 0
req.regex_fsm = self.regex_fsm_cache.get_fsm(req.sampling_params.regex)
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
# Forward # Forward
logits, normalized_logprobs = self.model_runner.forward( logits, normalized_logprobs = self.model_runner.forward(
...@@ -350,9 +355,27 @@ class ModelRpcServer(rpyc.Service): ...@@ -350,9 +355,27 @@ class ModelRpcServer(rpyc.Service):
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
def forward_decode_batch(self, batch: Batch): def forward_decode_batch(self, batch: Batch):
# check if decode out of memory
if not batch.check_decode_mem():
old_ratio = self.new_token_ratio
self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
retracted_reqs = batch.retract_decode()
logger.info(
"decode out of memory happened, "
f"#retracted_reqs: {len(retracted_reqs)}, "
f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
)
self.forward_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_step[0],
self.min_new_token_ratio,
)
# Update batch tensors # Update batch tensors
self.decode_forward_ct += 1 self.decode_forward_ct += 1
batch.update_for_decode() batch.prepare_for_decode()
# Forward # Forward
logits = self.model_runner.forward(batch, ForwardMode.DECODE) logits = self.model_runner.forward(batch, ForwardMode.DECODE)
......
...@@ -17,9 +17,6 @@ class Scheduler: ...@@ -17,9 +17,6 @@ class Scheduler:
self.max_total_num_token = max_total_num_token self.max_total_num_token = max_total_num_token
self.tree_cache = tree_cache self.tree_cache = tree_cache
def new_token_estimation_ratio(self):
return 0.5 if self.schedule_heuristic != "fcfs" else 0.6
def get_priority_queue(self, forward_queue): def get_priority_queue(self, forward_queue):
if self.schedule_heuristic == "lpm": if self.schedule_heuristic == "lpm":
# longest prefix match # longest prefix match
......
...@@ -119,7 +119,7 @@ class ServerArgs: ...@@ -119,7 +119,7 @@ class ServerArgs:
"--schedule-conservativeness", "--schedule-conservativeness",
type=float, type=float,
default=ServerArgs.schedule_conservativeness, default=ServerArgs.schedule_conservativeness,
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see out-of-memory errors.", help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
) )
parser.add_argument( parser.add_argument(
"--random-seed", "--random-seed",
......
...@@ -34,8 +34,8 @@ def test_generate_worker(model_path, tp_rank, tp_size): ...@@ -34,8 +34,8 @@ def test_generate_worker(model_path, tp_rank, tp_size):
reqs.append(req) reqs.append(req)
# Prefill # Prefill
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None) batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.init_extend_batch(model.model_config.vocab_size(), None) batch.prepare_for_extend(model.model_config.vocab_size, None)
logits, _ = model.forward(batch, ForwardMode.EXTEND) logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits) next_token_ids, next_token_probs = batch.sample(logits)
print("extend logits (first)", logits) print("extend logits (first)", logits)
...@@ -47,8 +47,8 @@ def test_generate_worker(model_path, tp_rank, tp_size): ...@@ -47,8 +47,8 @@ def test_generate_worker(model_path, tp_rank, tp_size):
req.prefix_indices = model.req_to_token_pool.req_to_token[ req.prefix_indices = model.req_to_token_pool.req_to_token[
batch.req_pool_indices[i], :cut_num batch.req_pool_indices[i], :cut_num
] ]
batch = Batch(reqs, model.req_to_token_pool, model.token_to_kv_pool, None) batch = Batch.init_new(reqs, model.req_to_token_pool, model.token_to_kv_pool, None)
batch.init_extend_batch(model.model_config.vocab_size(), None) batch.prepare_for_extend(model.model_config.vocab_size, None)
logits, _ = model.forward(batch, ForwardMode.EXTEND) logits, _ = model.forward(batch, ForwardMode.EXTEND)
next_token_ids, next_token_probs = batch.sample(logits) next_token_ids, next_token_probs = batch.sample(logits)
...@@ -59,7 +59,7 @@ def test_generate_worker(model_path, tp_rank, tp_size): ...@@ -59,7 +59,7 @@ def test_generate_worker(model_path, tp_rank, tp_size):
# Decode # Decode
for i in range(6): for i in range(6):
batch.update_for_decode(next_token_ids.cpu().numpy()) batch.prepare_for_decode(next_token_ids.cpu().numpy())
logits = model.forward(batch, ForwardMode.DECODE) logits = model.forward(batch, ForwardMode.DECODE)
next_token_ids, next_token_probs = batch.sample(logits) next_token_ids, next_token_probs = batch.sample(logits)
......
import argparse
import random
import string
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from vllm.transformers_utils.tokenizer import get_tokenizer
import sglang as sgl
TOKENIZER = None
RANDOM_PREFILL_LEN = None
RANDOM_DECODE_LEN = None
def gen_prompt(token_num):
if RANDOM_PREFILL_LEN:
token_num = random.randint(1, token_num)
cha_set = string.ascii_letters + string.digits
ret = "".join(random.choices(cha_set, k=token_num))
while len(TOKENIZER(ret).input_ids) < token_num:
ret += random.choice(cha_set)
return ret
def robust_test_dfs(s, d, args, leaf_states):
if d == 0:
s += "END"
leaf_states.append(s)
return
s += gen_prompt(args.len_prefill)
forks = s.fork(args.num_fork)
for fork_s in forks:
fork_s += gen_prompt(args.len_prefill)
new_tokens = (
args.len_decode
if not RANDOM_DECODE_LEN
else random.randint(1, args.len_decode)
)
fork_s += sgl.gen(
max_tokens=new_tokens,
ignore_eos=True,
)
for fork_s in forks:
robust_test_dfs(fork_s, d - 1, args, leaf_states)
def robust_test_bfs(s, args, leaf_states):
old_forks = [s]
new_forks = []
for _ in range(args.depth):
for old_fork in old_forks:
old_fork += gen_prompt(args.len_prefill)
forks = old_fork.fork(args.num_fork)
for fork_s in forks:
fork_s += gen_prompt(args.len_prefill)
new_tokens = (
args.len_decode
if not RANDOM_DECODE_LEN
else random.randint(1, args.len_decode)
)
fork_s += sgl.gen(
max_tokens=new_tokens,
ignore_eos=True,
)
new_forks.extend(forks)
old_forks = new_forks
new_forks = []
for old_fork in old_forks:
old_fork += "END"
leaf_states.append(old_fork)
@sgl.function
def robust_test(s, args):
leaf_states = []
if args.mode == "bfs":
robust_test_bfs(s, args, leaf_states)
else:
robust_test_dfs(s, args.depth, args, leaf_states)
return leaf_states
def main(args):
backend = select_sglang_backend(args)
arguments = [{"args": args} for _ in range(args.num_req)]
states = robust_test.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel
)
with open(f"tmp_robust_{args.mode}.txt", "w") as f:
for state in states:
leaf_states = state.ret_value
for leaf_state in leaf_states:
assert leaf_state.text()[-3:] == "END"
f.write(leaf_state.text()[:-3] + "\n")
if __name__ == "__main__":
# fmt: off
parser = argparse.ArgumentParser()
parser.add_argument("--num-req", type=int, default=2)
parser.add_argument("--depth", type=int, default=3)
parser.add_argument("--num-fork", type=int, default=2)
parser.add_argument("--len-prefill", type=int, default=128)
parser.add_argument("--len-decode", type=int, default=128)
parser.add_argument("--random-prefill-len", action="store_true")
parser.add_argument("--random-decode-len", action="store_true")
parser.add_argument("--mode", type=str, default="bfs", choices=["dfs", "bfs"])
parser.add_argument("--tokenizer", type=str, default = "meta-llama/Llama-2-7b-chat-hf")
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--seed", type=int, default=42)
args = add_common_sglang_args_and_parse(parser)
# fmt: on
RANDOM_PREFILL_LEN = args.random_prefill_len
RANDOM_DECODE_LEN = args.random_decode_len
TOKENIZER = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
random.seed(args.seed)
main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment