Unverified Commit f06e90c2 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Optimize retract (#440)

parent 2cea6146
# NOTE: Currently this can only be run through HTTP requests.
import json
from concurrent.futures import ThreadPoolExecutor
from json_decode import character_regex
from sglang.utils import http_request
character_names = ["Hermione Granger", "Ron Weasley", "Harry Potter"]
base_url = "http://localhost:30000"
prompt = "is a character in Harry Potter. Please fill in the following information about this character.\n"
def openai_api_request(name):
data = {
"model": "",
"prompt": name + prompt,
"temperature": 0,
"max_tokens": 128,
"regex": character_regex,
"logprobs": 3,
}
res = http_request(base_url + "/v1/completions", json=data).json()
# with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
# fout.write(json.dumps(res, indent=4))
logprobs = res["choices"][0]["logprobs"]
usage = res["usage"]
assert len(logprobs["token_logprobs"]) == len(logprobs["tokens"])
assert len(logprobs["token_logprobs"]) == len(logprobs["top_logprobs"])
assert len(logprobs["token_logprobs"]) == usage["completion_tokens"] - 1
return res
def srt_api_request(name):
data = {
"text": name + prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 128,
"regex": character_regex,
},
"return_logprob": True,
"logprob_start_len": 0,
"top_logprobs_num": 3,
"return_text_in_logprobs": True,
}
res = http_request(base_url + "/generate", json=data).json()
# with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
# fout.write(json.dumps(res, indent=4))
meta_info = res["meta_info"]
assert len(meta_info["prefill_token_logprobs"]) == len(
meta_info["prefill_top_logprobs"]
)
assert len(meta_info["decode_token_logprobs"]) == len(
meta_info["decode_top_logprobs"]
)
assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"]
assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1
return res
def pretty_print(res):
meta_info = res["meta_info"]
print("\n\n", "=" * 30, "Prefill", "=" * 30)
for i in range(len(meta_info["prefill_token_logprobs"])):
print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = (
[str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]]
if meta_info["prefill_top_logprobs"][i]
else []
)
for top_k in top_ks:
print(f"{top_k: <15}", end="")
print()
print("\n\n", "=" * 30, "Decode", "=" * 30)
for i in range(len(meta_info["decode_token_logprobs"])):
print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]]
for top_k in top_ks:
print(f"{top_k: <15}", end="")
print()
print(res["text"])
if __name__ == "__main__":
with ThreadPoolExecutor() as executor:
ress = executor.map(srt_api_request, character_names)
for res in ress:
pretty_print(res)
openai_api_request("Hermione Granger")
......@@ -28,5 +28,11 @@ class GlobalConfig:
# Request dependency time due to network delay
self.request_dependency_time = 0.03
# New generation token ratio estimation
self.base_new_token_ratio = 0.4
self.base_min_new_token_ratio = 0.2
self.new_token_ratio_decay = 0.0001
self.new_token_ratio_recovery = 0.05
global_config = GlobalConfig()
......@@ -50,21 +50,22 @@ class LogitsProcessor(nn.Module):
prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
for i in range(len(extend_seq_lens_cpu)):
if extend_seq_lens_cpu[i] == 0:
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0:
prefill_top_logprobs.append([])
decode_top_logprobs.append([])
continue
k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
prefill_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
)
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_lens_cpu[i]
pt += extend_seq_len
return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
......@@ -145,7 +146,7 @@ class LogitsProcessor(nn.Module):
)
if __name__ == "__main__":
def test():
all_logprobs = torch.tensor(
# s s s
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
......@@ -173,3 +174,7 @@ if __name__ == "__main__":
print("start", start)
print("end", end)
print("sum_logp", sum_logp)
if __name__ == "__main__":
test()
......@@ -51,11 +51,6 @@ class DetokenizerManager:
# Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit
for i in range(len(output_strs)):
if recv_obj.hit_stop_str[i] is not None:
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
if pos != -1:
output_strs[i] = output_strs[i][:pos]
if len(output_tokens[i]) > 0:
first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0])
......@@ -65,9 +60,12 @@ class DetokenizerManager:
if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i]
output_strs[i] = (
recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
)
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
if recv_obj.hit_stop_str[i] is not None:
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
if pos != -1:
output_strs[i] = output_strs[i][:pos]
self.send_to_tokenizer.send_pyobj(
BatchStrOut(
......
......@@ -106,8 +106,8 @@ class TokenizedGenerateReqInput:
@dataclass
class BatchTokenIDOut:
rids: List[str]
prev_output_strs : List[str]
output_tokens: List[List[int]]
output_and_jump_forward_strs: List[str]
hit_stop_str: List[Optional[str]]
skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool]
......
......@@ -36,15 +36,15 @@ class FinishReason(IntEnum):
class Req:
def __init__(self, rid, input_text, input_ids):
def __init__(self, rid, origin_input_text, origin_input_ids):
self.rid = rid
self.input_text = input_text
self.input_ids = input_ids
self.origin_input_text = origin_input_text
self.origin_input_ids = origin_input_ids
self.origin_input_ids_unpadded = origin_input_ids # before image padding
self.prev_output_str = ""
self.prev_output_ids = []
self.output_ids = []
# Since jump forward may retokenize the prompt with partial outputs,
# we maintain the original prompt length to report the correct usage.
self.prompt_tokens = len(input_ids)
self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
# The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens.
......@@ -76,15 +76,24 @@ class Req:
self.top_logprobs_num = 0
self.normalized_prompt_logprob = None
self.prefill_token_logprobs = None
self.decode_token_logprobs = None
self.decode_token_logprobs = []
self.prefill_top_logprobs = None
self.decode_top_logprobs = None
self.decode_top_logprobs = []
# The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs
self.last_update_decode_tokens = 0
# Constrained decoding
self.regex_fsm = None
self.regex_fsm_state = 0
self.jump_forward_map = None
self.output_and_jump_forward_str = ""
def partial_decode(self, ids):
first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
first_token = (
first_token.decode() if isinstance(first_token, bytes) else first_token
)
return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
def max_new_tokens(self):
return self.sampling_params.max_new_tokens
......@@ -93,7 +102,10 @@ class Req:
if self.finished:
return
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
if (
len(self.prev_output_ids) + len(self.output_ids)
>= self.sampling_params.max_new_tokens
):
self.finished = True
self.finish_reason = FinishReason.LENGTH
return
......@@ -112,60 +124,66 @@ class Req:
)
for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str:
# FIXME: (minor) try incremental match in prev_output_str
if stop_str in tail_str or stop_str in self.prev_output_str:
self.finished = True
self.finish_reason = FinishReason.STOP_STR
self.hit_stop_str = stop_str
return
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids)
# FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space.
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
first_token = (
first_token.decode() if isinstance(first_token, bytes) else first_token
)
if first_token.startswith("▁"):
old_output_str = " " + old_output_str
if self.input_text is None:
# TODO(lmzheng): This can be wrong. Check with Liangsheng.
self.input_text = self.tokenizer.decode(self.input_ids)
new_input_string = (
self.input_text
+ self.output_and_jump_forward_str
+ old_output_str
cur_output_str = self.partial_decode(self.output_ids)
# TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
if self.origin_input_text is None:
# Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode(
self.origin_input_ids_unpadded
)
all_text = (
self.origin_input_text
+ self.prev_output_str
+ cur_output_str
+ jump_forward_str
)
new_input_ids = self.tokenizer.encode(new_input_string)
if self.pixel_values is not None:
# NOTE: This is a hack because the old input_ids contains the image padding
jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str))
else:
jump_forward_tokens_len = (
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
)
all_ids = self.tokenizer.encode(all_text)
prompt_tokens = len(self.origin_input_ids_unpadded)
self.origin_input_ids = all_ids[:prompt_tokens]
self.origin_input_ids_unpadded = self.origin_input_ids
# NOTE: the output ids may not strictly correspond to the output text
old_prev_output_ids = self.prev_output_ids
self.prev_output_ids = all_ids[prompt_tokens:]
self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
self.output_ids = []
self.regex_fsm_state = next_state
if self.return_logprob:
# For fast-forward part's logprobs
k = 0
for i, old_id in enumerate(old_prev_output_ids):
if old_id == self.prev_output_ids[i]:
k = k + 1
else:
break
self.decode_token_logprobs = self.decode_token_logprobs[:k]
self.decode_top_logprobs = self.decode_top_logprobs[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.prev_output_ids) - k
# print("=" * 100)
# print(f"Catch jump forward:\n{jump_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
self.input_ids = new_input_ids
self.output_ids = []
self.sampling_params.max_new_tokens = max(
self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0
)
self.regex_fsm_state = next_state
self.output_and_jump_forward_str = (
self.output_and_jump_forward_str + old_output_str + jump_forward_str
)
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
# print("*" * 100)
def __repr__(self):
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
@dataclass
......@@ -336,6 +354,7 @@ class Batch:
def retract_decode(self):
sorted_indices = [i for i in range(len(self.reqs))]
# TODO(lsyin): improve the priority of retraction
sorted_indices.sort(
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
reverse=True,
......@@ -356,18 +375,27 @@ class Batch:
][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.dec_refs(token_indices)
# release the last node
self.tree_cache.dec_lock_ref(req.last_node)
cur_output_str = req.partial_decode(req.output_ids)
req.prev_output_str = req.prev_output_str + cur_output_str
req.prev_output_ids.extend(req.output_ids)
req.prefix_indices = None
req.last_node = None
req.extend_input_len = 0
req.output_ids = []
req.regex_fsm_state = 0
# For incremental logprobs
req.last_update_decode_tokens = 0
req.logprob_start_len = 10**9
self.filter_batch(sorted_indices)
return retracted_reqs
def check_for_jump_forward(self):
def check_for_jump_forward(self, model_runner):
jump_forward_reqs = []
filter_indices = [i for i in range(len(self.reqs))]
......@@ -397,6 +425,18 @@ class Batch:
# jump-forward
req.jump_forward_and_retokenize(jump_forward_str, next_state)
# re-applying image padding
if req.pixel_values is not None:
(
req.origin_input_ids,
req.image_offset,
) = model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
jump_forward_reqs.append(req)
filter_indices.remove(i)
......
......@@ -4,7 +4,7 @@ import multiprocessing
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import List, Optional
import rpyc
import torch
......@@ -16,6 +16,7 @@ try:
except ImportError:
from vllm.logger import logger as vllm_default_logger
from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
......@@ -106,7 +107,8 @@ class ModelRpcServer:
set_random_seed(server_args.random_seed)
# Print info
logger.info(f"[rank={self.tp_rank}] "
logger.info(
f"[rank={self.tp_rank}] "
f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, "
......@@ -151,9 +153,20 @@ class ModelRpcServer:
self.jump_forward_cache = JumpForwardCache()
# Init new token estimation
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
assert (
server_args.schedule_conservativeness >= 0
), "Invalid schedule_conservativeness"
self.new_token_ratio = min(
global_config.base_new_token_ratio * server_args.schedule_conservativeness,
1.0,
)
self.min_new_token_ratio = min(
global_config.base_min_new_token_ratio
* server_args.schedule_conservativeness,
1.0,
)
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs):
if self.tp_size != 1:
......@@ -256,8 +269,13 @@ class ModelRpcServer:
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
]
req.image_size = recv_req.image_size
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
req.origin_input_ids, req.image_offset = (
self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
)
req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob
......@@ -275,11 +293,11 @@ class ModelRpcServer:
)
# Truncate prompts that are too long
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min(
req.sampling_params.max_new_tokens,
self.model_config.context_len - 1 - len(req.input_ids),
self.max_total_num_token - 128 - len(req.input_ids),
self.model_config.context_len - 1 - len(req.origin_input_ids),
self.max_total_num_token - 128 - len(req.origin_input_ids),
)
self.forward_queue.append(req)
......@@ -292,6 +310,10 @@ class ModelRpcServer:
# Compute matched prefix length
for req in self.forward_queue:
assert (
len(req.output_ids) == 0
), "The output ids should be empty when prefilling"
req.input_ids = req.origin_input_ids + req.prev_output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_logprob:
prefix_indices = prefix_indices[: req.logprob_start_len]
......@@ -319,7 +341,7 @@ class ModelRpcServer:
)
for req in self.forward_queue:
if req.return_logprob:
if req.return_logprob and req.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob
if req.extend_input_len < 2:
delta = 2 - req.extend_input_len
......@@ -441,28 +463,53 @@ class ModelRpcServer:
req.check_finished()
if req.return_logprob:
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list(
zip(
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :],
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
if req.prefill_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list(
zip(
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :],
)
)
)
if req.logprob_start_len == 0:
req.prefill_token_logprobs = [
(None, req.input_ids[0])
] + req.prefill_token_logprobs
req.decode_token_logprobs = [
if req.logprob_start_len == 0:
req.prefill_token_logprobs = [
(None, req.input_ids[0])
] + req.prefill_token_logprobs
if req.last_update_decode_tokens != 0:
req.decode_token_logprobs.extend(
list(
zip(
prefill_token_logprobs[
pt
+ req.extend_input_len
- req.last_update_decode_tokens : pt
+ req.extend_input_len
- 1
],
req.input_ids[-req.last_update_decode_tokens + 1 :],
)
)
)
req.decode_token_logprobs.append(
(last_token_logprobs[i], next_token_ids[i])
]
)
if req.top_logprobs_num > 0:
req.prefill_top_logprobs = prefill_top_logprobs[i]
if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
req.decode_top_logprobs = [decode_top_logprobs[i]]
if req.prefill_top_logprobs is None:
req.prefill_top_logprobs = prefill_top_logprobs[i]
if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
if req.last_update_decode_tokens != 0:
req.decode_top_logprobs.extend(
prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
)
req.decode_top_logprobs.append(decode_top_logprobs[i])
pt += req.extend_input_len
......@@ -484,7 +531,7 @@ class ModelRpcServer:
# check if decode out of memory
if not batch.check_decode_mem():
old_ratio = self.new_token_ratio
self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
retracted_reqs = batch.retract_decode()
logger.info(
......@@ -495,26 +542,13 @@ class ModelRpcServer:
self.forward_queue.extend(retracted_reqs)
else:
self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_step[0],
self.new_token_ratio - self.new_token_ratio_decay,
self.min_new_token_ratio,
)
if not self.disable_regex_jump_forward:
# check for jump-forward
jump_forward_reqs = batch.check_for_jump_forward()
# check for image jump-forward
for req in jump_forward_reqs:
if req.pixel_values is not None:
(
req.input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.input_ids,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
self.forward_queue.extend(jump_forward_reqs)
if batch.is_empty():
......@@ -557,8 +591,8 @@ class ModelRpcServer:
def handle_finished_requests(self, batch: Batch):
output_rids = []
prev_output_strs = []
output_tokens = []
output_and_jump_forward_strs = []
output_hit_stop_str = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
......@@ -582,8 +616,8 @@ class ModelRpcServer:
)
):
output_rids.append(req.rid)
prev_output_strs.append(req.prev_output_str)
output_tokens.append(req.output_ids)
output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
output_hit_stop_str.append(req.hit_stop_str)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
......@@ -593,10 +627,8 @@ class ModelRpcServer:
)
meta_info = {
"prompt_tokens": req.prompt_tokens,
"completion_tokens": len(req.input_ids)
+ len(req.output_ids)
- req.prompt_tokens,
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": FinishReason.to_str(req.finish_reason),
"hit_stop_str": req.hit_stop_str,
......@@ -623,8 +655,8 @@ class ModelRpcServer:
self.out_pyobjs.append(
BatchTokenIDOut(
output_rids,
prev_output_strs,
output_tokens,
output_and_jump_forward_strs,
output_hit_stop_str,
output_skip_special_tokens,
output_spaces_between_special_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment