Unverified Commit 963175d5 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Support streaming for v1/chat/completions (#11179)

parent 0618ad6d
......@@ -578,7 +578,7 @@ class GrpcRequestManager:
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
),
"finish_reason": (
str(batch_out.finished_reasons[i])
batch_out.finished_reasons[i]
if batch_out.finished_reasons[i]
else None
),
......
......@@ -112,7 +112,6 @@ def _launch_scheduler_process_only(
pp_rank,
None,
writer,
None,
),
)
......@@ -583,6 +582,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
index=output.get("index", 0),
),
)
......@@ -640,6 +640,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
index=output.get("index", 0),
**matched_stop_kwargs,
),
)
......
......@@ -179,6 +179,9 @@ message GenerateStreamChunk {
// Input logprobs (if requested) - only in first chunk
InputLogProbs input_logprobs = 7;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 8;
}
message GenerateComplete {
......@@ -207,6 +210,9 @@ message GenerateComplete {
// Input logprobs if requested (for prompt tokens)
InputLogProbs input_logprobs = 10;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 11;
}
message GenerateError {
......
......@@ -29,7 +29,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe2\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x0e\n\x06stream\x18\x11 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\x86\x02\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12<\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\"\x8c\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12<\n\x0einput_logprobs\x18\n \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbsB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"u\n\x0eOutputLogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"\x9e\x01\n\rInputLogProbs\x12@\n\x0etoken_logprobs\x18\x01 \x03(\x0b\x32(.sglang.grpc.scheduler.InputTokenLogProb\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"1\n\x11InputTokenLogProb\x12\x12\n\x05value\x18\x01 \x01(\x02H\x00\x88\x01\x01\x42\x08\n\x06_value\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe2\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x0e\n\x06stream\x18\x11 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\x95\x02\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12<\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x08 \x01(\r\"\x9b\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12<\n\x0einput_logprobs\x18\n \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x0b \x01(\rB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"u\n\x0eOutputLogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"\x9e\x01\n\rInputLogProbs\x12@\n\x0etoken_logprobs\x18\x01 \x03(\x0b\x32(.sglang.grpc.scheduler.InputTokenLogProb\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"1\n\x11InputTokenLogProb\x12\x12\n\x05value\x18\x01 \x01(\x02H\x00\x88\x01\x01\x42\x08\n\x06_value\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
......@@ -53,59 +53,59 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_GENERATERESPONSE']._serialized_start=1835
_globals['_GENERATERESPONSE']._serialized_end=2062
_globals['_GENERATESTREAMCHUNK']._serialized_start=2065
_globals['_GENERATESTREAMCHUNK']._serialized_end=2327
_globals['_GENERATECOMPLETE']._serialized_start=2330
_globals['_GENERATECOMPLETE']._serialized_end=2726
_globals['_GENERATEERROR']._serialized_start=2728
_globals['_GENERATEERROR']._serialized_end=2803
_globals['_OUTPUTLOGPROBS']._serialized_start=2805
_globals['_OUTPUTLOGPROBS']._serialized_end=2922
_globals['_INPUTLOGPROBS']._serialized_start=2925
_globals['_INPUTLOGPROBS']._serialized_end=3083
_globals['_INPUTTOKENLOGPROB']._serialized_start=3085
_globals['_INPUTTOKENLOGPROB']._serialized_end=3134
_globals['_TOPLOGPROBS']._serialized_start=3136
_globals['_TOPLOGPROBS']._serialized_end=3184
_globals['_HIDDENSTATES']._serialized_start=3186
_globals['_HIDDENSTATES']._serialized_end=3249
_globals['_EMBEDREQUEST']._serialized_start=3252
_globals['_EMBEDREQUEST']._serialized_end=3582
_globals['_EMBEDRESPONSE']._serialized_start=3585
_globals['_EMBEDRESPONSE']._serialized_end=3742
_globals['_EMBEDCOMPLETE']._serialized_start=3745
_globals['_EMBEDCOMPLETE']._serialized_end=3908
_globals['_EMBEDDING']._serialized_start=3910
_globals['_EMBEDDING']._serialized_end=3952
_globals['_EMBEDERROR']._serialized_start=3954
_globals['_EMBEDERROR']._serialized_end=4014
_globals['_HEALTHCHECKREQUEST']._serialized_start=4016
_globals['_HEALTHCHECKREQUEST']._serialized_end=4094
_globals['_HEALTHCHECKRESPONSE']._serialized_start=4096
_globals['_HEALTHCHECKRESPONSE']._serialized_end=4151
_globals['_ABORTREQUEST']._serialized_start=4153
_globals['_ABORTREQUEST']._serialized_end=4203
_globals['_ABORTRESPONSE']._serialized_start=4205
_globals['_ABORTRESPONSE']._serialized_end=4254
_globals['_LOADLORAREQUEST']._serialized_start=4256
_globals['_LOADLORAREQUEST']._serialized_end=4329
_globals['_LOADLORARESPONSE']._serialized_start=4331
_globals['_LOADLORARESPONSE']._serialized_end=4403
_globals['_UNLOADLORAREQUEST']._serialized_start=4405
_globals['_UNLOADLORAREQUEST']._serialized_end=4444
_globals['_UNLOADLORARESPONSE']._serialized_start=4446
_globals['_UNLOADLORARESPONSE']._serialized_end=4500
_globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4502
_globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4621
_globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4623
_globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4680
_globals['_GETINTERNALSTATEREQUEST']._serialized_start=4682
_globals['_GETINTERNALSTATEREQUEST']._serialized_end=4727
_globals['_GETINTERNALSTATERESPONSE']._serialized_start=4729
_globals['_GETINTERNALSTATERESPONSE']._serialized_end=4795
_globals['_SETINTERNALSTATEREQUEST']._serialized_start=4797
_globals['_SETINTERNALSTATEREQUEST']._serialized_end=4862
_globals['_SETINTERNALSTATERESPONSE']._serialized_start=4864
_globals['_SETINTERNALSTATERESPONSE']._serialized_end=4924
_globals['_SGLANGSCHEDULER']._serialized_start=4927
_globals['_SGLANGSCHEDULER']._serialized_end=5309
_globals['_GENERATESTREAMCHUNK']._serialized_end=2342
_globals['_GENERATECOMPLETE']._serialized_start=2345
_globals['_GENERATECOMPLETE']._serialized_end=2756
_globals['_GENERATEERROR']._serialized_start=2758
_globals['_GENERATEERROR']._serialized_end=2833
_globals['_OUTPUTLOGPROBS']._serialized_start=2835
_globals['_OUTPUTLOGPROBS']._serialized_end=2952
_globals['_INPUTLOGPROBS']._serialized_start=2955
_globals['_INPUTLOGPROBS']._serialized_end=3113
_globals['_INPUTTOKENLOGPROB']._serialized_start=3115
_globals['_INPUTTOKENLOGPROB']._serialized_end=3164
_globals['_TOPLOGPROBS']._serialized_start=3166
_globals['_TOPLOGPROBS']._serialized_end=3214
_globals['_HIDDENSTATES']._serialized_start=3216
_globals['_HIDDENSTATES']._serialized_end=3279
_globals['_EMBEDREQUEST']._serialized_start=3282
_globals['_EMBEDREQUEST']._serialized_end=3612
_globals['_EMBEDRESPONSE']._serialized_start=3615
_globals['_EMBEDRESPONSE']._serialized_end=3772
_globals['_EMBEDCOMPLETE']._serialized_start=3775
_globals['_EMBEDCOMPLETE']._serialized_end=3938
_globals['_EMBEDDING']._serialized_start=3940
_globals['_EMBEDDING']._serialized_end=3982
_globals['_EMBEDERROR']._serialized_start=3984
_globals['_EMBEDERROR']._serialized_end=4044
_globals['_HEALTHCHECKREQUEST']._serialized_start=4046
_globals['_HEALTHCHECKREQUEST']._serialized_end=4124
_globals['_HEALTHCHECKRESPONSE']._serialized_start=4126
_globals['_HEALTHCHECKRESPONSE']._serialized_end=4181
_globals['_ABORTREQUEST']._serialized_start=4183
_globals['_ABORTREQUEST']._serialized_end=4233
_globals['_ABORTRESPONSE']._serialized_start=4235
_globals['_ABORTRESPONSE']._serialized_end=4284
_globals['_LOADLORAREQUEST']._serialized_start=4286
_globals['_LOADLORAREQUEST']._serialized_end=4359
_globals['_LOADLORARESPONSE']._serialized_start=4361
_globals['_LOADLORARESPONSE']._serialized_end=4433
_globals['_UNLOADLORAREQUEST']._serialized_start=4435
_globals['_UNLOADLORAREQUEST']._serialized_end=4474
_globals['_UNLOADLORARESPONSE']._serialized_start=4476
_globals['_UNLOADLORARESPONSE']._serialized_end=4530
_globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4532
_globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4651
_globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4653
_globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4710
_globals['_GETINTERNALSTATEREQUEST']._serialized_start=4712
_globals['_GETINTERNALSTATEREQUEST']._serialized_end=4757
_globals['_GETINTERNALSTATERESPONSE']._serialized_start=4759
_globals['_GETINTERNALSTATERESPONSE']._serialized_end=4825
_globals['_SETINTERNALSTATEREQUEST']._serialized_start=4827
_globals['_SETINTERNALSTATEREQUEST']._serialized_end=4892
_globals['_SETINTERNALSTATERESPONSE']._serialized_start=4894
_globals['_SETINTERNALSTATERESPONSE']._serialized_end=4954
_globals['_SGLANGSCHEDULER']._serialized_start=4957
_globals['_SGLANGSCHEDULER']._serialized_end=5339
# @@protoc_insertion_point(module_scope)
......@@ -160,7 +160,7 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message):
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs")
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs", "index")
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
......@@ -168,6 +168,7 @@ class GenerateStreamChunk(_message.Message):
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
INDEX_FIELD_NUMBER: _ClassVar[int]
token_ids: _containers.RepeatedScalarFieldContainer[int]
prompt_tokens: int
completion_tokens: int
......@@ -175,10 +176,11 @@ class GenerateStreamChunk(_message.Message):
output_logprobs: OutputLogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float]
input_logprobs: InputLogProbs
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ...
index: int
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs")
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs", "index")
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
......@@ -189,6 +191,7 @@ class GenerateComplete(_message.Message):
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
INDEX_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: str
prompt_tokens: int
......@@ -199,7 +202,8 @@ class GenerateComplete(_message.Message):
matched_token_id: int
matched_stop_str: str
input_logprobs: InputLogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ...
index: int
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details")
......
......@@ -192,7 +192,6 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)),
name: None,
tool_calls: None,
function_call: None,
reasoning_content: None,
});
}
......
......@@ -179,6 +179,9 @@ message GenerateStreamChunk {
// Input logprobs (if requested) - only in first chunk
InputLogProbs input_logprobs = 7;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 8;
}
message GenerateComplete {
......@@ -207,6 +210,9 @@ message GenerateComplete {
// Input logprobs if requested (for prompt tokens)
InputLogProbs input_logprobs = 10;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 11;
}
message GenerateError {
......
......@@ -72,8 +72,6 @@ pub enum ChatMessage {
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<FunctionCallResponse>,
/// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>,
......@@ -140,8 +138,6 @@ pub struct ChatMessageDelta {
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCallDelta>,
/// Reasoning content delta for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>,
......@@ -473,6 +469,8 @@ pub struct ChatStreamChoice {
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<Value>,
}
// Completions API request types (v1/completions) - DEPRECATED but still supported
......
......@@ -44,7 +44,7 @@ graph TB
end
subgraph Factory Layer
MID --> PF[ParserFactory]
MID --> PF[ReasoningParserFactory]
PF --> REG[ParserRegistry]
REG --> PM[Pattern Matching]
PM --> PP[Parser Pool]
......@@ -93,7 +93,7 @@ graph TB
```mermaid
sequenceDiagram
participant C as Client
participant F as ParserFactory
participant F as ReasoningParserFactory
participant R as Registry
participant P as Parser Pool
participant BP as BaseParser
......@@ -206,7 +206,7 @@ classDiagram
+new() Self
}
class ParserFactory {
class ReasoningParserFactory {
-registry: ParserRegistry
+new() Self
+get_pooled(model_id: &str) PooledParser
......@@ -240,7 +240,7 @@ classDiagram
Step3Parser o-- BaseReasoningParser
BaseReasoningParser o-- ParserConfig
ParserFactory o-- ParserRegistry
ReasoningParserFactory o-- ParserRegistry
ParserRegistry o-- ReasoningParser
```
......@@ -302,7 +302,7 @@ classDiagram
- Delegate to get_pooled_parser
- Case-insensitive comparison
**ParserFactory Methods**:
**ReasoningParserFactory Methods**:
1. **`new()`**:
- Register all built-in parsers
......@@ -437,7 +437,7 @@ impl ReasoningParser for MyModelParser {
**Step 2: Register in Factory**
```rust
// In factory.rs ParserFactory::new()
// In factory.rs ReasoningParserFactory::new()
registry.register_parser("mymodel", || {
Box::new(MyModelParser::new())
});
......
......@@ -128,11 +128,11 @@ impl Default for ParserRegistry {
/// Factory for creating reasoning parsers based on model type.
#[derive(Clone)]
pub struct ParserFactory {
pub struct ReasoningParserFactory {
registry: ParserRegistry,
}
impl ParserFactory {
impl ReasoningParserFactory {
/// Create a new factory with default parsers registered.
pub fn new() -> Self {
let registry = ParserRegistry::new();
......@@ -237,7 +237,7 @@ impl ParserFactory {
}
}
impl Default for ParserFactory {
impl Default for ReasoningParserFactory {
fn default() -> Self {
Self::new()
}
......@@ -249,35 +249,35 @@ mod tests {
#[test]
fn test_factory_creates_deepseek_r1() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let parser = factory.create("deepseek-r1-distill").unwrap();
assert_eq!(parser.model_type(), "deepseek_r1");
}
#[test]
fn test_factory_creates_qwen3() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let parser = factory.create("qwen3-7b").unwrap();
assert_eq!(parser.model_type(), "qwen3");
}
#[test]
fn test_factory_creates_kimi() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let parser = factory.create("kimi-chat").unwrap();
assert_eq!(parser.model_type(), "kimi");
}
#[test]
fn test_factory_fallback_to_passthrough() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let parser = factory.create("unknown-model").unwrap();
assert_eq!(parser.model_type(), "passthrough");
}
#[test]
fn test_case_insensitive_matching() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let parser1 = factory.create("DeepSeek-R1").unwrap();
let parser2 = factory.create("QWEN3").unwrap();
let parser3 = factory.create("Kimi").unwrap();
......@@ -289,21 +289,21 @@ mod tests {
#[test]
fn test_step3_model() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let step3 = factory.create("step3-model").unwrap();
assert_eq!(step3.model_type(), "step3");
}
#[test]
fn test_glm45_model() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let glm45 = factory.create("glm45-v2").unwrap();
assert_eq!(glm45.model_type(), "glm45");
}
#[test]
fn test_pooled_parser_reuse() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
// Get the same parser twice - should be the same instance
let parser1 = factory.get_pooled("deepseek-r1");
......@@ -321,7 +321,7 @@ mod tests {
fn test_pooled_parser_concurrent_access() {
use std::thread;
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let parser = factory.get_pooled("deepseek-r1");
// Spawn multiple threads that use the same parser
......@@ -347,7 +347,7 @@ mod tests {
#[test]
fn test_pool_clearing() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
// Get a pooled parser
let parser1 = factory.get_pooled("deepseek-r1");
......@@ -364,7 +364,7 @@ mod tests {
#[test]
fn test_passthrough_parser_pooling() {
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
// Unknown models should get passthrough parser
let parser1 = factory.get_pooled("unknown-model-1");
......@@ -383,7 +383,7 @@ mod tests {
use std::thread;
use std::time::Instant;
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let num_threads = 100;
let requests_per_thread = 50;
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
......@@ -527,7 +527,7 @@ mod tests {
fn test_concurrent_pool_modifications() {
use std::thread;
let factory = ParserFactory::new();
let factory = ReasoningParserFactory::new();
let mut handles = vec![];
// Thread 1: Continuously get parsers
......
......@@ -2,7 +2,7 @@ pub mod factory;
pub mod parsers;
pub mod traits;
pub use factory::{ParserFactory, ParserRegistry, PooledParser};
pub use factory::{ParserRegistry, PooledParser, ReasoningParserFactory};
pub use parsers::{
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
QwenThinkingParser, Step3Parser,
......
......@@ -4,7 +4,7 @@ use crate::config::types::RetryConfig;
use crate::core::{WorkerRegistry, WorkerType};
use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry;
use crate::reasoning_parser::ParserFactory;
use crate::reasoning_parser::ReasoningParserFactory;
use crate::routers::RouterTrait;
use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory;
......@@ -24,7 +24,7 @@ pub struct GrpcPDRouter {
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory,
reasoning_parser_factory: ReasoningParserFactory,
tool_parser_factory: ToolParserFactory,
dp_aware: bool,
......
......@@ -7,10 +7,14 @@ use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
Json,
};
use bytes::Bytes;
use std::io;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn};
use crate::config::types::RetryConfig;
......@@ -21,11 +25,12 @@ use crate::policies::PolicyRegistry;
use crate::protocols::spec::ChatMessage;
use crate::protocols::spec::{
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
CompletionRequest, EmbeddingRequest, FunctionCallResponse, GenerateRequest, RerankRequest,
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolChoice,
ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, CompletionRequest,
EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, GenerateRequest, RerankRequest,
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice,
ToolChoiceValue, Usage,
};
use crate::reasoning_parser::ParserFactory;
use crate::reasoning_parser::{ParserResult, ReasoningParserFactory};
use crate::routers::RouterTrait;
use crate::server::AppContext;
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
......@@ -34,7 +39,7 @@ use crate::tokenizer::stop::{
};
use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer;
use crate::tool_parser::ToolParserFactory;
use crate::tool_parser::{StreamingParseResult, ToolParserFactory};
use proto::generate_response::Response::{Chunk, Complete, Error};
use serde_json::{json, Map, Value};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
......@@ -50,12 +55,13 @@ pub struct ProcessedMessages {
}
/// gRPC router implementation for SGLang
#[derive(Clone)]
#[allow(dead_code)]
pub struct GrpcRouter {
worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory,
reasoning_parser_factory: ReasoningParserFactory,
tool_parser_factory: ToolParserFactory,
dp_aware: bool,
api_key: Option<String>,
......@@ -776,10 +782,11 @@ impl GrpcRouter {
}
/// Parse tool calls using model-specific parser
async fn parse_with_model_parser(
async fn parse_tool_calls(
&self,
processed_text: &str,
model: &str,
history_tool_calls_count: usize,
) -> (Option<Vec<ToolCall>>, String) {
// Get pooled parser for this model
let pooled_parser = self.tool_parser_factory.get_pooled(model);
......@@ -810,16 +817,26 @@ impl GrpcRouter {
let spec_tool_calls = parsed_tool_calls
.into_iter()
.map(|tc| ToolCall {
id: tc.id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
.enumerate()
.map(|(index, tc)| {
// Generate ID for this tool call
let id = Self::generate_tool_call_id(
model,
&tc.function.name,
index,
history_tool_calls_count,
);
ToolCall {
id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
}
})
.collect();
(Some(spec_tool_calls), normal_text)
......@@ -920,6 +937,47 @@ impl GrpcRouter {
builder.build()
}
/// Count the number of tool calls in the request message history
/// This is used for KimiK2 format which needs globally unique indices
fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize {
request
.messages
.iter()
.filter_map(|msg| {
if let ChatMessage::Assistant { tool_calls, .. } = msg {
tool_calls.as_ref().map(|calls| calls.len())
} else {
None
}
})
.sum()
}
/// Generate a tool call ID based on model format
///
/// # Arguments
/// * `model` - Model name to determine ID format
/// * `tool_name` - Name of the tool being called
/// * `tool_index` - Index of this tool call within the current message
/// * `history_count` - Number of tool calls in previous messages
///
/// # Returns
/// A unique ID string. KimiK2 uses `functions.{name}:{global_index}`, others use `call_{uuid}`
fn generate_tool_call_id(
model: &str,
tool_name: &str,
tool_index: usize,
history_count: usize,
) -> String {
if model.to_lowercase().contains("kimi") {
// KimiK2 format: functions.{name}:{global_index}
format!("functions.{}:{}", tool_name, history_count + tool_index)
} else {
// Standard OpenAI format: call_{24-char-uuid}
format!("call_{}", &Uuid::new_v4().simple().to_string()[..24])
}
}
/// Process a chunk of tokens through the stop decoder
fn process_chunk_tokens(
stop_decoder: &mut StopSequenceDecoder,
......@@ -953,6 +1011,230 @@ impl GrpcRouter {
(chunk_text, false) // Return text and continue processing
}
/// Helper: Process reasoning content in streaming mode
/// Returns (modified_delta, optional_reasoning_chunk)
fn process_reasoning_stream(
&self,
delta: &str,
index: u32,
reasoning_parsers: &mut HashMap<
u32,
Arc<std::sync::Mutex<Box<dyn crate::reasoning_parser::ReasoningParser>>>,
>,
request_id: &str,
model: &str,
created: u64,
) -> (String, Option<ChatCompletionStreamResponse>) {
// Get or create parser for this index
reasoning_parsers
.entry(index)
.or_insert_with(|| self.reasoning_parser_factory.get_pooled(model));
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
let parse_result = {
let mut parser = pooled_parser.lock().unwrap();
parser.parse_reasoning_streaming_incremental(delta)
};
match parse_result {
Ok(ParserResult {
reasoning_text,
normal_text,
}) => {
let chunk = if !reasoning_text.is_empty() {
Some(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: Some(reasoning_text),
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
})
} else {
None
};
return (normal_text, chunk);
}
Err(e) => {
warn!("Reasoning parsing error: {}", e);
}
}
}
(delta.to_string(), None)
}
/// Helper: Process tool calls in streaming mode
/// Returns (should_skip_content, chunks_to_emit)
#[allow(clippy::too_many_arguments)]
async fn process_tool_calls_stream(
&self,
delta: &str,
index: u32,
tool_parsers: &mut HashMap<
u32,
Arc<tokio::sync::Mutex<Box<dyn crate::tool_parser::ToolParser>>>,
>,
has_tool_calls: &mut HashMap<u32, bool>,
tools: &[crate::protocols::spec::Tool],
request_id: &str,
model: &str,
created: u64,
history_tool_calls_count: usize,
) -> (bool, Vec<ChatCompletionStreamResponse>) {
let mut chunks = Vec::new();
// Get or create parser for this index
tool_parsers
.entry(index)
.or_insert_with(|| self.tool_parser_factory.get_pooled(model));
if let Some(pooled_parser) = tool_parsers.get(&index) {
let mut parser = pooled_parser.lock().await;
match parser.parse_incremental(delta, tools).await {
Ok(StreamingParseResult { normal_text, calls }) => {
// Emit normal text if present
if !normal_text.is_empty() {
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(normal_text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// Emit tool call chunks
for tool_call_item in calls {
has_tool_calls.insert(index, true);
let tool_call_id = if let Some(ref name) = tool_call_item.name {
Some(Self::generate_tool_call_id(
model,
name,
tool_call_item.tool_index,
history_tool_calls_count,
))
} else {
None
};
let tool_call_delta = ToolCallDelta {
index: tool_call_item.tool_index as u32,
id: tool_call_id,
tool_type: if tool_call_item.name.is_some() {
Some("function".to_string())
} else {
None
},
function: Some(FunctionCallDelta {
name: tool_call_item.name,
arguments: if !tool_call_item.parameters.is_empty() {
Some(tool_call_item.parameters)
} else {
None
},
}),
};
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![tool_call_delta]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// If we emitted chunks, skip regular content
return (!chunks.is_empty(), chunks);
}
Err(e) => {
warn!("Tool call parsing error: {}", e);
}
}
}
(false, chunks)
}
/// Helper: Create content chunk
fn create_content_chunk(
content: String,
index: u32,
request_id: &str,
model: &str,
created: u64,
logprobs: Option<crate::protocols::spec::ChatLogProbs>,
) -> ChatCompletionStreamResponse {
ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(content),
tool_calls: None,
reasoning_content: None,
},
logprobs,
finish_reason: None,
matched_stop: None,
}],
usage: None,
}
}
/// Helper: Format response as SSE chunk
fn format_sse_chunk(response: &ChatCompletionStreamResponse) -> String {
format!(
"data: {}\n\n",
serde_json::to_string(response).unwrap_or_default()
)
}
/// Submit request and handle streaming response for chat completions route
async fn handle_streaming_chat(
&self,
......@@ -960,14 +1242,13 @@ impl GrpcRouter {
request: proto::GenerateRequest,
original_request: &ChatCompletionRequest,
) -> Response {
let mut stop_decoder = self.create_stop_decoder(
original_request.stop.as_ref(),
original_request.stop_token_ids.as_ref(),
original_request.skip_special_tokens,
original_request.no_stop_trim,
);
let request_id = request.request_id.clone();
let model = original_request.model.clone();
// Create channel for SSE streaming
let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
// Process streaming tokens
// Start the gRPC stream
let mut grpc_stream = match client.generate(request).await {
Ok(stream) => stream,
Err(e) => {
......@@ -980,49 +1261,414 @@ impl GrpcRouter {
}
};
let mut decoded_text = String::new();
let stop_params = (
original_request.stop.clone(),
original_request.stop_token_ids.clone(),
original_request.skip_special_tokens,
original_request.no_stop_trim,
);
// Spawn processing task
let self_clone = self.clone();
let original_request_clone = original_request.clone();
tokio::spawn(async move {
let result = Self::process_streaming_chunks(
&self_clone,
&mut grpc_stream,
request_id,
model,
stop_params,
original_request_clone,
&tx,
)
.await;
if let Err(e) = result {
let error_chunk = format!(
"data: {}\n\n",
json!({
"error": {
"message": e,
"type": "internal_error"
}
})
);
let _ = tx.send(Ok(Bytes::from(error_chunk)));
}
// Send DONE marker
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
});
// Create response with SSE headers
let stream = UnboundedReceiverStream::new(rx);
let mut response = Response::new(Body::from_stream(stream));
*response.status_mut() = StatusCode::OK;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
.headers_mut()
.insert("Cache-Control", HeaderValue::from_static("no-cache"));
response
.headers_mut()
.insert("Connection", HeaderValue::from_static("keep-alive"));
response
}
/// Process streaming chunks and send SSE events
async fn process_streaming_chunks(
router: &GrpcRouter,
grpc_stream: &mut (impl tokio_stream::Stream<Item = Result<proto::GenerateResponse, tonic::Status>>
+ Unpin),
request_id: String,
model: String,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: ChatCompletionRequest,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<(), String> {
// Extract request parameters
let separate_reasoning = original_request.separate_reasoning;
let tool_choice = &original_request.tool_choice;
let tools = &original_request.tools;
let history_tool_calls_count = Self::get_history_tool_calls_count(&original_request);
let stream_options = &original_request.stream_options;
// Phase 1: Initialize state tracking (per-index for n>1 support)
let mut is_firsts: HashMap<u32, bool> = HashMap::new();
let mut stream_buffers: HashMap<u32, String> = HashMap::new();
let mut finish_reasons: HashMap<u32, String> = HashMap::new();
let mut matched_stops: HashMap<u32, Option<Value>> = HashMap::new();
let mut prompt_tokens: HashMap<u32, u32> = HashMap::new();
let mut completion_tokens: HashMap<u32, u32> = HashMap::new();
let mut cached_tokens: HashMap<u32, u32> = HashMap::new();
// Parser state (lazy initialization per index)
type PooledReasoningParser =
Arc<std::sync::Mutex<Box<dyn crate::reasoning_parser::ReasoningParser>>>;
let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new();
type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn crate::tool_parser::ToolParser>>>;
let mut tool_parsers: HashMap<u32, PooledToolParser> = HashMap::new();
let mut has_tool_calls: HashMap<u32, bool> = HashMap::new();
// Create stop decoder
let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params;
let mut stop_decoder = router.create_stop_decoder(
stop.as_ref(),
stop_token_ids.as_ref(),
skip_special_tokens,
no_stop_trim,
);
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
// Phase 2: Main streaming loop
while let Some(response) = grpc_stream.next().await {
let gen_response = match response {
Ok(resp) => resp,
Err(e) => {
error!("Stream error: {}", e);
break;
}
};
let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
match gen_response.response {
Some(Chunk(chunk)) => {
// Process tokens and check if we should stop
let (chunk_text, should_stop) =
let index = chunk.index;
// Process tokens through stop decoder
let (chunk_text, _should_stop) =
Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids);
decoded_text.push_str(&chunk_text);
if should_stop {
break;
if chunk_text.is_empty() {
continue;
}
// Process logprobs if present
let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs {
match router.convert_proto_to_openai_logprobs(proto_logprobs) {
Ok(logprobs) => Some(logprobs),
Err(e) => {
warn!("Failed to process logprobs: {}", e);
None
}
}
} else {
None
};
// Initialize stream buffer if first time
let stream_buffer = stream_buffers.entry(index).or_default();
// Send first chunk with role
if is_firsts.get(&index).copied().unwrap_or(true) {
let first_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk))))
.map_err(|_| "Failed to send first chunk".to_string())?;
is_firsts.insert(index, false);
}
// Calculate delta
let mut delta = chunk_text;
stream_buffer.push_str(&delta);
// Reasoning content handling
if separate_reasoning {
let (normal_text, reasoning_chunk) = router.process_reasoning_stream(
&delta,
index,
&mut reasoning_parsers,
&request_id,
&model,
created,
);
if let Some(chunk) = reasoning_chunk {
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
.map_err(|_| "Failed to send reasoning chunk".to_string())?;
}
delta = normal_text;
}
// Tool call handling
let tool_choice_enabled =
!matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None)));
if tool_choice_enabled && tools.is_some() {
let (should_skip, tool_chunks) = router
.process_tool_calls_stream(
&delta,
index,
&mut tool_parsers,
&mut has_tool_calls,
tools.as_ref().unwrap(),
&request_id,
&model,
created,
history_tool_calls_count,
)
.await;
for chunk in tool_chunks {
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
.map_err(|_| "Failed to send tool call chunk".to_string())?;
}
if should_skip {
continue;
}
}
// Regular content emission
if !delta.is_empty() {
let content_chunk = Self::create_content_chunk(
delta,
index,
&request_id,
&model,
created,
choice_logprobs,
);
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk))))
.map_err(|_| "Failed to send content chunk".to_string())?;
}
continue;
}
Some(Complete(_complete)) => {
Some(Complete(complete)) => {
// Flush any remaining text
if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() {
if !text.is_empty() {
decoded_text.push_str(&text);
debug!("Flushed text: {}", text);
let index = complete.index;
let stream_buffer = stream_buffers.entry(index).or_default();
stream_buffer.push_str(&text);
let content_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&content_chunk)
.map_err(|e| format!("Failed to serialize content chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send flushed content".to_string())?;
}
}
// Store metadata
let index = complete.index;
prompt_tokens.insert(index, complete.prompt_tokens as u32);
completion_tokens.insert(index, complete.completion_tokens as u32);
cached_tokens.insert(index, complete.cached_tokens as u32);
finish_reasons.insert(index, complete.finish_reason.clone());
// Extract matched_stop
let matched_stop_value = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id)))
}
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(Value::String(stop_str.clone()))
}
None => None,
};
matched_stops.insert(index, matched_stop_value);
break;
}
Some(Error(error)) => {
error!("Generation error: {}", error.message);
break;
return Err(error.message);
}
None => continue,
}
}
// TODO: Replace with proper SSE streaming response
// For now, return the complete decoded text
(StatusCode::OK, format!("Decoded text: {}", decoded_text)).into_response()
// Phase 3: Check unstreamed tool args
// Check if parsers have any remaining arguments that haven't been streamed yet
for (index, parser) in &tool_parsers {
let parser_guard = parser.lock().await;
if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() {
for tool_call_item in unstreamed_items {
let tool_call_delta = ToolCallDelta {
index: tool_call_item.tool_index as u32,
id: None,
tool_type: None, // No type for argument deltas
function: Some(FunctionCallDelta {
name: None, // No name for argument deltas
arguments: if !tool_call_item.parameters.is_empty() {
Some(tool_call_item.parameters)
} else {
None
},
}),
};
let tool_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index: *index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![tool_call_delta]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&tool_chunk)
.map_err(|e| format!("Failed to serialize tool chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send unstreamed tool args".to_string())?;
}
}
}
// Phase 4: Finish reason chunks
for (index, finish_reason) in finish_reasons.iter() {
let final_finish_reason =
if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" {
"tool_calls".to_string()
} else {
finish_reason.clone()
};
let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone());
let finish_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index: *index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: Some(final_finish_reason),
matched_stop: matched_stop_value,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&finish_chunk)
.map_err(|e| format!("Failed to serialize finish chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send finish chunk".to_string())?;
}
// Phase 5: Usage chunk
if let Some(stream_opts) = stream_options {
if stream_opts.include_usage.unwrap_or(false) {
let total_prompt: u32 = prompt_tokens.values().sum();
let total_completion: u32 = completion_tokens.values().sum();
let usage_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![],
usage: Some(Usage {
prompt_tokens: total_prompt,
completion_tokens: total_completion,
total_tokens: total_prompt + total_completion,
completion_tokens_details: None,
}),
};
let sse_chunk = serde_json::to_string(&usage_chunk)
.map_err(|e| format!("Failed to serialize usage chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send usage chunk".to_string())?;
}
}
Ok(())
}
/// Submit request and handle non-streaming response for chat completions route
......@@ -1082,10 +1728,17 @@ impl GrpcRouter {
}
// Process each response into a ChatChoice
let history_tool_calls_count = Self::get_history_tool_calls_count(original_request);
let mut choices = Vec::new();
for (index, complete) in all_responses.iter().enumerate() {
match self
.process_single_choice(complete, index, original_request, &mut stop_decoder)
.process_single_choice(
complete,
index,
original_request,
&mut stop_decoder,
history_tool_calls_count,
)
.await
{
Ok(choice) => choices.push(choice),
......@@ -1216,11 +1869,12 @@ impl GrpcRouter {
decoded_text.push_str(&t);
}
let output_ids = complete.output_ids.clone();
let output_ids = std::mem::take(&mut complete.output_ids);
let finish_reason = std::mem::take(&mut complete.finish_reason);
// Build base meta_info using json! macro
let mut meta_info = json!({
"finish_reason": complete.finish_reason.clone(),
"finish_reason": finish_reason,
"prompt_tokens": complete.prompt_tokens,
"completion_tokens": complete.completion_tokens,
"cached_tokens": complete.cached_tokens,
......@@ -1269,9 +1923,13 @@ impl GrpcRouter {
})
.collect();
// Build ChatLogProbsContent for each token
for (i, &logprob) in proto_logprobs.token_logprobs.iter().enumerate() {
let token_text = token_texts.get(i).cloned().unwrap_or_default();
// Build ChatLogProbsContent for each token (consume iterator to avoid clones)
for (i, (&logprob, token_text)) in proto_logprobs
.token_logprobs
.iter()
.zip(token_texts.into_iter())
.enumerate()
{
let bytes = Some(token_text.as_bytes().to_vec());
// Build top_logprobs for this position
......@@ -1324,6 +1982,7 @@ impl GrpcRouter {
index: usize,
original_request: &ChatCompletionRequest,
stop_decoder: &mut StopSequenceDecoder,
history_tool_calls_count: usize,
) -> Result<ChatChoice, String> {
stop_decoder.reset();
// Decode tokens
......@@ -1401,7 +2060,11 @@ impl GrpcRouter {
self.parse_json_schema_response(&processed_text, &original_request.tool_choice);
} else {
(tool_calls, processed_text) = self
.parse_with_model_parser(&processed_text, &original_request.model)
.parse_tool_calls(
&processed_text,
&original_request.model,
history_tool_calls_count,
)
.await;
}
}
......@@ -1686,7 +2349,6 @@ mod tests {
content: Some("Assistant response".to_string()),
name: None,
tool_calls: None,
function_call: None,
reasoning_content: None,
}];
......
......@@ -15,7 +15,7 @@ use crate::{
},
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
},
reasoning_parser::ParserFactory,
reasoning_parser::ReasoningParserFactory,
routers::{router_manager::RouterManager, RouterTrait},
service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
......@@ -45,7 +45,7 @@ pub struct AppContext {
pub router_config: RouterConfig,
pub rate_limiter: Arc<TokenBucket>,
pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ParserFactory>,
pub reasoning_parser_factory: Option<ReasoningParserFactory>,
pub tool_parser_factory: Option<ToolParserFactory>,
pub worker_registry: Arc<WorkerRegistry>,
pub policy_registry: Arc<PolicyRegistry>,
......@@ -79,7 +79,7 @@ impl AppContext {
tokenizer_factory::create_tokenizer(&tokenizer_path)
.map_err(|e| format!("Failed to create tokenizer: {e}"))?,
);
let reasoning_parser_factory = Some(ParserFactory::new());
let reasoning_parser_factory = Some(ReasoningParserFactory::new());
let tool_parser_factory = Some(ToolParserFactory::new());
(tokenizer, reasoning_parser_factory, tool_parser_factory)
......
......@@ -123,12 +123,7 @@ impl DeepSeekParser {
let arguments = serde_json::to_string(&args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("deepseek_call_{}", uuid::Uuid::new_v4());
Ok(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name.to_string(),
arguments,
......@@ -320,4 +315,8 @@ impl ToolParser for DeepSeekParser {
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text)
}
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
}
......@@ -129,12 +129,7 @@ impl Glm4MoeParser {
let arguments_str = serde_json::to_string(&arguments)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("glm4_call_{}", uuid::Uuid::new_v4());
Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name.to_string(),
arguments: arguments_str,
......@@ -321,4 +316,8 @@ impl ToolParser for Glm4MoeParser {
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text)
}
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
}
......@@ -113,12 +113,7 @@ impl ToolParser for GptOssParser {
}
};
// Generate unique ID
let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4());
tools.push(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: function_name,
arguments,
......
......@@ -14,6 +14,48 @@ pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
.collect()
}
/// Get unstreamed tool call arguments
/// Returns tool call items for arguments that have been parsed but not yet streamed
/// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk
pub fn get_unstreamed_args(
prev_tool_call_arr: &[Value],
streamed_args_for_tool: &[String],
) -> Option<Vec<ToolCallItem>> {
// Check if we have tool calls being tracked
if prev_tool_call_arr.is_empty() || streamed_args_for_tool.is_empty() {
return None;
}
// Get the last tool call that was being processed
let tool_index = prev_tool_call_arr.len() - 1;
if tool_index >= streamed_args_for_tool.len() {
return None;
}
// Get expected vs actual arguments
let expected_args = prev_tool_call_arr[tool_index].get("arguments")?;
let expected_str = serde_json::to_string(expected_args).ok()?;
let actual_str = &streamed_args_for_tool[tool_index];
// Check if there are remaining arguments to send
let remaining = if expected_str.starts_with(actual_str) {
&expected_str[actual_str.len()..]
} else {
return None;
};
if remaining.is_empty() {
return None;
}
// Return the remaining arguments as a ToolCallItem
Some(vec![ToolCallItem {
tool_index,
name: None, // No name for argument deltas
parameters: remaining.to_string(),
}])
}
/// Check if a buffer ends with a partial occurrence of a token
/// Returns Some(length) if there's a partial match, None otherwise
pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
......
......@@ -8,7 +8,7 @@ use crate::tool_parser::{
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall},
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
};
/// JSON format parser for tool calls
......@@ -136,16 +136,7 @@ impl JsonParser {
let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate a unique ID if not provided
let id = obj
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| format!("call_{}", uuid::Uuid::new_v4()));
Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: name.to_string(),
arguments,
......@@ -274,4 +265,8 @@ impl ToolParser for JsonParser {
let trimmed = text.trim();
(trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#)
}
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
}
......@@ -131,12 +131,7 @@ impl ToolParser for KimiK2Parser {
// Try to parse JSON arguments
match serde_json::from_str::<serde_json::Value>(function_args) {
Ok(_) => {
// Generate unique ID
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
tools.push(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall {
name: func_name,
arguments: function_args.to_string(),
......@@ -339,4 +334,8 @@ impl ToolParser for KimiK2Parser {
fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) || text.contains("<|tool_call_begin|>")
}
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment