Unverified Commit f06e90c2 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Optimize retract (#440)

parent 2cea6146
# NOTE: Currently this can only be run through HTTP requests.
import json
from concurrent.futures import ThreadPoolExecutor
from json_decode import character_regex
from sglang.utils import http_request
character_names = ["Hermione Granger", "Ron Weasley", "Harry Potter"]
base_url = "http://localhost:30000"
prompt = "is a character in Harry Potter. Please fill in the following information about this character.\n"
def openai_api_request(name):
data = {
"model": "",
"prompt": name + prompt,
"temperature": 0,
"max_tokens": 128,
"regex": character_regex,
"logprobs": 3,
}
res = http_request(base_url + "/v1/completions", json=data).json()
# with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
# fout.write(json.dumps(res, indent=4))
logprobs = res["choices"][0]["logprobs"]
usage = res["usage"]
assert len(logprobs["token_logprobs"]) == len(logprobs["tokens"])
assert len(logprobs["token_logprobs"]) == len(logprobs["top_logprobs"])
assert len(logprobs["token_logprobs"]) == usage["completion_tokens"] - 1
return res
def srt_api_request(name):
data = {
"text": name + prompt,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 128,
"regex": character_regex,
},
"return_logprob": True,
"logprob_start_len": 0,
"top_logprobs_num": 3,
"return_text_in_logprobs": True,
}
res = http_request(base_url + "/generate", json=data).json()
# with open(f"json_logprobs_{name.replace(' ', '_')}_tmp.json", "w") as fout:
# fout.write(json.dumps(res, indent=4))
meta_info = res["meta_info"]
assert len(meta_info["prefill_token_logprobs"]) == len(
meta_info["prefill_top_logprobs"]
)
assert len(meta_info["decode_token_logprobs"]) == len(
meta_info["decode_top_logprobs"]
)
assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"]
assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1
return res
def pretty_print(res):
meta_info = res["meta_info"]
print("\n\n", "=" * 30, "Prefill", "=" * 30)
for i in range(len(meta_info["prefill_token_logprobs"])):
print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = (
[str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]]
if meta_info["prefill_top_logprobs"][i]
else []
)
for top_k in top_ks:
print(f"{top_k: <15}", end="")
print()
print("\n\n", "=" * 30, "Decode", "=" * 30)
for i in range(len(meta_info["decode_token_logprobs"])):
print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="")
top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]]
for top_k in top_ks:
print(f"{top_k: <15}", end="")
print()
print(res["text"])
if __name__ == "__main__":
with ThreadPoolExecutor() as executor:
ress = executor.map(srt_api_request, character_names)
for res in ress:
pretty_print(res)
openai_api_request("Hermione Granger")
...@@ -28,5 +28,11 @@ class GlobalConfig: ...@@ -28,5 +28,11 @@ class GlobalConfig:
# Request dependency time due to network delay # Request dependency time due to network delay
self.request_dependency_time = 0.03 self.request_dependency_time = 0.03
# New generation token ratio estimation
self.base_new_token_ratio = 0.4
self.base_min_new_token_ratio = 0.2
self.new_token_ratio_decay = 0.0001
self.new_token_ratio_recovery = 0.05
global_config = GlobalConfig() global_config = GlobalConfig()
...@@ -50,21 +50,22 @@ class LogitsProcessor(nn.Module): ...@@ -50,21 +50,22 @@ class LogitsProcessor(nn.Module):
prefill_top_logprobs, decode_top_logprobs = [], [] prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0 pt = 0
# NOTE: the GPU-CPU overhead can be reduced # NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy() extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
for i in range(len(extend_seq_lens_cpu)): for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_lens_cpu[i] == 0: if extend_seq_len == 0:
prefill_top_logprobs.append([]) prefill_top_logprobs.append([])
decode_top_logprobs.append([]) decode_top_logprobs.append([])
continue continue
k = input_metadata.top_logprobs_nums[i] k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k) t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist() vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist() ps_cpu = t.indices.tolist()
prefill_top_logprobs.append( prefill_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
) )
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1]))) decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_lens_cpu[i] pt += extend_seq_len
return prefill_top_logprobs, decode_top_logprobs return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata): def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
...@@ -145,7 +146,7 @@ class LogitsProcessor(nn.Module): ...@@ -145,7 +146,7 @@ class LogitsProcessor(nn.Module):
) )
if __name__ == "__main__": def test():
all_logprobs = torch.tensor( all_logprobs = torch.tensor(
# s s s # s s s
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]], [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
...@@ -173,3 +174,7 @@ if __name__ == "__main__": ...@@ -173,3 +174,7 @@ if __name__ == "__main__":
print("start", start) print("start", start)
print("end", end) print("end", end)
print("sum_logp", sum_logp) print("sum_logp", sum_logp)
if __name__ == "__main__":
test()
...@@ -51,11 +51,6 @@ class DetokenizerManager: ...@@ -51,11 +51,6 @@ class DetokenizerManager:
# Trim stop str # Trim stop str
# TODO(lmzheng): handle the case where multiple stop strs are hit # TODO(lmzheng): handle the case where multiple stop strs are hit
for i in range(len(output_strs)): for i in range(len(output_strs)):
if recv_obj.hit_stop_str[i] is not None:
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
if pos != -1:
output_strs[i] = output_strs[i][:pos]
if len(output_tokens[i]) > 0: if len(output_tokens[i]) > 0:
first_token = self.tokenizer.convert_ids_to_tokens( first_token = self.tokenizer.convert_ids_to_tokens(
int(output_tokens[i][0]) int(output_tokens[i][0])
...@@ -65,9 +60,12 @@ class DetokenizerManager: ...@@ -65,9 +60,12 @@ class DetokenizerManager:
if first_token.startswith("▁"): if first_token.startswith("▁"):
output_strs[i] = " " + output_strs[i] output_strs[i] = " " + output_strs[i]
output_strs[i] = ( output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
) if recv_obj.hit_stop_str[i] is not None:
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
if pos != -1:
output_strs[i] = output_strs[i][:pos]
self.send_to_tokenizer.send_pyobj( self.send_to_tokenizer.send_pyobj(
BatchStrOut( BatchStrOut(
......
...@@ -106,8 +106,8 @@ class TokenizedGenerateReqInput: ...@@ -106,8 +106,8 @@ class TokenizedGenerateReqInput:
@dataclass @dataclass
class BatchTokenIDOut: class BatchTokenIDOut:
rids: List[str] rids: List[str]
prev_output_strs : List[str]
output_tokens: List[List[int]] output_tokens: List[List[int]]
output_and_jump_forward_strs: List[str]
hit_stop_str: List[Optional[str]] hit_stop_str: List[Optional[str]]
skip_special_tokens: List[bool] skip_special_tokens: List[bool]
spaces_between_special_tokens: List[bool] spaces_between_special_tokens: List[bool]
......
...@@ -36,15 +36,15 @@ class FinishReason(IntEnum): ...@@ -36,15 +36,15 @@ class FinishReason(IntEnum):
class Req: class Req:
def __init__(self, rid, input_text, input_ids): def __init__(self, rid, origin_input_text, origin_input_ids):
self.rid = rid self.rid = rid
self.input_text = input_text self.origin_input_text = origin_input_text
self.input_ids = input_ids self.origin_input_ids = origin_input_ids
self.origin_input_ids_unpadded = origin_input_ids # before image padding
self.prev_output_str = ""
self.prev_output_ids = []
self.output_ids = [] self.output_ids = []
self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
# Since jump forward may retokenize the prompt with partial outputs,
# we maintain the original prompt length to report the correct usage.
self.prompt_tokens = len(input_ids)
# The number of decoded tokens for token usage report. Note that # The number of decoded tokens for token usage report. Note that
# this does not include the jump forward tokens. # this does not include the jump forward tokens.
...@@ -76,15 +76,24 @@ class Req: ...@@ -76,15 +76,24 @@ class Req:
self.top_logprobs_num = 0 self.top_logprobs_num = 0
self.normalized_prompt_logprob = None self.normalized_prompt_logprob = None
self.prefill_token_logprobs = None self.prefill_token_logprobs = None
self.decode_token_logprobs = None self.decode_token_logprobs = []
self.prefill_top_logprobs = None self.prefill_top_logprobs = None
self.decode_top_logprobs = None self.decode_top_logprobs = []
# The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs
self.last_update_decode_tokens = 0
# Constrained decoding # Constrained decoding
self.regex_fsm = None self.regex_fsm = None
self.regex_fsm_state = 0 self.regex_fsm_state = 0
self.jump_forward_map = None self.jump_forward_map = None
self.output_and_jump_forward_str = ""
def partial_decode(self, ids):
first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
first_token = (
first_token.decode() if isinstance(first_token, bytes) else first_token
)
return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
def max_new_tokens(self): def max_new_tokens(self):
return self.sampling_params.max_new_tokens return self.sampling_params.max_new_tokens
...@@ -93,7 +102,10 @@ class Req: ...@@ -93,7 +102,10 @@ class Req:
if self.finished: if self.finished:
return return
if len(self.output_ids) >= self.sampling_params.max_new_tokens: if (
len(self.prev_output_ids) + len(self.output_ids)
>= self.sampling_params.max_new_tokens
):
self.finished = True self.finished = True
self.finish_reason = FinishReason.LENGTH self.finish_reason = FinishReason.LENGTH
return return
...@@ -112,60 +124,66 @@ class Req: ...@@ -112,60 +124,66 @@ class Req:
) )
for stop_str in self.sampling_params.stop_strs: for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str: # FIXME: (minor) try incremental match in prev_output_str
if stop_str in tail_str or stop_str in self.prev_output_str:
self.finished = True self.finished = True
self.finish_reason = FinishReason.STOP_STR self.finish_reason = FinishReason.STOP_STR
self.hit_stop_str = stop_str self.hit_stop_str = stop_str
return return
def jump_forward_and_retokenize(self, jump_forward_str, next_state): def jump_forward_and_retokenize(self, jump_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids)
# FIXME: This logic does not really solve the problem of determining whether # FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space. # there should be a leading space.
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0]) cur_output_str = self.partial_decode(self.output_ids)
first_token = (
first_token.decode() if isinstance(first_token, bytes) else first_token # TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
if self.origin_input_text is None:
# Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode(
self.origin_input_ids_unpadded
) )
if first_token.startswith("▁"):
old_output_str = " " + old_output_str all_text = (
if self.input_text is None: self.origin_input_text
# TODO(lmzheng): This can be wrong. Check with Liangsheng. + self.prev_output_str
self.input_text = self.tokenizer.decode(self.input_ids) + cur_output_str
new_input_string = (
self.input_text
+ self.output_and_jump_forward_str
+ old_output_str
+ jump_forward_str + jump_forward_str
) )
new_input_ids = self.tokenizer.encode(new_input_string) all_ids = self.tokenizer.encode(all_text)
if self.pixel_values is not None: prompt_tokens = len(self.origin_input_ids_unpadded)
# NOTE: This is a hack because the old input_ids contains the image padding self.origin_input_ids = all_ids[:prompt_tokens]
jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str)) self.origin_input_ids_unpadded = self.origin_input_ids
# NOTE: the output ids may not strictly correspond to the output text
old_prev_output_ids = self.prev_output_ids
self.prev_output_ids = all_ids[prompt_tokens:]
self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
self.output_ids = []
self.regex_fsm_state = next_state
if self.return_logprob:
# For fast-forward part's logprobs
k = 0
for i, old_id in enumerate(old_prev_output_ids):
if old_id == self.prev_output_ids[i]:
k = k + 1
else: else:
jump_forward_tokens_len = ( break
len(new_input_ids) - len(self.input_ids) - len(self.output_ids) self.decode_token_logprobs = self.decode_token_logprobs[:k]
) self.decode_top_logprobs = self.decode_top_logprobs[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.prev_output_ids) - k
# print("=" * 100) # print("=" * 100)
# print(f"Catch jump forward:\n{jump_forward_str}") # print(f"Catch jump forward:\n{jump_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids)) # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids)) # print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
self.input_ids = new_input_ids
self.output_ids = []
self.sampling_params.max_new_tokens = max(
self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0
)
self.regex_fsm_state = next_state
self.output_and_jump_forward_str = (
self.output_and_jump_forward_str + old_output_str + jump_forward_str
)
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}") # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
# print("*" * 100) # print("*" * 100)
def __repr__(self): def __repr__(self):
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, " return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
@dataclass @dataclass
...@@ -336,6 +354,7 @@ class Batch: ...@@ -336,6 +354,7 @@ class Batch:
def retract_decode(self): def retract_decode(self):
sorted_indices = [i for i in range(len(self.reqs))] sorted_indices = [i for i in range(len(self.reqs))]
# TODO(lsyin): improve the priority of retraction
sorted_indices.sort( sorted_indices.sort(
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)), key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
reverse=True, reverse=True,
...@@ -356,18 +375,27 @@ class Batch: ...@@ -356,18 +375,27 @@ class Batch:
][last_uncached_pos : seq_lens_cpu[idx]] ][last_uncached_pos : seq_lens_cpu[idx]]
self.token_to_kv_pool.dec_refs(token_indices) self.token_to_kv_pool.dec_refs(token_indices)
# release the last node
self.tree_cache.dec_lock_ref(req.last_node) self.tree_cache.dec_lock_ref(req.last_node)
cur_output_str = req.partial_decode(req.output_ids)
req.prev_output_str = req.prev_output_str + cur_output_str
req.prev_output_ids.extend(req.output_ids)
req.prefix_indices = None req.prefix_indices = None
req.last_node = None req.last_node = None
req.extend_input_len = 0 req.extend_input_len = 0
req.output_ids = [] req.output_ids = []
req.regex_fsm_state = 0
# For incremental logprobs
req.last_update_decode_tokens = 0
req.logprob_start_len = 10**9
self.filter_batch(sorted_indices) self.filter_batch(sorted_indices)
return retracted_reqs return retracted_reqs
def check_for_jump_forward(self): def check_for_jump_forward(self, model_runner):
jump_forward_reqs = [] jump_forward_reqs = []
filter_indices = [i for i in range(len(self.reqs))] filter_indices = [i for i in range(len(self.reqs))]
...@@ -397,6 +425,18 @@ class Batch: ...@@ -397,6 +425,18 @@ class Batch:
# jump-forward # jump-forward
req.jump_forward_and_retokenize(jump_forward_str, next_state) req.jump_forward_and_retokenize(jump_forward_str, next_state)
# re-applying image padding
if req.pixel_values is not None:
(
req.origin_input_ids,
req.image_offset,
) = model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
jump_forward_reqs.append(req) jump_forward_reqs.append(req)
filter_indices.remove(i) filter_indices.remove(i)
......
...@@ -4,7 +4,7 @@ import multiprocessing ...@@ -4,7 +4,7 @@ import multiprocessing
import time import time
import warnings import warnings
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple, Union from typing import List, Optional
import rpyc import rpyc
import torch import torch
...@@ -16,6 +16,7 @@ try: ...@@ -16,6 +16,7 @@ try:
except ImportError: except ImportError:
from vllm.logger import logger as vllm_default_logger from vllm.logger import logger as vllm_default_logger
from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
...@@ -106,7 +107,8 @@ class ModelRpcServer: ...@@ -106,7 +107,8 @@ class ModelRpcServer:
set_random_seed(server_args.random_seed) set_random_seed(server_args.random_seed)
# Print info # Print info
logger.info(f"[rank={self.tp_rank}] " logger.info(
f"[rank={self.tp_rank}] "
f"max_total_num_token={self.max_total_num_token}, " f"max_total_num_token={self.max_total_num_token}, "
f"max_prefill_num_token={self.max_prefill_num_token}, " f"max_prefill_num_token={self.max_prefill_num_token}, "
f"context_len={self.model_config.context_len}, " f"context_len={self.model_config.context_len}, "
...@@ -151,9 +153,20 @@ class ModelRpcServer: ...@@ -151,9 +153,20 @@ class ModelRpcServer:
self.jump_forward_cache = JumpForwardCache() self.jump_forward_cache = JumpForwardCache()
# Init new token estimation # Init new token estimation
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0) assert (
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0) server_args.schedule_conservativeness >= 0
self.new_token_ratio_step = (0.0001, 0.05) # (down, up) ), "Invalid schedule_conservativeness"
self.new_token_ratio = min(
global_config.base_new_token_ratio * server_args.schedule_conservativeness,
1.0,
)
self.min_new_token_ratio = min(
global_config.base_min_new_token_ratio
* server_args.schedule_conservativeness,
1.0,
)
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
def exposed_step(self, recv_reqs): def exposed_step(self, recv_reqs):
if self.tp_size != 1: if self.tp_size != 1:
...@@ -256,8 +269,13 @@ class ModelRpcServer: ...@@ -256,8 +269,13 @@ class ModelRpcServer:
(recv_req.image_hash >> 64) % self.model_config.vocab_size, (recv_req.image_hash >> 64) % self.model_config.vocab_size,
] ]
req.image_size = recv_req.image_size req.image_size = recv_req.image_size
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids( req.origin_input_ids, req.image_offset = (
req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size self.model_runner.model.pad_input_ids(
req.origin_input_ids_unpadded,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
) )
req.sampling_params = recv_req.sampling_params req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob req.return_logprob = recv_req.return_logprob
...@@ -275,11 +293,11 @@ class ModelRpcServer: ...@@ -275,11 +293,11 @@ class ModelRpcServer:
) )
# Truncate prompts that are too long # Truncate prompts that are too long
req.input_ids = req.input_ids[: self.model_config.context_len - 1] req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
req.sampling_params.max_new_tokens = min( req.sampling_params.max_new_tokens = min(
req.sampling_params.max_new_tokens, req.sampling_params.max_new_tokens,
self.model_config.context_len - 1 - len(req.input_ids), self.model_config.context_len - 1 - len(req.origin_input_ids),
self.max_total_num_token - 128 - len(req.input_ids), self.max_total_num_token - 128 - len(req.origin_input_ids),
) )
self.forward_queue.append(req) self.forward_queue.append(req)
...@@ -292,6 +310,10 @@ class ModelRpcServer: ...@@ -292,6 +310,10 @@ class ModelRpcServer:
# Compute matched prefix length # Compute matched prefix length
for req in self.forward_queue: for req in self.forward_queue:
assert (
len(req.output_ids) == 0
), "The output ids should be empty when prefilling"
req.input_ids = req.origin_input_ids + req.prev_output_ids
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_logprob: if req.return_logprob:
prefix_indices = prefix_indices[: req.logprob_start_len] prefix_indices = prefix_indices[: req.logprob_start_len]
...@@ -319,7 +341,7 @@ class ModelRpcServer: ...@@ -319,7 +341,7 @@ class ModelRpcServer:
) )
for req in self.forward_queue: for req in self.forward_queue:
if req.return_logprob: if req.return_logprob and req.normalized_prompt_logprob is None:
# Need at least two tokens to compute normalized logprob # Need at least two tokens to compute normalized logprob
if req.extend_input_len < 2: if req.extend_input_len < 2:
delta = 2 - req.extend_input_len delta = 2 - req.extend_input_len
...@@ -441,8 +463,10 @@ class ModelRpcServer: ...@@ -441,8 +463,10 @@ class ModelRpcServer:
req.check_finished() req.check_finished()
if req.return_logprob: if req.return_logprob:
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = normalized_prompt_logprobs[i] req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
if req.prefill_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list( req.prefill_token_logprobs = list(
zip( zip(
...@@ -454,15 +478,38 @@ class ModelRpcServer: ...@@ -454,15 +478,38 @@ class ModelRpcServer:
req.prefill_token_logprobs = [ req.prefill_token_logprobs = [
(None, req.input_ids[0]) (None, req.input_ids[0])
] + req.prefill_token_logprobs ] + req.prefill_token_logprobs
req.decode_token_logprobs = [
if req.last_update_decode_tokens != 0:
req.decode_token_logprobs.extend(
list(
zip(
prefill_token_logprobs[
pt
+ req.extend_input_len
- req.last_update_decode_tokens : pt
+ req.extend_input_len
- 1
],
req.input_ids[-req.last_update_decode_tokens + 1 :],
)
)
)
req.decode_token_logprobs.append(
(last_token_logprobs[i], next_token_ids[i]) (last_token_logprobs[i], next_token_ids[i])
] )
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
if req.prefill_top_logprobs is None:
req.prefill_top_logprobs = prefill_top_logprobs[i] req.prefill_top_logprobs = prefill_top_logprobs[i]
if req.logprob_start_len == 0: if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
req.decode_top_logprobs = [decode_top_logprobs[i]]
if req.last_update_decode_tokens != 0:
req.decode_top_logprobs.extend(
prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
)
req.decode_top_logprobs.append(decode_top_logprobs[i])
pt += req.extend_input_len pt += req.extend_input_len
...@@ -484,7 +531,7 @@ class ModelRpcServer: ...@@ -484,7 +531,7 @@ class ModelRpcServer:
# check if decode out of memory # check if decode out of memory
if not batch.check_decode_mem(): if not batch.check_decode_mem():
old_ratio = self.new_token_ratio old_ratio = self.new_token_ratio
self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0) self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
retracted_reqs = batch.retract_decode() retracted_reqs = batch.retract_decode()
logger.info( logger.info(
...@@ -495,26 +542,13 @@ class ModelRpcServer: ...@@ -495,26 +542,13 @@ class ModelRpcServer:
self.forward_queue.extend(retracted_reqs) self.forward_queue.extend(retracted_reqs)
else: else:
self.new_token_ratio = max( self.new_token_ratio = max(
self.new_token_ratio - self.new_token_ratio_step[0], self.new_token_ratio - self.new_token_ratio_decay,
self.min_new_token_ratio, self.min_new_token_ratio,
) )
if not self.disable_regex_jump_forward: if not self.disable_regex_jump_forward:
# check for jump-forward # check for jump-forward
jump_forward_reqs = batch.check_for_jump_forward() jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
# check for image jump-forward
for req in jump_forward_reqs:
if req.pixel_values is not None:
(
req.input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.input_ids,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
self.forward_queue.extend(jump_forward_reqs) self.forward_queue.extend(jump_forward_reqs)
if batch.is_empty(): if batch.is_empty():
...@@ -557,8 +591,8 @@ class ModelRpcServer: ...@@ -557,8 +591,8 @@ class ModelRpcServer:
def handle_finished_requests(self, batch: Batch): def handle_finished_requests(self, batch: Batch):
output_rids = [] output_rids = []
prev_output_strs = []
output_tokens = [] output_tokens = []
output_and_jump_forward_strs = []
output_hit_stop_str = [] output_hit_stop_str = []
output_skip_special_tokens = [] output_skip_special_tokens = []
output_spaces_between_special_tokens = [] output_spaces_between_special_tokens = []
...@@ -582,8 +616,8 @@ class ModelRpcServer: ...@@ -582,8 +616,8 @@ class ModelRpcServer:
) )
): ):
output_rids.append(req.rid) output_rids.append(req.rid)
prev_output_strs.append(req.prev_output_str)
output_tokens.append(req.output_ids) output_tokens.append(req.output_ids)
output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
output_hit_stop_str.append(req.hit_stop_str) output_hit_stop_str.append(req.hit_stop_str)
output_skip_special_tokens.append( output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens req.sampling_params.skip_special_tokens
...@@ -593,10 +627,8 @@ class ModelRpcServer: ...@@ -593,10 +627,8 @@ class ModelRpcServer:
) )
meta_info = { meta_info = {
"prompt_tokens": req.prompt_tokens, "prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.input_ids) "completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
+ len(req.output_ids)
- req.prompt_tokens,
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"finish_reason": FinishReason.to_str(req.finish_reason), "finish_reason": FinishReason.to_str(req.finish_reason),
"hit_stop_str": req.hit_stop_str, "hit_stop_str": req.hit_stop_str,
...@@ -623,8 +655,8 @@ class ModelRpcServer: ...@@ -623,8 +655,8 @@ class ModelRpcServer:
self.out_pyobjs.append( self.out_pyobjs.append(
BatchTokenIDOut( BatchTokenIDOut(
output_rids, output_rids,
prev_output_strs,
output_tokens, output_tokens,
output_and_jump_forward_strs,
output_hit_stop_str, output_hit_stop_str,
output_skip_special_tokens, output_skip_special_tokens,
output_spaces_between_special_tokens, output_spaces_between_special_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment