Unverified Commit 963175d5 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Support streaming for v1/chat/completions (#11179)

parent 0618ad6d
...@@ -578,7 +578,7 @@ class GrpcRequestManager: ...@@ -578,7 +578,7 @@ class GrpcRequestManager:
batch_out.cached_tokens[i] if batch_out.cached_tokens else 0 batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
), ),
"finish_reason": ( "finish_reason": (
str(batch_out.finished_reasons[i]) batch_out.finished_reasons[i]
if batch_out.finished_reasons[i] if batch_out.finished_reasons[i]
else None else None
), ),
......
...@@ -112,7 +112,6 @@ def _launch_scheduler_process_only( ...@@ -112,7 +112,6 @@ def _launch_scheduler_process_only(
pp_rank, pp_rank,
None, None,
writer, writer,
None,
), ),
) )
...@@ -583,6 +582,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -583,6 +582,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
cached_tokens=meta_info.get("cached_tokens", 0), cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto, output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto, input_logprobs=input_logprobs_proto,
index=output.get("index", 0),
), ),
) )
...@@ -640,6 +640,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer) ...@@ -640,6 +640,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
cached_tokens=meta_info.get("cached_tokens", 0), cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto, output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto, input_logprobs=input_logprobs_proto,
index=output.get("index", 0),
**matched_stop_kwargs, **matched_stop_kwargs,
), ),
) )
......
...@@ -179,6 +179,9 @@ message GenerateStreamChunk { ...@@ -179,6 +179,9 @@ message GenerateStreamChunk {
// Input logprobs (if requested) - only in first chunk // Input logprobs (if requested) - only in first chunk
InputLogProbs input_logprobs = 7; InputLogProbs input_logprobs = 7;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 8;
} }
message GenerateComplete { message GenerateComplete {
...@@ -207,6 +210,9 @@ message GenerateComplete { ...@@ -207,6 +210,9 @@ message GenerateComplete {
// Input logprobs if requested (for prompt tokens) // Input logprobs if requested (for prompt tokens)
InputLogProbs input_logprobs = 10; InputLogProbs input_logprobs = 10;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 11;
} }
message GenerateError { message GenerateError {
......
...@@ -29,7 +29,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__ ...@@ -29,7 +29,7 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__
from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe2\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x0e\n\x06stream\x18\x11 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\x86\x02\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12<\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\"\x8c\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12<\n\x0einput_logprobs\x18\n \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbsB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"u\n\x0eOutputLogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"\x9e\x01\n\rInputLogProbs\x12@\n\x0etoken_logprobs\x18\x01 \x03(\x0b\x32(.sglang.grpc.scheduler.InputTokenLogProb\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"1\n\x11InputTokenLogProb\x12\x12\n\x05value\x18\x01 \x01(\x02H\x00\x88\x01\x01\x42\x08\n\x06_value\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3') DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16sglang_scheduler.proto\x12\x15sglang.grpc.scheduler\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1cgoogle/protobuf/struct.proto\"\xe1\x05\n\x0eSamplingParams\x12\x13\n\x0btemperature\x18\x01 \x01(\x02\x12\r\n\x05top_p\x18\x02 \x01(\x02\x12\r\n\x05top_k\x18\x03 \x01(\x05\x12\r\n\x05min_p\x18\x04 \x01(\x02\x12\x19\n\x11\x66requency_penalty\x18\x05 \x01(\x02\x12\x18\n\x10presence_penalty\x18\x06 \x01(\x02\x12\x1a\n\x12repetition_penalty\x18\x07 \x01(\x02\x12\x1b\n\x0emax_new_tokens\x18\x08 \x01(\x05H\x01\x88\x01\x01\x12\x0c\n\x04stop\x18\t \x03(\t\x12\x16\n\x0estop_token_ids\x18\n \x03(\r\x12\x1b\n\x13skip_special_tokens\x18\x0b \x01(\x08\x12%\n\x1dspaces_between_special_tokens\x18\x0c \x01(\x08\x12\x0f\n\x05regex\x18\r \x01(\tH\x00\x12\x15\n\x0bjson_schema\x18\x0e \x01(\tH\x00\x12\x16\n\x0c\x65\x62nf_grammar\x18\x0f \x01(\tH\x00\x12\x18\n\x0estructural_tag\x18\x10 \x01(\tH\x00\x12\x11\n\tlora_path\x18\x11 \x01(\t\x12\t\n\x01n\x18\x12 \x01(\x05\x12\x15\n\rtoken_healing\x18\x13 \x01(\x08\x12\x16\n\x0emin_new_tokens\x18\x14 \x01(\x05\x12\x12\n\nignore_eos\x18\x15 \x01(\x08\x12\x14\n\x0cno_stop_trim\x18\x16 \x01(\x08\x12\x17\n\x0fstream_interval\x18\x17 \x01(\x05\x12H\n\nlogit_bias\x18\x18 \x03(\x0b\x32\x34.sglang.grpc.scheduler.SamplingParams.LogitBiasEntry\x12.\n\rcustom_params\x18\x19 \x01(\x0b\x32\x17.google.protobuf.Struct\x1a\x30\n\x0eLogitBiasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02:\x02\x38\x01\x42\x0c\n\nconstraintB\x11\n\x0f_max_new_tokens\"]\n\x13\x44isaggregatedParams\x12\x16\n\x0e\x62ootstrap_host\x18\x01 \x01(\t\x12\x16\n\x0e\x62ootstrap_port\x18\x02 \x01(\x05\x12\x16\n\x0e\x62ootstrap_room\x18\x03 \x01(\x05\"\xe2\x04\n\x0fGenerateRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x04 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x16\n\x0ereturn_logprob\x18\x05 \x01(\x08\x12\x19\n\x11logprob_start_len\x18\x06 \x01(\x05\x12\x18\n\x10top_logprobs_num\x18\x07 \x01(\x05\x12\x19\n\x11token_ids_logprob\x18\x08 \x03(\r\x12\x1c\n\x14return_hidden_states\x18\t \x01(\x08\x12H\n\x14\x64isaggregated_params\x18\n \x01(\x0b\x32*.sglang.grpc.scheduler.DisaggregatedParams\x12\x1e\n\x16\x63ustom_logit_processor\x18\x0b \x01(\t\x12-\n\ttimestamp\x18\x0c \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x13\n\x0blog_metrics\x18\r \x01(\x08\x12\x14\n\x0cinput_embeds\x18\x0e \x03(\x02\x12\x0f\n\x07lora_id\x18\x0f \x01(\t\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x10 \x01(\x05\x12\x0e\n\x06stream\x18\x11 \x01(\x08\":\n\x0eTokenizedInput\x12\x15\n\roriginal_text\x18\x01 \x01(\t\x12\x11\n\tinput_ids\x18\x02 \x03(\r\"\xd3\x01\n\x10MultimodalInputs\x12\x12\n\nimage_urls\x18\x01 \x03(\t\x12\x12\n\nvideo_urls\x18\x02 \x03(\t\x12\x12\n\naudio_urls\x18\x03 \x03(\t\x12\x33\n\x12processed_features\x18\x04 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\x12\n\nimage_data\x18\x05 \x03(\x0c\x12\x12\n\nvideo_data\x18\x06 \x03(\x0c\x12\x12\n\naudio_data\x18\x07 \x03(\x0c\x12\x12\n\nmodalities\x18\x08 \x03(\t\"\xe3\x01\n\x10GenerateResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12;\n\x05\x63hunk\x18\x02 \x01(\x0b\x32*.sglang.grpc.scheduler.GenerateStreamChunkH\x00\x12;\n\x08\x63omplete\x18\x03 \x01(\x0b\x32\'.sglang.grpc.scheduler.GenerateCompleteH\x00\x12\x35\n\x05\x65rror\x18\x04 \x01(\x0b\x32$.sglang.grpc.scheduler.GenerateErrorH\x00\x42\n\n\x08response\"\x95\x02\n\x13GenerateStreamChunk\x12\x11\n\ttoken_ids\x18\x01 \x03(\r\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x03 \x01(\x05\x12\x15\n\rcached_tokens\x18\x04 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12\x15\n\rhidden_states\x18\x06 \x03(\x02\x12<\n\x0einput_logprobs\x18\x07 \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x08 \x01(\r\"\x9b\x03\n\x10GenerateComplete\x12\x12\n\noutput_ids\x18\x01 \x03(\r\x12\x15\n\rfinish_reason\x18\x02 \x01(\t\x12\x15\n\rprompt_tokens\x18\x03 \x01(\x05\x12\x19\n\x11\x63ompletion_tokens\x18\x04 \x01(\x05\x12\x15\n\rcached_tokens\x18\x05 \x01(\x05\x12>\n\x0foutput_logprobs\x18\x06 \x01(\x0b\x32%.sglang.grpc.scheduler.OutputLogProbs\x12>\n\x11\x61ll_hidden_states\x18\x07 \x03(\x0b\x32#.sglang.grpc.scheduler.HiddenStates\x12\x1a\n\x10matched_token_id\x18\x08 \x01(\rH\x00\x12\x1a\n\x10matched_stop_str\x18\t \x01(\tH\x00\x12<\n\x0einput_logprobs\x18\n \x01(\x0b\x32$.sglang.grpc.scheduler.InputLogProbs\x12\r\n\x05index\x18\x0b \x01(\rB\x0e\n\x0cmatched_stop\"K\n\rGenerateError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x18\n\x10http_status_code\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"u\n\x0eOutputLogProbs\x12\x16\n\x0etoken_logprobs\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"\x9e\x01\n\rInputLogProbs\x12@\n\x0etoken_logprobs\x18\x01 \x03(\x0b\x32(.sglang.grpc.scheduler.InputTokenLogProb\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\x12\x38\n\x0ctop_logprobs\x18\x03 \x03(\x0b\x32\".sglang.grpc.scheduler.TopLogProbs\"1\n\x11InputTokenLogProb\x12\x12\n\x05value\x18\x01 \x01(\x02H\x00\x88\x01\x01\x42\x08\n\x06_value\"0\n\x0bTopLogProbs\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\x11\n\ttoken_ids\x18\x02 \x03(\x05\"?\n\x0cHiddenStates\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05layer\x18\x02 \x01(\x05\x12\x10\n\x08position\x18\x03 \x01(\x05\"\xca\x02\n\x0c\x45mbedRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\ttokenized\x18\x02 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\x12:\n\tmm_inputs\x18\x04 \x01(\x0b\x32\'.sglang.grpc.scheduler.MultimodalInputs\x12>\n\x0fsampling_params\x18\x05 \x01(\x0b\x32%.sglang.grpc.scheduler.SamplingParams\x12\x13\n\x0blog_metrics\x18\x06 \x01(\x08\x12\x16\n\x0etoken_type_ids\x18\x07 \x03(\x05\x12\x1a\n\x12\x64\x61ta_parallel_rank\x18\x08 \x01(\x05\x12\x18\n\x10is_cross_encoder\x18\t \x01(\x08\x12\r\n\x05texts\x18\n \x03(\t\"\x9d\x01\n\rEmbedResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x38\n\x08\x63omplete\x18\x02 \x01(\x0b\x32$.sglang.grpc.scheduler.EmbedCompleteH\x00\x12\x32\n\x05\x65rror\x18\x03 \x01(\x0b\x32!.sglang.grpc.scheduler.EmbedErrorH\x00\x42\n\n\x08response\"\xa3\x01\n\rEmbedComplete\x12\x11\n\tembedding\x18\x01 \x03(\x02\x12\x15\n\rprompt_tokens\x18\x02 \x01(\x05\x12\x15\n\rcached_tokens\x18\x03 \x01(\x05\x12\x15\n\rembedding_dim\x18\x04 \x01(\x05\x12:\n\x10\x62\x61tch_embeddings\x18\x05 \x03(\x0b\x32 .sglang.grpc.scheduler.Embedding\"*\n\tEmbedding\x12\x0e\n\x06values\x18\x01 \x03(\x02\x12\r\n\x05index\x18\x02 \x01(\x05\"<\n\nEmbedError\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x0c\n\x04\x63ode\x18\x02 \x01(\t\x12\x0f\n\x07\x64\x65tails\x18\x03 \x01(\t\"N\n\x12HealthCheckRequest\x12\x38\n\ttokenized\x18\x01 \x01(\x0b\x32%.sglang.grpc.scheduler.TokenizedInput\"7\n\x13HealthCheckResponse\x12\x0f\n\x07healthy\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0c\x41\x62ortRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12\x0e\n\x06reason\x18\x02 \x01(\t\"1\n\rAbortResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"I\n\x0fLoadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\x12\x14\n\x0c\x61\x64\x61pter_path\x18\x02 \x01(\t\x12\x0c\n\x04rank\x18\x03 \x01(\x05\"H\n\x10LoadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x12\n\nadapter_id\x18\x02 \x01(\t\x12\x0f\n\x07message\x18\x03 \x01(\t\"\'\n\x11UnloadLoRARequest\x12\x12\n\nadapter_id\x18\x01 \x01(\t\"6\n\x12UnloadLoRAResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"w\n\x14UpdateWeightsRequest\x12\x13\n\tdisk_path\x18\x01 \x01(\tH\x00\x12\x15\n\x0btensor_data\x18\x02 \x01(\x0cH\x00\x12\x14\n\nremote_url\x18\x03 \x01(\tH\x00\x12\x13\n\x0bweight_name\x18\x04 \x01(\tB\x08\n\x06source\"9\n\x15UpdateWeightsResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\"-\n\x17GetInternalStateRequest\x12\x12\n\nstate_keys\x18\x01 \x03(\t\"B\n\x18GetInternalStateResponse\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"A\n\x17SetInternalStateRequest\x12&\n\x05state\x18\x01 \x01(\x0b\x32\x17.google.protobuf.Struct\"<\n\x18SetInternalStateResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t2\xfe\x02\n\x0fSglangScheduler\x12]\n\x08Generate\x12&.sglang.grpc.scheduler.GenerateRequest\x1a\'.sglang.grpc.scheduler.GenerateResponse0\x01\x12R\n\x05\x45mbed\x12#.sglang.grpc.scheduler.EmbedRequest\x1a$.sglang.grpc.scheduler.EmbedResponse\x12\x64\n\x0bHealthCheck\x12).sglang.grpc.scheduler.HealthCheckRequest\x1a*.sglang.grpc.scheduler.HealthCheckResponse\x12R\n\x05\x41\x62ort\x12#.sglang.grpc.scheduler.AbortRequest\x1a$.sglang.grpc.scheduler.AbortResponseb\x06proto3')
_globals = globals() _globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
...@@ -53,59 +53,59 @@ if not _descriptor._USE_C_DESCRIPTORS: ...@@ -53,59 +53,59 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals['_GENERATERESPONSE']._serialized_start=1835 _globals['_GENERATERESPONSE']._serialized_start=1835
_globals['_GENERATERESPONSE']._serialized_end=2062 _globals['_GENERATERESPONSE']._serialized_end=2062
_globals['_GENERATESTREAMCHUNK']._serialized_start=2065 _globals['_GENERATESTREAMCHUNK']._serialized_start=2065
_globals['_GENERATESTREAMCHUNK']._serialized_end=2327 _globals['_GENERATESTREAMCHUNK']._serialized_end=2342
_globals['_GENERATECOMPLETE']._serialized_start=2330 _globals['_GENERATECOMPLETE']._serialized_start=2345
_globals['_GENERATECOMPLETE']._serialized_end=2726 _globals['_GENERATECOMPLETE']._serialized_end=2756
_globals['_GENERATEERROR']._serialized_start=2728 _globals['_GENERATEERROR']._serialized_start=2758
_globals['_GENERATEERROR']._serialized_end=2803 _globals['_GENERATEERROR']._serialized_end=2833
_globals['_OUTPUTLOGPROBS']._serialized_start=2805 _globals['_OUTPUTLOGPROBS']._serialized_start=2835
_globals['_OUTPUTLOGPROBS']._serialized_end=2922 _globals['_OUTPUTLOGPROBS']._serialized_end=2952
_globals['_INPUTLOGPROBS']._serialized_start=2925 _globals['_INPUTLOGPROBS']._serialized_start=2955
_globals['_INPUTLOGPROBS']._serialized_end=3083 _globals['_INPUTLOGPROBS']._serialized_end=3113
_globals['_INPUTTOKENLOGPROB']._serialized_start=3085 _globals['_INPUTTOKENLOGPROB']._serialized_start=3115
_globals['_INPUTTOKENLOGPROB']._serialized_end=3134 _globals['_INPUTTOKENLOGPROB']._serialized_end=3164
_globals['_TOPLOGPROBS']._serialized_start=3136 _globals['_TOPLOGPROBS']._serialized_start=3166
_globals['_TOPLOGPROBS']._serialized_end=3184 _globals['_TOPLOGPROBS']._serialized_end=3214
_globals['_HIDDENSTATES']._serialized_start=3186 _globals['_HIDDENSTATES']._serialized_start=3216
_globals['_HIDDENSTATES']._serialized_end=3249 _globals['_HIDDENSTATES']._serialized_end=3279
_globals['_EMBEDREQUEST']._serialized_start=3252 _globals['_EMBEDREQUEST']._serialized_start=3282
_globals['_EMBEDREQUEST']._serialized_end=3582 _globals['_EMBEDREQUEST']._serialized_end=3612
_globals['_EMBEDRESPONSE']._serialized_start=3585 _globals['_EMBEDRESPONSE']._serialized_start=3615
_globals['_EMBEDRESPONSE']._serialized_end=3742 _globals['_EMBEDRESPONSE']._serialized_end=3772
_globals['_EMBEDCOMPLETE']._serialized_start=3745 _globals['_EMBEDCOMPLETE']._serialized_start=3775
_globals['_EMBEDCOMPLETE']._serialized_end=3908 _globals['_EMBEDCOMPLETE']._serialized_end=3938
_globals['_EMBEDDING']._serialized_start=3910 _globals['_EMBEDDING']._serialized_start=3940
_globals['_EMBEDDING']._serialized_end=3952 _globals['_EMBEDDING']._serialized_end=3982
_globals['_EMBEDERROR']._serialized_start=3954 _globals['_EMBEDERROR']._serialized_start=3984
_globals['_EMBEDERROR']._serialized_end=4014 _globals['_EMBEDERROR']._serialized_end=4044
_globals['_HEALTHCHECKREQUEST']._serialized_start=4016 _globals['_HEALTHCHECKREQUEST']._serialized_start=4046
_globals['_HEALTHCHECKREQUEST']._serialized_end=4094 _globals['_HEALTHCHECKREQUEST']._serialized_end=4124
_globals['_HEALTHCHECKRESPONSE']._serialized_start=4096 _globals['_HEALTHCHECKRESPONSE']._serialized_start=4126
_globals['_HEALTHCHECKRESPONSE']._serialized_end=4151 _globals['_HEALTHCHECKRESPONSE']._serialized_end=4181
_globals['_ABORTREQUEST']._serialized_start=4153 _globals['_ABORTREQUEST']._serialized_start=4183
_globals['_ABORTREQUEST']._serialized_end=4203 _globals['_ABORTREQUEST']._serialized_end=4233
_globals['_ABORTRESPONSE']._serialized_start=4205 _globals['_ABORTRESPONSE']._serialized_start=4235
_globals['_ABORTRESPONSE']._serialized_end=4254 _globals['_ABORTRESPONSE']._serialized_end=4284
_globals['_LOADLORAREQUEST']._serialized_start=4256 _globals['_LOADLORAREQUEST']._serialized_start=4286
_globals['_LOADLORAREQUEST']._serialized_end=4329 _globals['_LOADLORAREQUEST']._serialized_end=4359
_globals['_LOADLORARESPONSE']._serialized_start=4331 _globals['_LOADLORARESPONSE']._serialized_start=4361
_globals['_LOADLORARESPONSE']._serialized_end=4403 _globals['_LOADLORARESPONSE']._serialized_end=4433
_globals['_UNLOADLORAREQUEST']._serialized_start=4405 _globals['_UNLOADLORAREQUEST']._serialized_start=4435
_globals['_UNLOADLORAREQUEST']._serialized_end=4444 _globals['_UNLOADLORAREQUEST']._serialized_end=4474
_globals['_UNLOADLORARESPONSE']._serialized_start=4446 _globals['_UNLOADLORARESPONSE']._serialized_start=4476
_globals['_UNLOADLORARESPONSE']._serialized_end=4500 _globals['_UNLOADLORARESPONSE']._serialized_end=4530
_globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4502 _globals['_UPDATEWEIGHTSREQUEST']._serialized_start=4532
_globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4621 _globals['_UPDATEWEIGHTSREQUEST']._serialized_end=4651
_globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4623 _globals['_UPDATEWEIGHTSRESPONSE']._serialized_start=4653
_globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4680 _globals['_UPDATEWEIGHTSRESPONSE']._serialized_end=4710
_globals['_GETINTERNALSTATEREQUEST']._serialized_start=4682 _globals['_GETINTERNALSTATEREQUEST']._serialized_start=4712
_globals['_GETINTERNALSTATEREQUEST']._serialized_end=4727 _globals['_GETINTERNALSTATEREQUEST']._serialized_end=4757
_globals['_GETINTERNALSTATERESPONSE']._serialized_start=4729 _globals['_GETINTERNALSTATERESPONSE']._serialized_start=4759
_globals['_GETINTERNALSTATERESPONSE']._serialized_end=4795 _globals['_GETINTERNALSTATERESPONSE']._serialized_end=4825
_globals['_SETINTERNALSTATEREQUEST']._serialized_start=4797 _globals['_SETINTERNALSTATEREQUEST']._serialized_start=4827
_globals['_SETINTERNALSTATEREQUEST']._serialized_end=4862 _globals['_SETINTERNALSTATEREQUEST']._serialized_end=4892
_globals['_SETINTERNALSTATERESPONSE']._serialized_start=4864 _globals['_SETINTERNALSTATERESPONSE']._serialized_start=4894
_globals['_SETINTERNALSTATERESPONSE']._serialized_end=4924 _globals['_SETINTERNALSTATERESPONSE']._serialized_end=4954
_globals['_SGLANGSCHEDULER']._serialized_start=4927 _globals['_SGLANGSCHEDULER']._serialized_start=4957
_globals['_SGLANGSCHEDULER']._serialized_end=5309 _globals['_SGLANGSCHEDULER']._serialized_end=5339
# @@protoc_insertion_point(module_scope) # @@protoc_insertion_point(module_scope)
...@@ -160,7 +160,7 @@ class GenerateResponse(_message.Message): ...@@ -160,7 +160,7 @@ class GenerateResponse(_message.Message):
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ... def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
class GenerateStreamChunk(_message.Message): class GenerateStreamChunk(_message.Message):
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs") __slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs", "index")
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int] TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int] COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
...@@ -168,6 +168,7 @@ class GenerateStreamChunk(_message.Message): ...@@ -168,6 +168,7 @@ class GenerateStreamChunk(_message.Message):
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int] HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
INDEX_FIELD_NUMBER: _ClassVar[int]
token_ids: _containers.RepeatedScalarFieldContainer[int] token_ids: _containers.RepeatedScalarFieldContainer[int]
prompt_tokens: int prompt_tokens: int
completion_tokens: int completion_tokens: int
...@@ -175,10 +176,11 @@ class GenerateStreamChunk(_message.Message): ...@@ -175,10 +176,11 @@ class GenerateStreamChunk(_message.Message):
output_logprobs: OutputLogProbs output_logprobs: OutputLogProbs
hidden_states: _containers.RepeatedScalarFieldContainer[float] hidden_states: _containers.RepeatedScalarFieldContainer[float]
input_logprobs: InputLogProbs input_logprobs: InputLogProbs
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ... index: int
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
class GenerateComplete(_message.Message): class GenerateComplete(_message.Message):
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs") __slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs", "index")
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int] OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
FINISH_REASON_FIELD_NUMBER: _ClassVar[int] FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int] PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
...@@ -189,6 +191,7 @@ class GenerateComplete(_message.Message): ...@@ -189,6 +191,7 @@ class GenerateComplete(_message.Message):
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int] MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int] MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int] INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
INDEX_FIELD_NUMBER: _ClassVar[int]
output_ids: _containers.RepeatedScalarFieldContainer[int] output_ids: _containers.RepeatedScalarFieldContainer[int]
finish_reason: str finish_reason: str
prompt_tokens: int prompt_tokens: int
...@@ -199,7 +202,8 @@ class GenerateComplete(_message.Message): ...@@ -199,7 +202,8 @@ class GenerateComplete(_message.Message):
matched_token_id: int matched_token_id: int
matched_stop_str: str matched_stop_str: str
input_logprobs: InputLogProbs input_logprobs: InputLogProbs
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ...) -> None: ... index: int
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
class GenerateError(_message.Message): class GenerateError(_message.Message):
__slots__ = ("message", "http_status_code", "details") __slots__ = ("message", "http_status_code", "details")
......
...@@ -192,7 +192,6 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest { ...@@ -192,7 +192,6 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)), content: Some(format!("Answer {}: This is a detailed response about topic {} that covers multiple aspects and provides comprehensive analysis of the interconnected systems you mentioned.", i, i)),
name: None, name: None,
tool_calls: None, tool_calls: None,
function_call: None,
reasoning_content: None, reasoning_content: None,
}); });
} }
......
...@@ -179,6 +179,9 @@ message GenerateStreamChunk { ...@@ -179,6 +179,9 @@ message GenerateStreamChunk {
// Input logprobs (if requested) - only in first chunk // Input logprobs (if requested) - only in first chunk
InputLogProbs input_logprobs = 7; InputLogProbs input_logprobs = 7;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 8;
} }
message GenerateComplete { message GenerateComplete {
...@@ -207,6 +210,9 @@ message GenerateComplete { ...@@ -207,6 +210,9 @@ message GenerateComplete {
// Input logprobs if requested (for prompt tokens) // Input logprobs if requested (for prompt tokens)
InputLogProbs input_logprobs = 10; InputLogProbs input_logprobs = 10;
// Index for ordering when n>1 (for parallel request multiplexing)
uint32 index = 11;
} }
message GenerateError { message GenerateError {
......
...@@ -72,8 +72,6 @@ pub enum ChatMessage { ...@@ -72,8 +72,6 @@ pub enum ChatMessage {
name: Option<String>, name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>, tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<FunctionCallResponse>,
/// Reasoning content for O1-style models (SGLang extension) /// Reasoning content for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
reasoning_content: Option<String>, reasoning_content: Option<String>,
...@@ -140,8 +138,6 @@ pub struct ChatMessageDelta { ...@@ -140,8 +138,6 @@ pub struct ChatMessageDelta {
pub content: Option<String>, pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>, pub tool_calls: Option<Vec<ToolCallDelta>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCallDelta>,
/// Reasoning content delta for O1-style models (SGLang extension) /// Reasoning content delta for O1-style models (SGLang extension)
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_content: Option<String>, pub reasoning_content: Option<String>,
...@@ -473,6 +469,8 @@ pub struct ChatStreamChoice { ...@@ -473,6 +469,8 @@ pub struct ChatStreamChoice {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>, pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub matched_stop: Option<Value>,
} }
// Completions API request types (v1/completions) - DEPRECATED but still supported // Completions API request types (v1/completions) - DEPRECATED but still supported
......
...@@ -44,7 +44,7 @@ graph TB ...@@ -44,7 +44,7 @@ graph TB
end end
subgraph Factory Layer subgraph Factory Layer
MID --> PF[ParserFactory] MID --> PF[ReasoningParserFactory]
PF --> REG[ParserRegistry] PF --> REG[ParserRegistry]
REG --> PM[Pattern Matching] REG --> PM[Pattern Matching]
PM --> PP[Parser Pool] PM --> PP[Parser Pool]
...@@ -93,7 +93,7 @@ graph TB ...@@ -93,7 +93,7 @@ graph TB
```mermaid ```mermaid
sequenceDiagram sequenceDiagram
participant C as Client participant C as Client
participant F as ParserFactory participant F as ReasoningParserFactory
participant R as Registry participant R as Registry
participant P as Parser Pool participant P as Parser Pool
participant BP as BaseParser participant BP as BaseParser
...@@ -206,7 +206,7 @@ classDiagram ...@@ -206,7 +206,7 @@ classDiagram
+new() Self +new() Self
} }
class ParserFactory { class ReasoningParserFactory {
-registry: ParserRegistry -registry: ParserRegistry
+new() Self +new() Self
+get_pooled(model_id: &str) PooledParser +get_pooled(model_id: &str) PooledParser
...@@ -240,7 +240,7 @@ classDiagram ...@@ -240,7 +240,7 @@ classDiagram
Step3Parser o-- BaseReasoningParser Step3Parser o-- BaseReasoningParser
BaseReasoningParser o-- ParserConfig BaseReasoningParser o-- ParserConfig
ParserFactory o-- ParserRegistry ReasoningParserFactory o-- ParserRegistry
ParserRegistry o-- ReasoningParser ParserRegistry o-- ReasoningParser
``` ```
...@@ -302,7 +302,7 @@ classDiagram ...@@ -302,7 +302,7 @@ classDiagram
- Delegate to get_pooled_parser - Delegate to get_pooled_parser
- Case-insensitive comparison - Case-insensitive comparison
**ParserFactory Methods**: **ReasoningParserFactory Methods**:
1. **`new()`**: 1. **`new()`**:
- Register all built-in parsers - Register all built-in parsers
...@@ -437,7 +437,7 @@ impl ReasoningParser for MyModelParser { ...@@ -437,7 +437,7 @@ impl ReasoningParser for MyModelParser {
**Step 2: Register in Factory** **Step 2: Register in Factory**
```rust ```rust
// In factory.rs ParserFactory::new() // In factory.rs ReasoningParserFactory::new()
registry.register_parser("mymodel", || { registry.register_parser("mymodel", || {
Box::new(MyModelParser::new()) Box::new(MyModelParser::new())
}); });
......
...@@ -128,11 +128,11 @@ impl Default for ParserRegistry { ...@@ -128,11 +128,11 @@ impl Default for ParserRegistry {
/// Factory for creating reasoning parsers based on model type. /// Factory for creating reasoning parsers based on model type.
#[derive(Clone)] #[derive(Clone)]
pub struct ParserFactory { pub struct ReasoningParserFactory {
registry: ParserRegistry, registry: ParserRegistry,
} }
impl ParserFactory { impl ReasoningParserFactory {
/// Create a new factory with default parsers registered. /// Create a new factory with default parsers registered.
pub fn new() -> Self { pub fn new() -> Self {
let registry = ParserRegistry::new(); let registry = ParserRegistry::new();
...@@ -237,7 +237,7 @@ impl ParserFactory { ...@@ -237,7 +237,7 @@ impl ParserFactory {
} }
} }
impl Default for ParserFactory { impl Default for ReasoningParserFactory {
fn default() -> Self { fn default() -> Self {
Self::new() Self::new()
} }
...@@ -249,35 +249,35 @@ mod tests { ...@@ -249,35 +249,35 @@ mod tests {
#[test] #[test]
fn test_factory_creates_deepseek_r1() { fn test_factory_creates_deepseek_r1() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser = factory.create("deepseek-r1-distill").unwrap(); let parser = factory.create("deepseek-r1-distill").unwrap();
assert_eq!(parser.model_type(), "deepseek_r1"); assert_eq!(parser.model_type(), "deepseek_r1");
} }
#[test] #[test]
fn test_factory_creates_qwen3() { fn test_factory_creates_qwen3() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser = factory.create("qwen3-7b").unwrap(); let parser = factory.create("qwen3-7b").unwrap();
assert_eq!(parser.model_type(), "qwen3"); assert_eq!(parser.model_type(), "qwen3");
} }
#[test] #[test]
fn test_factory_creates_kimi() { fn test_factory_creates_kimi() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser = factory.create("kimi-chat").unwrap(); let parser = factory.create("kimi-chat").unwrap();
assert_eq!(parser.model_type(), "kimi"); assert_eq!(parser.model_type(), "kimi");
} }
#[test] #[test]
fn test_factory_fallback_to_passthrough() { fn test_factory_fallback_to_passthrough() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser = factory.create("unknown-model").unwrap(); let parser = factory.create("unknown-model").unwrap();
assert_eq!(parser.model_type(), "passthrough"); assert_eq!(parser.model_type(), "passthrough");
} }
#[test] #[test]
fn test_case_insensitive_matching() { fn test_case_insensitive_matching() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser1 = factory.create("DeepSeek-R1").unwrap(); let parser1 = factory.create("DeepSeek-R1").unwrap();
let parser2 = factory.create("QWEN3").unwrap(); let parser2 = factory.create("QWEN3").unwrap();
let parser3 = factory.create("Kimi").unwrap(); let parser3 = factory.create("Kimi").unwrap();
...@@ -289,21 +289,21 @@ mod tests { ...@@ -289,21 +289,21 @@ mod tests {
#[test] #[test]
fn test_step3_model() { fn test_step3_model() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let step3 = factory.create("step3-model").unwrap(); let step3 = factory.create("step3-model").unwrap();
assert_eq!(step3.model_type(), "step3"); assert_eq!(step3.model_type(), "step3");
} }
#[test] #[test]
fn test_glm45_model() { fn test_glm45_model() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let glm45 = factory.create("glm45-v2").unwrap(); let glm45 = factory.create("glm45-v2").unwrap();
assert_eq!(glm45.model_type(), "glm45"); assert_eq!(glm45.model_type(), "glm45");
} }
#[test] #[test]
fn test_pooled_parser_reuse() { fn test_pooled_parser_reuse() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
// Get the same parser twice - should be the same instance // Get the same parser twice - should be the same instance
let parser1 = factory.get_pooled("deepseek-r1"); let parser1 = factory.get_pooled("deepseek-r1");
...@@ -321,7 +321,7 @@ mod tests { ...@@ -321,7 +321,7 @@ mod tests {
fn test_pooled_parser_concurrent_access() { fn test_pooled_parser_concurrent_access() {
use std::thread; use std::thread;
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let parser = factory.get_pooled("deepseek-r1"); let parser = factory.get_pooled("deepseek-r1");
// Spawn multiple threads that use the same parser // Spawn multiple threads that use the same parser
...@@ -347,7 +347,7 @@ mod tests { ...@@ -347,7 +347,7 @@ mod tests {
#[test] #[test]
fn test_pool_clearing() { fn test_pool_clearing() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
// Get a pooled parser // Get a pooled parser
let parser1 = factory.get_pooled("deepseek-r1"); let parser1 = factory.get_pooled("deepseek-r1");
...@@ -364,7 +364,7 @@ mod tests { ...@@ -364,7 +364,7 @@ mod tests {
#[test] #[test]
fn test_passthrough_parser_pooling() { fn test_passthrough_parser_pooling() {
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
// Unknown models should get passthrough parser // Unknown models should get passthrough parser
let parser1 = factory.get_pooled("unknown-model-1"); let parser1 = factory.get_pooled("unknown-model-1");
...@@ -383,7 +383,7 @@ mod tests { ...@@ -383,7 +383,7 @@ mod tests {
use std::thread; use std::thread;
use std::time::Instant; use std::time::Instant;
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let num_threads = 100; let num_threads = 100;
let requests_per_thread = 50; let requests_per_thread = 50;
let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"]; let models = vec!["deepseek-r1", "qwen3", "kimi", "qwen3-thinking"];
...@@ -527,7 +527,7 @@ mod tests { ...@@ -527,7 +527,7 @@ mod tests {
fn test_concurrent_pool_modifications() { fn test_concurrent_pool_modifications() {
use std::thread; use std::thread;
let factory = ParserFactory::new(); let factory = ReasoningParserFactory::new();
let mut handles = vec![]; let mut handles = vec![];
// Thread 1: Continuously get parsers // Thread 1: Continuously get parsers
......
...@@ -2,7 +2,7 @@ pub mod factory; ...@@ -2,7 +2,7 @@ pub mod factory;
pub mod parsers; pub mod parsers;
pub mod traits; pub mod traits;
pub use factory::{ParserFactory, ParserRegistry, PooledParser}; pub use factory::{ParserRegistry, PooledParser, ReasoningParserFactory};
pub use parsers::{ pub use parsers::{
BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser, BaseReasoningParser, DeepSeekR1Parser, Glm45Parser, KimiParser, Qwen3Parser,
QwenThinkingParser, Step3Parser, QwenThinkingParser, Step3Parser,
......
...@@ -4,7 +4,7 @@ use crate::config::types::RetryConfig; ...@@ -4,7 +4,7 @@ use crate::config::types::RetryConfig;
use crate::core::{WorkerRegistry, WorkerType}; use crate::core::{WorkerRegistry, WorkerType};
use crate::metrics::RouterMetrics; use crate::metrics::RouterMetrics;
use crate::policies::PolicyRegistry; use crate::policies::PolicyRegistry;
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::ReasoningParserFactory;
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ToolParserFactory; use crate::tool_parser::ToolParserFactory;
...@@ -24,7 +24,7 @@ pub struct GrpcPDRouter { ...@@ -24,7 +24,7 @@ pub struct GrpcPDRouter {
worker_registry: Arc<WorkerRegistry>, worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>, policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory, reasoning_parser_factory: ReasoningParserFactory,
tool_parser_factory: ToolParserFactory, tool_parser_factory: ToolParserFactory,
dp_aware: bool, dp_aware: bool,
......
...@@ -7,10 +7,14 @@ use async_trait::async_trait; ...@@ -7,10 +7,14 @@ use async_trait::async_trait;
use axum::{ use axum::{
body::Body, body::Body,
extract::Request, extract::Request,
http::{HeaderMap, StatusCode}, http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json, Json,
}; };
use bytes::Bytes;
use std::io;
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use crate::config::types::RetryConfig; use crate::config::types::RetryConfig;
...@@ -21,11 +25,12 @@ use crate::policies::PolicyRegistry; ...@@ -21,11 +25,12 @@ use crate::policies::PolicyRegistry;
use crate::protocols::spec::ChatMessage; use crate::protocols::spec::ChatMessage;
use crate::protocols::spec::{ use crate::protocols::spec::{
ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse,
CompletionRequest, EmbeddingRequest, FunctionCallResponse, GenerateRequest, RerankRequest, ChatCompletionStreamResponse, ChatMessageDelta, ChatStreamChoice, CompletionRequest,
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolChoice, EmbeddingRequest, FunctionCallDelta, FunctionCallResponse, GenerateRequest, RerankRequest,
ResponsesGetParams, ResponsesRequest, StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice,
ToolChoiceValue, Usage, ToolChoiceValue, Usage,
}; };
use crate::reasoning_parser::ParserFactory; use crate::reasoning_parser::{ParserResult, ReasoningParserFactory};
use crate::routers::RouterTrait; use crate::routers::RouterTrait;
use crate::server::AppContext; use crate::server::AppContext;
use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams}; use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
...@@ -34,7 +39,7 @@ use crate::tokenizer::stop::{ ...@@ -34,7 +39,7 @@ use crate::tokenizer::stop::{
}; };
use crate::tokenizer::traits::Tokenizer; use crate::tokenizer::traits::Tokenizer;
use crate::tokenizer::HuggingFaceTokenizer; use crate::tokenizer::HuggingFaceTokenizer;
use crate::tool_parser::ToolParserFactory; use crate::tool_parser::{StreamingParseResult, ToolParserFactory};
use proto::generate_response::Response::{Chunk, Complete, Error}; use proto::generate_response::Response::{Chunk, Complete, Error};
use serde_json::{json, Map, Value}; use serde_json::{json, Map, Value};
use std::time::{Instant, SystemTime, UNIX_EPOCH}; use std::time::{Instant, SystemTime, UNIX_EPOCH};
...@@ -50,12 +55,13 @@ pub struct ProcessedMessages { ...@@ -50,12 +55,13 @@ pub struct ProcessedMessages {
} }
/// gRPC router implementation for SGLang /// gRPC router implementation for SGLang
#[derive(Clone)]
#[allow(dead_code)] #[allow(dead_code)]
pub struct GrpcRouter { pub struct GrpcRouter {
worker_registry: Arc<WorkerRegistry>, worker_registry: Arc<WorkerRegistry>,
policy_registry: Arc<PolicyRegistry>, policy_registry: Arc<PolicyRegistry>,
tokenizer: Arc<dyn Tokenizer>, tokenizer: Arc<dyn Tokenizer>,
reasoning_parser_factory: ParserFactory, reasoning_parser_factory: ReasoningParserFactory,
tool_parser_factory: ToolParserFactory, tool_parser_factory: ToolParserFactory,
dp_aware: bool, dp_aware: bool,
api_key: Option<String>, api_key: Option<String>,
...@@ -776,10 +782,11 @@ impl GrpcRouter { ...@@ -776,10 +782,11 @@ impl GrpcRouter {
} }
/// Parse tool calls using model-specific parser /// Parse tool calls using model-specific parser
async fn parse_with_model_parser( async fn parse_tool_calls(
&self, &self,
processed_text: &str, processed_text: &str,
model: &str, model: &str,
history_tool_calls_count: usize,
) -> (Option<Vec<ToolCall>>, String) { ) -> (Option<Vec<ToolCall>>, String) {
// Get pooled parser for this model // Get pooled parser for this model
let pooled_parser = self.tool_parser_factory.get_pooled(model); let pooled_parser = self.tool_parser_factory.get_pooled(model);
...@@ -810,16 +817,26 @@ impl GrpcRouter { ...@@ -810,16 +817,26 @@ impl GrpcRouter {
let spec_tool_calls = parsed_tool_calls let spec_tool_calls = parsed_tool_calls
.into_iter() .into_iter()
.map(|tc| ToolCall { .enumerate()
id: tc.id, .map(|(index, tc)| {
tool_type: "function".to_string(), // Generate ID for this tool call
function: FunctionCallResponse { let id = Self::generate_tool_call_id(
name: tc.function.name, model,
arguments: Some( &tc.function.name,
serde_json::to_string(&tc.function.arguments) index,
.unwrap_or_else(|_| "{}".to_string()), history_tool_calls_count,
), );
}, ToolCall {
id,
tool_type: "function".to_string(),
function: FunctionCallResponse {
name: tc.function.name,
arguments: Some(
serde_json::to_string(&tc.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
),
},
}
}) })
.collect(); .collect();
(Some(spec_tool_calls), normal_text) (Some(spec_tool_calls), normal_text)
...@@ -920,6 +937,47 @@ impl GrpcRouter { ...@@ -920,6 +937,47 @@ impl GrpcRouter {
builder.build() builder.build()
} }
/// Count the number of tool calls in the request message history
/// This is used for KimiK2 format which needs globally unique indices
fn get_history_tool_calls_count(request: &ChatCompletionRequest) -> usize {
request
.messages
.iter()
.filter_map(|msg| {
if let ChatMessage::Assistant { tool_calls, .. } = msg {
tool_calls.as_ref().map(|calls| calls.len())
} else {
None
}
})
.sum()
}
/// Generate a tool call ID based on model format
///
/// # Arguments
/// * `model` - Model name to determine ID format
/// * `tool_name` - Name of the tool being called
/// * `tool_index` - Index of this tool call within the current message
/// * `history_count` - Number of tool calls in previous messages
///
/// # Returns
/// A unique ID string. KimiK2 uses `functions.{name}:{global_index}`, others use `call_{uuid}`
fn generate_tool_call_id(
model: &str,
tool_name: &str,
tool_index: usize,
history_count: usize,
) -> String {
if model.to_lowercase().contains("kimi") {
// KimiK2 format: functions.{name}:{global_index}
format!("functions.{}:{}", tool_name, history_count + tool_index)
} else {
// Standard OpenAI format: call_{24-char-uuid}
format!("call_{}", &Uuid::new_v4().simple().to_string()[..24])
}
}
/// Process a chunk of tokens through the stop decoder /// Process a chunk of tokens through the stop decoder
fn process_chunk_tokens( fn process_chunk_tokens(
stop_decoder: &mut StopSequenceDecoder, stop_decoder: &mut StopSequenceDecoder,
...@@ -953,6 +1011,230 @@ impl GrpcRouter { ...@@ -953,6 +1011,230 @@ impl GrpcRouter {
(chunk_text, false) // Return text and continue processing (chunk_text, false) // Return text and continue processing
} }
/// Helper: Process reasoning content in streaming mode
/// Returns (modified_delta, optional_reasoning_chunk)
fn process_reasoning_stream(
&self,
delta: &str,
index: u32,
reasoning_parsers: &mut HashMap<
u32,
Arc<std::sync::Mutex<Box<dyn crate::reasoning_parser::ReasoningParser>>>,
>,
request_id: &str,
model: &str,
created: u64,
) -> (String, Option<ChatCompletionStreamResponse>) {
// Get or create parser for this index
reasoning_parsers
.entry(index)
.or_insert_with(|| self.reasoning_parser_factory.get_pooled(model));
if let Some(pooled_parser) = reasoning_parsers.get(&index) {
let parse_result = {
let mut parser = pooled_parser.lock().unwrap();
parser.parse_reasoning_streaming_incremental(delta)
};
match parse_result {
Ok(ParserResult {
reasoning_text,
normal_text,
}) => {
let chunk = if !reasoning_text.is_empty() {
Some(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: Some(reasoning_text),
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
})
} else {
None
};
return (normal_text, chunk);
}
Err(e) => {
warn!("Reasoning parsing error: {}", e);
}
}
}
(delta.to_string(), None)
}
/// Helper: Process tool calls in streaming mode
/// Returns (should_skip_content, chunks_to_emit)
#[allow(clippy::too_many_arguments)]
async fn process_tool_calls_stream(
&self,
delta: &str,
index: u32,
tool_parsers: &mut HashMap<
u32,
Arc<tokio::sync::Mutex<Box<dyn crate::tool_parser::ToolParser>>>,
>,
has_tool_calls: &mut HashMap<u32, bool>,
tools: &[crate::protocols::spec::Tool],
request_id: &str,
model: &str,
created: u64,
history_tool_calls_count: usize,
) -> (bool, Vec<ChatCompletionStreamResponse>) {
let mut chunks = Vec::new();
// Get or create parser for this index
tool_parsers
.entry(index)
.or_insert_with(|| self.tool_parser_factory.get_pooled(model));
if let Some(pooled_parser) = tool_parsers.get(&index) {
let mut parser = pooled_parser.lock().await;
match parser.parse_incremental(delta, tools).await {
Ok(StreamingParseResult { normal_text, calls }) => {
// Emit normal text if present
if !normal_text.is_empty() {
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(normal_text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// Emit tool call chunks
for tool_call_item in calls {
has_tool_calls.insert(index, true);
let tool_call_id = if let Some(ref name) = tool_call_item.name {
Some(Self::generate_tool_call_id(
model,
name,
tool_call_item.tool_index,
history_tool_calls_count,
))
} else {
None
};
let tool_call_delta = ToolCallDelta {
index: tool_call_item.tool_index as u32,
id: tool_call_id,
tool_type: if tool_call_item.name.is_some() {
Some("function".to_string())
} else {
None
},
function: Some(FunctionCallDelta {
name: tool_call_item.name,
arguments: if !tool_call_item.parameters.is_empty() {
Some(tool_call_item.parameters)
} else {
None
},
}),
};
chunks.push(ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![tool_call_delta]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
});
}
// If we emitted chunks, skip regular content
return (!chunks.is_empty(), chunks);
}
Err(e) => {
warn!("Tool call parsing error: {}", e);
}
}
}
(false, chunks)
}
/// Helper: Create content chunk
fn create_content_chunk(
content: String,
index: u32,
request_id: &str,
model: &str,
created: u64,
logprobs: Option<crate::protocols::spec::ChatLogProbs>,
) -> ChatCompletionStreamResponse {
ChatCompletionStreamResponse {
id: request_id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(content),
tool_calls: None,
reasoning_content: None,
},
logprobs,
finish_reason: None,
matched_stop: None,
}],
usage: None,
}
}
/// Helper: Format response as SSE chunk
fn format_sse_chunk(response: &ChatCompletionStreamResponse) -> String {
format!(
"data: {}\n\n",
serde_json::to_string(response).unwrap_or_default()
)
}
/// Submit request and handle streaming response for chat completions route /// Submit request and handle streaming response for chat completions route
async fn handle_streaming_chat( async fn handle_streaming_chat(
&self, &self,
...@@ -960,14 +1242,13 @@ impl GrpcRouter { ...@@ -960,14 +1242,13 @@ impl GrpcRouter {
request: proto::GenerateRequest, request: proto::GenerateRequest,
original_request: &ChatCompletionRequest, original_request: &ChatCompletionRequest,
) -> Response { ) -> Response {
let mut stop_decoder = self.create_stop_decoder( let request_id = request.request_id.clone();
original_request.stop.as_ref(), let model = original_request.model.clone();
original_request.stop_token_ids.as_ref(),
original_request.skip_special_tokens, // Create channel for SSE streaming
original_request.no_stop_trim, let (tx, rx) = mpsc::unbounded_channel::<Result<Bytes, io::Error>>();
);
// Process streaming tokens // Start the gRPC stream
let mut grpc_stream = match client.generate(request).await { let mut grpc_stream = match client.generate(request).await {
Ok(stream) => stream, Ok(stream) => stream,
Err(e) => { Err(e) => {
...@@ -980,49 +1261,414 @@ impl GrpcRouter { ...@@ -980,49 +1261,414 @@ impl GrpcRouter {
} }
}; };
let mut decoded_text = String::new(); let stop_params = (
original_request.stop.clone(),
original_request.stop_token_ids.clone(),
original_request.skip_special_tokens,
original_request.no_stop_trim,
);
// Spawn processing task
let self_clone = self.clone();
let original_request_clone = original_request.clone();
tokio::spawn(async move {
let result = Self::process_streaming_chunks(
&self_clone,
&mut grpc_stream,
request_id,
model,
stop_params,
original_request_clone,
&tx,
)
.await;
if let Err(e) = result {
let error_chunk = format!(
"data: {}\n\n",
json!({
"error": {
"message": e,
"type": "internal_error"
}
})
);
let _ = tx.send(Ok(Bytes::from(error_chunk)));
}
// Send DONE marker
let _ = tx.send(Ok(Bytes::from("data: [DONE]\n\n")));
});
// Create response with SSE headers
let stream = UnboundedReceiverStream::new(rx);
let mut response = Response::new(Body::from_stream(stream));
*response.status_mut() = StatusCode::OK;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
.headers_mut()
.insert("Cache-Control", HeaderValue::from_static("no-cache"));
response
.headers_mut()
.insert("Connection", HeaderValue::from_static("keep-alive"));
response
}
/// Process streaming chunks and send SSE events
async fn process_streaming_chunks(
router: &GrpcRouter,
grpc_stream: &mut (impl tokio_stream::Stream<Item = Result<proto::GenerateResponse, tonic::Status>>
+ Unpin),
request_id: String,
model: String,
stop_params: (Option<StringOrArray>, Option<Vec<u32>>, bool, bool),
original_request: ChatCompletionRequest,
tx: &mpsc::UnboundedSender<Result<Bytes, io::Error>>,
) -> Result<(), String> {
// Extract request parameters
let separate_reasoning = original_request.separate_reasoning;
let tool_choice = &original_request.tool_choice;
let tools = &original_request.tools;
let history_tool_calls_count = Self::get_history_tool_calls_count(&original_request);
let stream_options = &original_request.stream_options;
// Phase 1: Initialize state tracking (per-index for n>1 support)
let mut is_firsts: HashMap<u32, bool> = HashMap::new();
let mut stream_buffers: HashMap<u32, String> = HashMap::new();
let mut finish_reasons: HashMap<u32, String> = HashMap::new();
let mut matched_stops: HashMap<u32, Option<Value>> = HashMap::new();
let mut prompt_tokens: HashMap<u32, u32> = HashMap::new();
let mut completion_tokens: HashMap<u32, u32> = HashMap::new();
let mut cached_tokens: HashMap<u32, u32> = HashMap::new();
// Parser state (lazy initialization per index)
type PooledReasoningParser =
Arc<std::sync::Mutex<Box<dyn crate::reasoning_parser::ReasoningParser>>>;
let mut reasoning_parsers: HashMap<u32, PooledReasoningParser> = HashMap::new();
type PooledToolParser = Arc<tokio::sync::Mutex<Box<dyn crate::tool_parser::ToolParser>>>;
let mut tool_parsers: HashMap<u32, PooledToolParser> = HashMap::new();
let mut has_tool_calls: HashMap<u32, bool> = HashMap::new();
// Create stop decoder
let (stop, stop_token_ids, skip_special_tokens, no_stop_trim) = stop_params;
let mut stop_decoder = router.create_stop_decoder(
stop.as_ref(),
stop_token_ids.as_ref(),
skip_special_tokens,
no_stop_trim,
);
let created = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
// Phase 2: Main streaming loop
while let Some(response) = grpc_stream.next().await { while let Some(response) = grpc_stream.next().await {
let gen_response = match response { let gen_response = response.map_err(|e| format!("Stream error: {}", e))?;
Ok(resp) => resp,
Err(e) => {
error!("Stream error: {}", e);
break;
}
};
match gen_response.response { match gen_response.response {
Some(Chunk(chunk)) => { Some(Chunk(chunk)) => {
// Process tokens and check if we should stop let index = chunk.index;
let (chunk_text, should_stop) =
// Process tokens through stop decoder
let (chunk_text, _should_stop) =
Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids); Self::process_chunk_tokens(&mut stop_decoder, &chunk.token_ids);
decoded_text.push_str(&chunk_text);
if should_stop { if chunk_text.is_empty() {
break; continue;
}
// Process logprobs if present
let choice_logprobs = if let Some(ref proto_logprobs) = chunk.output_logprobs {
match router.convert_proto_to_openai_logprobs(proto_logprobs) {
Ok(logprobs) => Some(logprobs),
Err(e) => {
warn!("Failed to process logprobs: {}", e);
None
}
}
} else {
None
};
// Initialize stream buffer if first time
let stream_buffer = stream_buffers.entry(index).or_default();
// Send first chunk with role
if is_firsts.get(&index).copied().unwrap_or(true) {
let first_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&first_chunk))))
.map_err(|_| "Failed to send first chunk".to_string())?;
is_firsts.insert(index, false);
}
// Calculate delta
let mut delta = chunk_text;
stream_buffer.push_str(&delta);
// Reasoning content handling
if separate_reasoning {
let (normal_text, reasoning_chunk) = router.process_reasoning_stream(
&delta,
index,
&mut reasoning_parsers,
&request_id,
&model,
created,
);
if let Some(chunk) = reasoning_chunk {
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
.map_err(|_| "Failed to send reasoning chunk".to_string())?;
}
delta = normal_text;
}
// Tool call handling
let tool_choice_enabled =
!matches!(tool_choice, Some(ToolChoice::Value(ToolChoiceValue::None)));
if tool_choice_enabled && tools.is_some() {
let (should_skip, tool_chunks) = router
.process_tool_calls_stream(
&delta,
index,
&mut tool_parsers,
&mut has_tool_calls,
tools.as_ref().unwrap(),
&request_id,
&model,
created,
history_tool_calls_count,
)
.await;
for chunk in tool_chunks {
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&chunk))))
.map_err(|_| "Failed to send tool call chunk".to_string())?;
}
if should_skip {
continue;
}
}
// Regular content emission
if !delta.is_empty() {
let content_chunk = Self::create_content_chunk(
delta,
index,
&request_id,
&model,
created,
choice_logprobs,
);
tx.send(Ok(Bytes::from(Self::format_sse_chunk(&content_chunk))))
.map_err(|_| "Failed to send content chunk".to_string())?;
} }
continue;
} }
Some(Complete(_complete)) => { Some(Complete(complete)) => {
// Flush any remaining text // Flush any remaining text
if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() { if let SequenceDecoderOutput::Text(text) = stop_decoder.flush() {
if !text.is_empty() { if !text.is_empty() {
decoded_text.push_str(&text); let index = complete.index;
debug!("Flushed text: {}", text); let stream_buffer = stream_buffers.entry(index).or_default();
stream_buffer.push_str(&text);
let content_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: Some(text),
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&content_chunk)
.map_err(|e| format!("Failed to serialize content chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send flushed content".to_string())?;
} }
} }
// Store metadata
let index = complete.index;
prompt_tokens.insert(index, complete.prompt_tokens as u32);
completion_tokens.insert(index, complete.completion_tokens as u32);
cached_tokens.insert(index, complete.cached_tokens as u32);
finish_reasons.insert(index, complete.finish_reason.clone());
// Extract matched_stop
let matched_stop_value = match &complete.matched_stop {
Some(proto::generate_complete::MatchedStop::MatchedTokenId(token_id)) => {
Some(Value::Number(serde_json::Number::from(*token_id)))
}
Some(proto::generate_complete::MatchedStop::MatchedStopStr(stop_str)) => {
Some(Value::String(stop_str.clone()))
}
None => None,
};
matched_stops.insert(index, matched_stop_value);
break; break;
} }
Some(Error(error)) => { Some(Error(error)) => {
error!("Generation error: {}", error.message); return Err(error.message);
break;
} }
None => continue, None => continue,
} }
} }
// TODO: Replace with proper SSE streaming response // Phase 3: Check unstreamed tool args
// For now, return the complete decoded text // Check if parsers have any remaining arguments that haven't been streamed yet
(StatusCode::OK, format!("Decoded text: {}", decoded_text)).into_response() for (index, parser) in &tool_parsers {
let parser_guard = parser.lock().await;
if let Some(unstreamed_items) = parser_guard.get_unstreamed_tool_args() {
for tool_call_item in unstreamed_items {
let tool_call_delta = ToolCallDelta {
index: tool_call_item.tool_index as u32,
id: None,
tool_type: None, // No type for argument deltas
function: Some(FunctionCallDelta {
name: None, // No name for argument deltas
arguments: if !tool_call_item.parameters.is_empty() {
Some(tool_call_item.parameters)
} else {
None
},
}),
};
let tool_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index: *index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: Some(vec![tool_call_delta]),
reasoning_content: None,
},
logprobs: None,
finish_reason: None,
matched_stop: None,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&tool_chunk)
.map_err(|e| format!("Failed to serialize tool chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send unstreamed tool args".to_string())?;
}
}
}
// Phase 4: Finish reason chunks
for (index, finish_reason) in finish_reasons.iter() {
let final_finish_reason =
if has_tool_calls.get(index).copied().unwrap_or(false) && finish_reason == "stop" {
"tool_calls".to_string()
} else {
finish_reason.clone()
};
let matched_stop_value = matched_stops.get(index).and_then(|v| v.clone());
let finish_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![ChatStreamChoice {
index: *index,
delta: ChatMessageDelta {
role: Some("assistant".to_string()),
content: None,
tool_calls: None,
reasoning_content: None,
},
logprobs: None,
finish_reason: Some(final_finish_reason),
matched_stop: matched_stop_value,
}],
usage: None,
};
let sse_chunk = serde_json::to_string(&finish_chunk)
.map_err(|e| format!("Failed to serialize finish chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send finish chunk".to_string())?;
}
// Phase 5: Usage chunk
if let Some(stream_opts) = stream_options {
if stream_opts.include_usage.unwrap_or(false) {
let total_prompt: u32 = prompt_tokens.values().sum();
let total_completion: u32 = completion_tokens.values().sum();
let usage_chunk = ChatCompletionStreamResponse {
id: request_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
system_fingerprint: None,
choices: vec![],
usage: Some(Usage {
prompt_tokens: total_prompt,
completion_tokens: total_completion,
total_tokens: total_prompt + total_completion,
completion_tokens_details: None,
}),
};
let sse_chunk = serde_json::to_string(&usage_chunk)
.map_err(|e| format!("Failed to serialize usage chunk: {}", e))?;
tx.send(Ok(Bytes::from(format!("data: {}\n\n", sse_chunk))))
.map_err(|_| "Failed to send usage chunk".to_string())?;
}
}
Ok(())
} }
/// Submit request and handle non-streaming response for chat completions route /// Submit request and handle non-streaming response for chat completions route
...@@ -1082,10 +1728,17 @@ impl GrpcRouter { ...@@ -1082,10 +1728,17 @@ impl GrpcRouter {
} }
// Process each response into a ChatChoice // Process each response into a ChatChoice
let history_tool_calls_count = Self::get_history_tool_calls_count(original_request);
let mut choices = Vec::new(); let mut choices = Vec::new();
for (index, complete) in all_responses.iter().enumerate() { for (index, complete) in all_responses.iter().enumerate() {
match self match self
.process_single_choice(complete, index, original_request, &mut stop_decoder) .process_single_choice(
complete,
index,
original_request,
&mut stop_decoder,
history_tool_calls_count,
)
.await .await
{ {
Ok(choice) => choices.push(choice), Ok(choice) => choices.push(choice),
...@@ -1216,11 +1869,12 @@ impl GrpcRouter { ...@@ -1216,11 +1869,12 @@ impl GrpcRouter {
decoded_text.push_str(&t); decoded_text.push_str(&t);
} }
let output_ids = complete.output_ids.clone(); let output_ids = std::mem::take(&mut complete.output_ids);
let finish_reason = std::mem::take(&mut complete.finish_reason);
// Build base meta_info using json! macro // Build base meta_info using json! macro
let mut meta_info = json!({ let mut meta_info = json!({
"finish_reason": complete.finish_reason.clone(), "finish_reason": finish_reason,
"prompt_tokens": complete.prompt_tokens, "prompt_tokens": complete.prompt_tokens,
"completion_tokens": complete.completion_tokens, "completion_tokens": complete.completion_tokens,
"cached_tokens": complete.cached_tokens, "cached_tokens": complete.cached_tokens,
...@@ -1269,9 +1923,13 @@ impl GrpcRouter { ...@@ -1269,9 +1923,13 @@ impl GrpcRouter {
}) })
.collect(); .collect();
// Build ChatLogProbsContent for each token // Build ChatLogProbsContent for each token (consume iterator to avoid clones)
for (i, &logprob) in proto_logprobs.token_logprobs.iter().enumerate() { for (i, (&logprob, token_text)) in proto_logprobs
let token_text = token_texts.get(i).cloned().unwrap_or_default(); .token_logprobs
.iter()
.zip(token_texts.into_iter())
.enumerate()
{
let bytes = Some(token_text.as_bytes().to_vec()); let bytes = Some(token_text.as_bytes().to_vec());
// Build top_logprobs for this position // Build top_logprobs for this position
...@@ -1324,6 +1982,7 @@ impl GrpcRouter { ...@@ -1324,6 +1982,7 @@ impl GrpcRouter {
index: usize, index: usize,
original_request: &ChatCompletionRequest, original_request: &ChatCompletionRequest,
stop_decoder: &mut StopSequenceDecoder, stop_decoder: &mut StopSequenceDecoder,
history_tool_calls_count: usize,
) -> Result<ChatChoice, String> { ) -> Result<ChatChoice, String> {
stop_decoder.reset(); stop_decoder.reset();
// Decode tokens // Decode tokens
...@@ -1401,7 +2060,11 @@ impl GrpcRouter { ...@@ -1401,7 +2060,11 @@ impl GrpcRouter {
self.parse_json_schema_response(&processed_text, &original_request.tool_choice); self.parse_json_schema_response(&processed_text, &original_request.tool_choice);
} else { } else {
(tool_calls, processed_text) = self (tool_calls, processed_text) = self
.parse_with_model_parser(&processed_text, &original_request.model) .parse_tool_calls(
&processed_text,
&original_request.model,
history_tool_calls_count,
)
.await; .await;
} }
} }
...@@ -1686,7 +2349,6 @@ mod tests { ...@@ -1686,7 +2349,6 @@ mod tests {
content: Some("Assistant response".to_string()), content: Some("Assistant response".to_string()),
name: None, name: None,
tool_calls: None, tool_calls: None,
function_call: None,
reasoning_content: None, reasoning_content: None,
}]; }];
......
...@@ -15,7 +15,7 @@ use crate::{ ...@@ -15,7 +15,7 @@ use crate::{
}, },
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse}, worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
}, },
reasoning_parser::ParserFactory, reasoning_parser::ReasoningParserFactory,
routers::{router_manager::RouterManager, RouterTrait}, routers::{router_manager::RouterManager, RouterTrait},
service_discovery::{start_service_discovery, ServiceDiscoveryConfig}, service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
tokenizer::{factory as tokenizer_factory, traits::Tokenizer}, tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
...@@ -45,7 +45,7 @@ pub struct AppContext { ...@@ -45,7 +45,7 @@ pub struct AppContext {
pub router_config: RouterConfig, pub router_config: RouterConfig,
pub rate_limiter: Arc<TokenBucket>, pub rate_limiter: Arc<TokenBucket>,
pub tokenizer: Option<Arc<dyn Tokenizer>>, pub tokenizer: Option<Arc<dyn Tokenizer>>,
pub reasoning_parser_factory: Option<ParserFactory>, pub reasoning_parser_factory: Option<ReasoningParserFactory>,
pub tool_parser_factory: Option<ToolParserFactory>, pub tool_parser_factory: Option<ToolParserFactory>,
pub worker_registry: Arc<WorkerRegistry>, pub worker_registry: Arc<WorkerRegistry>,
pub policy_registry: Arc<PolicyRegistry>, pub policy_registry: Arc<PolicyRegistry>,
...@@ -79,7 +79,7 @@ impl AppContext { ...@@ -79,7 +79,7 @@ impl AppContext {
tokenizer_factory::create_tokenizer(&tokenizer_path) tokenizer_factory::create_tokenizer(&tokenizer_path)
.map_err(|e| format!("Failed to create tokenizer: {e}"))?, .map_err(|e| format!("Failed to create tokenizer: {e}"))?,
); );
let reasoning_parser_factory = Some(ParserFactory::new()); let reasoning_parser_factory = Some(ReasoningParserFactory::new());
let tool_parser_factory = Some(ToolParserFactory::new()); let tool_parser_factory = Some(ToolParserFactory::new());
(tokenizer, reasoning_parser_factory, tool_parser_factory) (tokenizer, reasoning_parser_factory, tool_parser_factory)
......
...@@ -123,12 +123,7 @@ impl DeepSeekParser { ...@@ -123,12 +123,7 @@ impl DeepSeekParser {
let arguments = serde_json::to_string(&args) let arguments = serde_json::to_string(&args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("deepseek_call_{}", uuid::Uuid::new_v4());
Ok(ToolCall { Ok(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: func_name.to_string(), name: func_name.to_string(),
arguments, arguments,
...@@ -320,4 +315,8 @@ impl ToolParser for DeepSeekParser { ...@@ -320,4 +315,8 @@ impl ToolParser for DeepSeekParser {
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) self.has_tool_markers(text)
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
} }
...@@ -129,12 +129,7 @@ impl Glm4MoeParser { ...@@ -129,12 +129,7 @@ impl Glm4MoeParser {
let arguments_str = serde_json::to_string(&arguments) let arguments_str = serde_json::to_string(&arguments)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate ID
let id = format!("glm4_call_{}", uuid::Uuid::new_v4());
Ok(Some(ToolCall { Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: func_name.to_string(), name: func_name.to_string(),
arguments: arguments_str, arguments: arguments_str,
...@@ -321,4 +316,8 @@ impl ToolParser for Glm4MoeParser { ...@@ -321,4 +316,8 @@ impl ToolParser for Glm4MoeParser {
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) self.has_tool_markers(text)
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
} }
...@@ -113,12 +113,7 @@ impl ToolParser for GptOssParser { ...@@ -113,12 +113,7 @@ impl ToolParser for GptOssParser {
} }
}; };
// Generate unique ID
let id = format!("gpt_oss_call_{}", uuid::Uuid::new_v4());
tools.push(ToolCall { tools.push(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: function_name, name: function_name,
arguments, arguments,
......
...@@ -14,6 +14,48 @@ pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> { ...@@ -14,6 +14,48 @@ pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
.collect() .collect()
} }
/// Get unstreamed tool call arguments
/// Returns tool call items for arguments that have been parsed but not yet streamed
/// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk
pub fn get_unstreamed_args(
prev_tool_call_arr: &[Value],
streamed_args_for_tool: &[String],
) -> Option<Vec<ToolCallItem>> {
// Check if we have tool calls being tracked
if prev_tool_call_arr.is_empty() || streamed_args_for_tool.is_empty() {
return None;
}
// Get the last tool call that was being processed
let tool_index = prev_tool_call_arr.len() - 1;
if tool_index >= streamed_args_for_tool.len() {
return None;
}
// Get expected vs actual arguments
let expected_args = prev_tool_call_arr[tool_index].get("arguments")?;
let expected_str = serde_json::to_string(expected_args).ok()?;
let actual_str = &streamed_args_for_tool[tool_index];
// Check if there are remaining arguments to send
let remaining = if expected_str.starts_with(actual_str) {
&expected_str[actual_str.len()..]
} else {
return None;
};
if remaining.is_empty() {
return None;
}
// Return the remaining arguments as a ToolCallItem
Some(vec![ToolCallItem {
tool_index,
name: None, // No name for argument deltas
parameters: remaining.to_string(),
}])
}
/// Check if a buffer ends with a partial occurrence of a token /// Check if a buffer ends with a partial occurrence of a token
/// Returns Some(length) if there's a partial match, None otherwise /// Returns Some(length) if there's a partial match, None otherwise
pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> { pub fn ends_with_partial_token(buffer: &str, token: &str) -> Option<usize> {
......
...@@ -8,7 +8,7 @@ use crate::tool_parser::{ ...@@ -8,7 +8,7 @@ use crate::tool_parser::{
parsers::helpers, parsers::helpers,
partial_json::PartialJson, partial_json::PartialJson,
traits::ToolParser, traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall}, types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
}; };
/// JSON format parser for tool calls /// JSON format parser for tool calls
...@@ -136,16 +136,7 @@ impl JsonParser { ...@@ -136,16 +136,7 @@ impl JsonParser {
let arguments = serde_json::to_string(args) let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?; .map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
// Generate a unique ID if not provided
let id = obj
.get("id")
.and_then(|v| v.as_str())
.map(String::from)
.unwrap_or_else(|| format!("call_{}", uuid::Uuid::new_v4()));
Ok(Some(ToolCall { Ok(Some(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: name.to_string(), name: name.to_string(),
arguments, arguments,
...@@ -274,4 +265,8 @@ impl ToolParser for JsonParser { ...@@ -274,4 +265,8 @@ impl ToolParser for JsonParser {
let trimmed = text.trim(); let trimmed = text.trim();
(trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#) (trimmed.starts_with('[') || trimmed.starts_with('{')) && trimmed.contains(r#""name""#)
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
} }
...@@ -131,12 +131,7 @@ impl ToolParser for KimiK2Parser { ...@@ -131,12 +131,7 @@ impl ToolParser for KimiK2Parser {
// Try to parse JSON arguments // Try to parse JSON arguments
match serde_json::from_str::<serde_json::Value>(function_args) { match serde_json::from_str::<serde_json::Value>(function_args) {
Ok(_) => { Ok(_) => {
// Generate unique ID
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
tools.push(ToolCall { tools.push(ToolCall {
id,
r#type: "function".to_string(),
function: FunctionCall { function: FunctionCall {
name: func_name, name: func_name,
arguments: function_args.to_string(), arguments: function_args.to_string(),
...@@ -339,4 +334,8 @@ impl ToolParser for KimiK2Parser { ...@@ -339,4 +334,8 @@ impl ToolParser for KimiK2Parser {
fn detect_format(&self, text: &str) -> bool { fn detect_format(&self, text: &str) -> bool {
self.has_tool_markers(text) || text.contains("<|tool_call_begin|>") self.has_tool_markers(text) || text.contains("<|tool_call_begin|>")
} }
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment