Unverified Commit 9a16fea0 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Return logprob for choices (#87)

parent 9e037c82
...@@ -69,6 +69,8 @@ state = multi_turn_question.run( ...@@ -69,6 +69,8 @@ state = multi_turn_question.run(
for m in state.messages(): for m in state.messages():
print(m["role"], ":", m["content"]) print(m["role"], ":", m["content"])
print(state["answer_1"])
``` ```
### Using Local Models ### Using Local Models
...@@ -99,6 +101,8 @@ state = multi_turn_question.run( ...@@ -99,6 +101,8 @@ state = multi_turn_question.run(
for m in state.messages(): for m in state.messages():
print(m["role"], ":", m["content"]) print(m["role"], ":", m["content"])
print(state["answer_1"])
``` ```
### More Examples ### More Examples
......
...@@ -9,8 +9,8 @@ class GenerateReqInput: ...@@ -9,8 +9,8 @@ class GenerateReqInput:
image_data: Optional[Union[List[str], str]] = None image_data: Optional[Union[List[str], str]] = None
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
return_normalized_logprob: Optional[Union[List[bool], bool]] = None return_logprob: Optional[Union[List[bool], bool]] = None
normalized_logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
stream: bool = False stream: bool = False
``` ```
......
"""
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
"""
import sglang as sgl
@sgl.function
def tool_use(s, question):
s += "To answer this question: " + question + ", "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"])
def main():
# Run one case
question = "What is 5 + 5?"
state = tool_use.run(question)
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
print('-' * 50)
# Run a batch
questions = [
"What is 5 + 6?",
"Who is Michael Jordan?",
]
states = tool_use.run_batch([{"question": q} for q in questions])
for question, state in zip(questions, states):
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
print('-' * 50)
if __name__ == "__main__":
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
main()
...@@ -209,7 +209,7 @@ class OpenAI(BaseBackend): ...@@ -209,7 +209,7 @@ class OpenAI(BaseBackend):
prompt_tokens.append(ret_token) prompt_tokens.append(ret_token)
decision = choices[np.argmax(scores)] decision = choices[np.argmax(scores)]
return decision, scores return decision, scores, scores
def openai_completion(client, is_chat=None, prompt=None, **kwargs): def openai_completion(client, is_chat=None, prompt=None, **kwargs):
......
...@@ -150,16 +150,20 @@ class RuntimeEndpoint(BaseBackend): ...@@ -150,16 +150,20 @@ class RuntimeEndpoint(BaseBackend):
data = { data = {
"text": [s.text_ + c for c in choices], "text": [s.text_ + c for c in choices],
"sampling_params": {"max_new_tokens": 0}, "sampling_params": {"max_new_tokens": 0},
"return_normalized_logprob": True, "return_logprob": True,
"normalized_logprob_start_len": prompt_len, "logprob_start_len": max(prompt_len - 2, 0),
} }
self._add_images(s, data) self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data) res = http_request(self.base_url + "/generate", json=data)
assert res.status_code == 200 assert res.status_code == 200
logps = [r["meta_info"]["normalized_logprob"] for r in res.json()] obj = res.json()
normalized_prompt_logprob = [
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
prompt_logprob = [r["meta_info"]["prompt_logprob"] for r in obj]
decision = choices[np.argmax(logps)] decision = choices[np.argmax(normalized_prompt_logprob)]
return decision, logps return decision, normalized_prompt_logprob, prompt_logprob
def concatenate_and_append(self, src_rids: List[str], dst_rid: str): def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
res = http_request( res = http_request(
......
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum, auto from enum import Enum, auto
from typing import Callable, Dict, List, Tuple, Optional from typing import Callable, Dict, List, Optional, Tuple
class ChatTemplateStyle(Enum): class ChatTemplateStyle(Enum):
...@@ -111,7 +111,7 @@ register_chat_template( ...@@ -111,7 +111,7 @@ register_chat_template(
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"), "assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
}, },
style=ChatTemplateStyle.PLAIN, style=ChatTemplateStyle.PLAIN,
stop_str=('<|im_end|>',) stop_str=("<|im_end|>",),
) )
) )
......
...@@ -80,7 +80,7 @@ def run_program_batch( ...@@ -80,7 +80,7 @@ def run_program_batch(
# Run all programs # Run all programs
if num_threads == "auto": if num_threads == "auto":
num_threads = max(64, multiprocessing.cpu_count() * 8) num_threads = max(96, multiprocessing.cpu_count() * 16)
num_threads = min(num_threads, len(batch_arguments)) num_threads = min(num_threads, len(batch_arguments))
if num_threads == 1: if num_threads == 1:
...@@ -364,10 +364,16 @@ class StreamExecutor: ...@@ -364,10 +364,16 @@ class StreamExecutor:
self.stream_var_event[name].set() self.stream_var_event[name].set()
def _execute_select(self, expr: SglSelect): def _execute_select(self, expr: SglSelect):
decision, scores = self.backend.select(self, expr.choices, expr.temperature) decision, normalized_prompt_logprob, prompt_logprob = self.backend.select(
self, expr.choices, expr.temperature
)
if expr.name is not None: if expr.name is not None:
name = expr.name name = expr.name
self.variables[name] = decision self.variables[name] = decision
self.meta_info[name] = {
"normalized_prompt_logprob": normalized_prompt_logprob,
"prompt_logprob": prompt_logprob,
}
self.variable_event[name].set() self.variable_event[name].set()
self.text_ += decision self.text_ += decision
......
...@@ -14,7 +14,7 @@ class LogitsProcessor(nn.Module): ...@@ -14,7 +14,7 @@ class LogitsProcessor(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
def forward(self, input_ids, hidden_states, weight, input_metadata): def forward(self, input_ids, hidden_states, weight, input_metadata):
if not input_metadata.return_normalized_logprob: if not input_metadata.return_logprob:
if input_metadata.forward_mode == ForwardMode.DECODE: if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states last_hidden = hidden_states
else: else:
...@@ -33,7 +33,7 @@ class LogitsProcessor(nn.Module): ...@@ -33,7 +33,7 @@ class LogitsProcessor(nn.Module):
if self.tp_size > 1: if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size] last_logits = last_logits[:, : self.config.vocab_size]
return last_logits, None return last_logits, (None, None)
else: else:
assert input_metadata.forward_mode != ForwardMode.DECODE assert input_metadata.forward_mode != ForwardMode.DECODE
last_index = ( last_index = (
...@@ -51,30 +51,23 @@ class LogitsProcessor(nn.Module): ...@@ -51,30 +51,23 @@ class LogitsProcessor(nn.Module):
logits = logits[:, : self.config.vocab_size] logits = logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6) all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
normalized_logprobs = compute_normalized_logprobs( logprobs = all_logprobs[
all_logprobs, torch.arange(all_logprobs.shape[0], device="cuda"),
input_ids, torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
input_metadata.extend_seq_lens, ]
input_metadata.extend_start_loc, logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
) )
last_logits = logits[last_index] last_logits = logits[last_index]
return last_logits, normalized_logprobs return last_logits, (logprobs, normalized_logprobs)
def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc):
logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
start = start_loc.clone()
end = start + seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
return sum_logp / ((seq_lens - 1).clamp(min=1))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -11,8 +11,8 @@ class GenerateReqInput: ...@@ -11,8 +11,8 @@ class GenerateReqInput:
image_data: Optional[Union[List[str], str]] = None image_data: Optional[Union[List[str], str]] = None
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
return_normalized_logprob: Optional[Union[List[bool], bool]] = None return_logprob: Optional[Union[List[bool], bool]] = None
normalized_logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
stream: bool = False stream: bool = False
def post_init(self): def post_init(self):
...@@ -23,10 +23,10 @@ class GenerateReqInput: ...@@ -23,10 +23,10 @@ class GenerateReqInput:
self.sampling_params = {} self.sampling_params = {}
if self.rid is None: if self.rid is None:
self.rid = uuid.uuid4().hex self.rid = uuid.uuid4().hex
if self.return_normalized_logprob is None: if self.return_logprob is None:
self.return_normalized_logprob = False self.return_logprob = False
if self.normalized_logprob_start_len is None: if self.logprob_start_len is None:
self.normalized_logprob_start_len = 0 self.logprob_start_len = 0
else: else:
num = len(self.text) num = len(self.text)
...@@ -45,17 +45,15 @@ class GenerateReqInput: ...@@ -45,17 +45,15 @@ class GenerateReqInput:
else: else:
assert isinstance(self.rid, list) assert isinstance(self.rid, list)
if self.return_normalized_logprob is None: if self.return_logprob is None:
self.return_normalized_logprob = [False] * num self.return_logprob = [False] * num
elif not isinstance(self.return_normalized_logprob, list): elif not isinstance(self.return_logprob, list):
self.return_normalized_logprob = [self.return_normalized_logprob] * num self.return_logprob = [self.return_logprob] * num
if self.normalized_logprob_start_len is None: if self.logprob_start_len is None:
self.normalized_logprob_start_len = [0] * num self.logprob_start_len = [0] * num
elif not isinstance(self.normalized_logprob_start_len, list): elif not isinstance(self.logprob_start_len, list):
self.normalized_logprob_start_len = [ self.logprob_start_len = [self.logprob_start_len] * num
self.normalized_logprob_start_len
] * num
@dataclass @dataclass
...@@ -65,8 +63,8 @@ class TokenizedGenerateReqInput: ...@@ -65,8 +63,8 @@ class TokenizedGenerateReqInput:
pixel_values: List[float] pixel_values: List[float]
image_hash: int image_hash: int
sampling_params: SamplingParams sampling_params: SamplingParams
return_normalized_logprob: bool return_logprob: bool
normalized_logprob_start_len: int logprob_start_len: int
stream: bool stream: bool
......
...@@ -28,8 +28,8 @@ class Req: ...@@ -28,8 +28,8 @@ class Req:
self.pixel_values = None self.pixel_values = None
self.image_offset = 0 self.image_offset = 0
self.sampling_params = None self.sampling_params = None
self.return_normalized_logprob = False self.return_logprob = False
self.normalized_logprob_start_len = 0 self.logprob_start_len = 0
self.stream = False self.stream = False
self.tokenizer = None self.tokenizer = None
...@@ -37,10 +37,11 @@ class Req: ...@@ -37,10 +37,11 @@ class Req:
self.finish_reason = None self.finish_reason = None
self.hit_stop_str = None self.hit_stop_str = None
self.adjust_input_len = 0 self.extend_input_len = 0
self.prefix_indices = [] self.prefix_indices = []
self.last_node = None self.last_node = None
self.logprob = None
self.normalized_logprob = None self.normalized_logprob = None
# for constrained decoding # for constrained decoding
...@@ -99,7 +100,7 @@ class Batch: ...@@ -99,7 +100,7 @@ class Batch:
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None out_cache_cont_end: torch.Tensor = None
return_normalized_logprob: bool = False return_logprob: bool = False
# for multimodal # for multimodal
pixel_values: List[torch.Tensor] = None pixel_values: List[torch.Tensor] = None
...@@ -119,14 +120,14 @@ class Batch: ...@@ -119,14 +120,14 @@ class Batch:
@classmethod @classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_normalized_logprob = any(req.return_normalized_logprob for req in reqs) return_logprob = any(req.return_logprob for req in reqs)
return cls( return cls(
reqs=reqs, reqs=reqs,
req_to_token_pool=req_to_token_pool, req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool, token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache, tree_cache=tree_cache,
return_normalized_logprob=return_normalized_logprob, return_logprob=return_logprob,
) )
def is_empty(self): def is_empty(self):
...@@ -257,7 +258,7 @@ class Batch: ...@@ -257,7 +258,7 @@ class Batch:
self.tree_cache.dec_ref_counter(req.last_node) self.tree_cache.dec_ref_counter(req.last_node)
req.prefix_indices = None req.prefix_indices = None
req.last_node = None req.last_node = None
req.adjust_input_len = 0 req.extend_input_len = 0
req.output_ids = [] req.output_ids = []
# TODO: apply more fine-grained retraction # TODO: apply more fine-grained retraction
...@@ -310,9 +311,7 @@ class Batch: ...@@ -310,9 +311,7 @@ class Batch:
self.prefix_lens = None self.prefix_lens = None
self.position_ids_offsets = self.position_ids_offsets[new_indices] self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
self.return_normalized_logprob = any( self.return_logprob = any(req.return_logprob for req in self.reqs)
req.return_normalized_logprob for req in self.reqs
)
for item in [ for item in [
"temperatures", "temperatures",
...@@ -336,9 +335,7 @@ class Batch: ...@@ -336,9 +335,7 @@ class Batch:
[self.position_ids_offsets, other.position_ids_offsets] [self.position_ids_offsets, other.position_ids_offsets]
) )
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
self.return_normalized_logprob = any( self.return_logprob = any(req.return_logprob for req in self.reqs)
req.return_normalized_logprob for req in self.reqs
)
for item in [ for item in [
"temperatures", "temperatures",
......
...@@ -214,8 +214,8 @@ class ModelRpcServer(rpyc.Service): ...@@ -214,8 +214,8 @@ class ModelRpcServer(rpyc.Service):
req.input_ids, pad_value req.input_ids, pad_value
) )
req.sampling_params = recv_req.sampling_params req.sampling_params = recv_req.sampling_params
req.return_normalized_logprob = recv_req.return_normalized_logprob req.return_logprob = recv_req.return_logprob
req.normalized_logprob_start_len = recv_req.normalized_logprob_start_len req.logprob_start_len = recv_req.logprob_start_len
req.stream = recv_req.stream req.stream = recv_req.stream
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
...@@ -240,9 +240,9 @@ class ModelRpcServer(rpyc.Service): ...@@ -240,9 +240,9 @@ class ModelRpcServer(rpyc.Service):
for req in self.forward_queue: for req in self.forward_queue:
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_normalized_logprob: if req.return_logprob:
prefix_indices = prefix_indices[: req.normalized_logprob_start_len] prefix_indices = prefix_indices[: req.logprob_start_len]
req.adjust_input_len = len(req.input_ids) - len(prefix_indices) req.extend_input_len = len(req.input_ids) - len(prefix_indices)
req.prefix_indices = prefix_indices req.prefix_indices = prefix_indices
req.last_node = last_node req.last_node = last_node
...@@ -267,32 +267,32 @@ class ModelRpcServer(rpyc.Service): ...@@ -267,32 +267,32 @@ class ModelRpcServer(rpyc.Service):
) )
for req in self.forward_queue: for req in self.forward_queue:
if req.return_normalized_logprob: if req.return_logprob:
# Need at least two tokens to compute normalized logprob # Need at least two tokens to compute normalized logprob
if req.adjust_input_len < 2: if req.extend_input_len < 2:
delta = 2 - req.adjust_input_len delta = 2 - req.extend_input_len
req.adjust_input_len += delta req.extend_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta] req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None: if req.image_offset is not None:
req.image_offset += delta req.image_offset += delta
if req.adjust_input_len == 0 and req.max_new_tokens() > 0: if req.extend_input_len == 0 and req.max_new_tokens() > 0:
# Need at least one token to compute logits # Need at least one token to compute logits
req.adjust_input_len = 1 req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1] req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None: if req.image_offset is not None:
req.image_offset += 1 req.image_offset += 1
if ( if (
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size < available_size
and req.adjust_input_len + new_batch_input_tokens and req.extend_input_len + new_batch_input_tokens
< self.max_prefill_num_token < self.max_prefill_num_token
): ):
delta = self.tree_cache.inc_ref_counter(req.last_node) delta = self.tree_cache.inc_ref_counter(req.last_node)
available_size += delta available_size += delta
if not ( if not (
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size < available_size
): ):
delta = self.tree_cache.dec_ref_counter(req.last_node) delta = self.tree_cache.dec_ref_counter(req.last_node)
...@@ -301,9 +301,9 @@ class ModelRpcServer(rpyc.Service): ...@@ -301,9 +301,9 @@ class ModelRpcServer(rpyc.Service):
self.token_to_kv_pool.add_refs(req.prefix_indices) self.token_to_kv_pool.add_refs(req.prefix_indices)
can_run_list.append(req) can_run_list.append(req)
new_batch_total_tokens += ( new_batch_total_tokens += (
req.adjust_input_len + req.max_new_tokens() req.extend_input_len + req.max_new_tokens()
) )
new_batch_input_tokens += req.adjust_input_len new_batch_input_tokens += req.extend_input_len
if len(can_run_list) == 0: if len(can_run_list) == 0:
return None return None
...@@ -339,27 +339,31 @@ class ModelRpcServer(rpyc.Service): ...@@ -339,27 +339,31 @@ class ModelRpcServer(rpyc.Service):
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
# Forward # Forward
logits, normalized_logprobs = self.model_runner.forward( logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
batch, ForwardMode.EXTEND, batch.return_normalized_logprob batch, ForwardMode.EXTEND, batch.return_logprob
) )
# print("extend logits", logits) # print("extend logits", logits)
if normalized_logprobs is not None: if logprobs is not None:
logprobs = logprobs.cpu().tolist()
normalized_logprobs = normalized_logprobs.cpu().tolist() normalized_logprobs = normalized_logprobs.cpu().tolist()
next_token_ids, next_token_probs = batch.sample(logits) next_token_ids, next_token_probs = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist() next_token_ids = next_token_ids.cpu().tolist()
else: else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
normalized_logprobs = None logprobs = normalized_logprobs = None
# Check finish condition # Check finish condition
reqs = batch.reqs reqs = batch.reqs
for i in range(len(reqs)): pt = 0
reqs[i].output_ids = [next_token_ids[i]] for i, req in enumerate(reqs):
reqs[i].check_finished() req.output_ids = [next_token_ids[i]]
req.check_finished()
if normalized_logprobs is not None: if logprobs is not None:
reqs[i].normalized_logprob = normalized_logprobs[i] req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
req.normalized_logprob = normalized_logprobs[i]
pt += req.extend_input_len
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
...@@ -427,8 +431,9 @@ class ModelRpcServer(rpyc.Service): ...@@ -427,8 +431,9 @@ class ModelRpcServer(rpyc.Service):
"prompt_tokens": len(req.input_ids), "prompt_tokens": len(req.input_ids),
"completion_tokens": len(req.output_ids), "completion_tokens": len(req.output_ids),
} }
if req.return_normalized_logprob: if req.return_logprob:
meta_info["normalized_logprob"] = req.normalized_logprob meta_info["prompt_logprob"] = req.logprob
meta_info["normalized_prompt_logprob"] = req.normalized_logprob
output_meta_info.append(meta_info) output_meta_info.append(meta_info)
output_finished.append(req.finished) output_finished.append(req.finished)
......
...@@ -45,7 +45,7 @@ class InputMetadata: ...@@ -45,7 +45,7 @@ class InputMetadata:
out_cache_cont_end: torch.Tensor = None out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None other_kv_index: torch.Tensor = None
return_normalized_logprob: bool = False return_logprob: bool = False
# for flashinfer # for flashinfer
use_flashinfer: bool = False use_flashinfer: bool = False
...@@ -127,7 +127,7 @@ class InputMetadata: ...@@ -127,7 +127,7 @@ class InputMetadata:
out_cache_loc, out_cache_loc,
out_cache_cont_start=None, out_cache_cont_start=None,
out_cache_cont_end=None, out_cache_cont_end=None,
return_normalized_logprob=False, return_logprob=False,
): ):
batch_size = len(req_pool_indices) batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
...@@ -175,7 +175,7 @@ class InputMetadata: ...@@ -175,7 +175,7 @@ class InputMetadata:
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start, out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end, out_cache_cont_end=out_cache_cont_end,
return_normalized_logprob=return_normalized_logprob, return_logprob=return_logprob,
other_kv_index=other_kv_index, other_kv_index=other_kv_index,
) )
...@@ -337,7 +337,7 @@ class ModelRunner: ...@@ -337,7 +337,7 @@ class ModelRunner:
prefix_lens, prefix_lens,
position_ids_offsets, position_ids_offsets,
out_cache_loc, out_cache_loc,
return_normalized_logprob, return_logprob,
): ):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
...@@ -348,7 +348,7 @@ class ModelRunner: ...@@ -348,7 +348,7 @@ class ModelRunner:
prefix_lens=prefix_lens, prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets, position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob, return_logprob=return_logprob,
) )
return self.model.forward(input_ids, input_metadata.positions, input_metadata) return self.model.forward(input_ids, input_metadata.positions, input_metadata)
...@@ -361,7 +361,7 @@ class ModelRunner: ...@@ -361,7 +361,7 @@ class ModelRunner:
prefix_lens, prefix_lens,
position_ids_offsets, position_ids_offsets,
out_cache_loc, out_cache_loc,
return_normalized_logprob, return_logprob,
): ):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
...@@ -372,7 +372,7 @@ class ModelRunner: ...@@ -372,7 +372,7 @@ class ModelRunner:
prefix_lens=prefix_lens, prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets, position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob, return_logprob=return_logprob,
) )
return self.model.forward(input_ids, input_metadata.positions, input_metadata) return self.model.forward(input_ids, input_metadata.positions, input_metadata)
...@@ -415,7 +415,7 @@ class ModelRunner: ...@@ -415,7 +415,7 @@ class ModelRunner:
prefix_lens, prefix_lens,
position_ids_offsets, position_ids_offsets,
out_cache_loc, out_cache_loc,
return_normalized_logprob, return_logprob,
): ):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
...@@ -426,7 +426,7 @@ class ModelRunner: ...@@ -426,7 +426,7 @@ class ModelRunner:
prefix_lens=prefix_lens, prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets, position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob, return_logprob=return_logprob,
) )
return self.model.forward( return self.model.forward(
input_ids, input_ids,
...@@ -436,9 +436,7 @@ class ModelRunner: ...@@ -436,9 +436,7 @@ class ModelRunner:
image_offsets, image_offsets,
) )
def forward( def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
self, batch: Batch, forward_mode: ForwardMode, return_normalized_logprob=False
):
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
kwargs = { kwargs = {
"input_ids": batch.input_ids, "input_ids": batch.input_ids,
...@@ -450,7 +448,7 @@ class ModelRunner: ...@@ -450,7 +448,7 @@ class ModelRunner:
"position_ids_offsets": batch.position_ids_offsets, "position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc, "out_cache_loc": batch.out_cache_loc,
} }
kwargs["return_normalized_logprob"] = return_normalized_logprob kwargs["return_logprob"] = return_logprob
return self.forward_extend_multi_modal(**kwargs) return self.forward_extend_multi_modal(**kwargs)
else: else:
kwargs = { kwargs = {
...@@ -467,10 +465,10 @@ class ModelRunner: ...@@ -467,10 +465,10 @@ class ModelRunner:
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
return self.forward_decode(**kwargs) return self.forward_decode(**kwargs)
elif forward_mode == ForwardMode.EXTEND: elif forward_mode == ForwardMode.EXTEND:
kwargs["return_normalized_logprob"] = return_normalized_logprob kwargs["return_logprob"] = return_logprob
return self.forward_extend(**kwargs) return self.forward_extend(**kwargs)
elif forward_mode == ForwardMode.PREFILL: elif forward_mode == ForwardMode.PREFILL:
kwargs["return_normalized_logprob"] = return_normalized_logprob kwargs["return_logprob"] = return_logprob
return self.forward_prefill(**kwargs) return self.forward_prefill(**kwargs)
else: else:
raise ValueError(f"Invaid forward mode: {forward_mode}") raise ValueError(f"Invaid forward mode: {forward_mode}")
...@@ -132,8 +132,8 @@ class TokenizerManager: ...@@ -132,8 +132,8 @@ class TokenizerManager:
pixel_values=pixel_values, pixel_values=pixel_values,
image_hash=image_hash, image_hash=image_hash,
sampling_params=sampling_params, sampling_params=sampling_params,
return_normalized_logprob=obj.return_normalized_logprob, return_logprob=obj.return_logprob,
normalized_logprob_start_len=obj.normalized_logprob_start_len, logprob_start_len=obj.logprob_start_len,
stream=obj.stream, stream=obj.stream,
) )
self.send_to_router.send_pyobj(tokenized_obj) self.send_to_router.send_pyobj(tokenized_obj)
...@@ -173,8 +173,8 @@ class TokenizerManager: ...@@ -173,8 +173,8 @@ class TokenizerManager:
pixel_values=pixel_values, pixel_values=pixel_values,
image_hash=image_hash, image_hash=image_hash,
sampling_params=sampling_params, sampling_params=sampling_params,
return_normalized_logprob=obj.return_normalized_logprob[i], return_logprob=obj.return_logprob[i],
normalized_logprob_start_len=obj.normalized_logprob_start_len[i], logprob_start_len=obj.logprob_start_len[i],
stream=obj.stream, stream=obj.stream,
) )
self.send_to_router.send_pyobj(tokenized_obj) self.send_to_router.send_pyobj(tokenized_obj)
......
...@@ -26,6 +26,8 @@ if __name__ == "__main__": ...@@ -26,6 +26,8 @@ if __name__ == "__main__":
"temperature": 0, "temperature": 0,
"max_new_tokens": 32, "max_new_tokens": 32,
}, },
# "return_logprob": True,
# "logprob_start_len": 0,
}, },
) )
print(response.json()) print(response.json())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment