Unverified Commit fc16a79b authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat: support batch `/completions` (#1626)


Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 3e1a5534
...@@ -35,22 +35,62 @@ ARG ARCH_ALT=x86_64 ...@@ -35,22 +35,62 @@ ARG ARCH_ALT=x86_64
WORKDIR /sgl-workspace WORKDIR /sgl-workspace
# Install UCX dependencies
RUN apt-get update -y && \
apt-get install -y --no-install-recommends \
--reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev \
libnuma-dev librdmacm-dev ibverbs-providers \
autoconf libtool
# Build UCX from source
ARG NIXL_UCX_REF=v1.19.x
RUN rm -rf /opt/hpcx/ucx && \
rm -rf /usr/local/ucx && \
cd /usr/local/src && \
git clone https://github.com/openucx/ucx.git && \
cd ucx && \
git checkout $NIXL_UCX_REF && \
./autogen.sh && ./configure \
--prefix=/usr/local/ucx \
--enable-shared \
--disable-static \
--disable-doxygen-doc \
--enable-optimizations \
--enable-cma \
--enable-devel-headers \
--with-cuda=/usr/local/cuda \
--with-verbs \
--with-efa \
--with-dm \
--with-gdrcopy=/usr/local \
--enable-mt && \
make -j && \
make -j install-strip && \
ldconfig
ENV LD_LIBRARY_PATH=/usr/lib:/usr/local/ucx/lib:$LD_LIBRARY_PATH
# Pinning to NIXL 0.2.1 right now # Pinning to NIXL 0.2.1 right now
# TODO: investigate pip install failure with 0.3.0 release # TODO: investigate pip install failure with 0.3.0 release
ARG NIXL_COMMIT="5e4c179ee850d482a83cb2a211e0947e46281060" ARG NIXL_COMMIT="5e4c179ee850d482a83cb2a211e0947e46281060"
RUN git clone https://github.com/ai-dynamo/nixl.git && cd nixl && git checkout ${NIXL_COMMIT} &&pip install --break-system-packages . --config-settings=setup-args="-Ducx_path=/opt/hpcx/ucx" RUN git clone https://github.com/ai-dynamo/nixl.git && cd nixl && git checkout ${NIXL_COMMIT} && pip install --break-system-packages . --config-settings=setup-args="-Ducx_path=/usr/local/ucx"
WORKDIR /sgl-workspace WORKDIR /sgl-workspace
RUN pip uninstall --break-system-packages -y sglang RUN pip uninstall --break-system-packages -y sglang
RUN rm -rf sglang RUN rm -rf sglang
# 0.4.7 # 0.4.8 has a bug with CUDA graphs and decode worker
RUN pip install --break-system-packages "sglang==0.4.7" # https://github.com/sgl-project/sglang/issues/7511
RUN pip install --break-system-packages "sglang==0.4.7.post1"
# Allow forceful shutdown of inflight requests
ENV SGL_FORCE_SHUTDOWN=1
WORKDIR /sgl-workspace WORKDIR /sgl-workspace
# https://github.com/ai-dynamo/dynamo/pull/1510 # https://github.com/ai-dynamo/dynamo/pull/1510
ARG DYNAMO_COMMIT="382e3aedc421b3b3abc338062b332b54b5aa8529" ARG DYNAMO_COMMIT="382e3aedc421b3b3abc338062b332b54b5aa8529"
RUN git clone https://github.com/ai-dynamo/dynamo.git && cd dynamo && git checkout ${DYNAMO_COMMIT} ARG DYNAMO_BRANCH="ishan/cmpl-token-id"
RUN git clone https://github.com/ai-dynamo/dynamo.git && cd dynamo && git checkout ${DYNAMO_BRANCH}
# install dynamo in editable mode # install dynamo in editable mode
WORKDIR /sgl-workspace/dynamo WORKDIR /sgl-workspace/dynamo
......
...@@ -106,12 +106,12 @@ Dynamo supports SGLang's implementation of wide expert parallelism and large sca ...@@ -106,12 +106,12 @@ Dynamo supports SGLang's implementation of wide expert parallelism and large sca
Steps to run: Steps to run:
1. Build the SGLang DeepEP container 1. Build the SGLang DeepEP container.
```bash ```bash
git clone https://github.com/sgl-project/sglang.git git clone -b v0.4.8 https://github.com/sgl-project/sglang.git
cd sglang/docker cd sglang/docker
docker build -f Dockerfile.deepep -t deepep . docker build -f Dockerfile -t deepep .
``` ```
You will now have a `deepep:latest` image You will now have a `deepep:latest` image
......
...@@ -45,7 +45,9 @@ class SGLangDecodeWorker: ...@@ -45,7 +45,9 @@ class SGLangDecodeWorker:
@endpoint() @endpoint()
async def generate(self, req: DisaggPreprocessedRequest): async def generate(self, req: DisaggPreprocessedRequest):
g = await self.engine.async_generate( g = await self.engine.async_generate(
input_ids=req.request.token_ids, input_ids=req.request.token_ids
if req.request.batch_token_ids is None
else req.request.batch_token_ids,
sampling_params=req.sampling_params, sampling_params=req.sampling_params,
stream=True, stream=True,
bootstrap_host=req.bootstrap_host, bootstrap_host=req.bootstrap_host,
......
...@@ -28,6 +28,7 @@ import asyncio ...@@ -28,6 +28,7 @@ import asyncio
import logging import logging
import random import random
import socket import socket
from typing import Dict, Union
import sglang as sgl import sglang as sgl
from components.decode_worker import SGLangDecodeWorker from components.decode_worker import SGLangDecodeWorker
...@@ -112,63 +113,123 @@ class SGLangWorker: ...@@ -112,63 +113,123 @@ class SGLangWorker:
sampling_params["ignore_eos"] = request.stop_conditions.ignore_eos sampling_params["ignore_eos"] = request.stop_conditions.ignore_eos
return sampling_params return sampling_params
def _get_request_batch_size(self, request: PreprocessedRequest):
"""Get batch size from request, returns None for single requests"""
if request.batch_token_ids is not None:
return len(request.batch_token_ids)
return None
def _is_batch_request(self, request: PreprocessedRequest):
"""Check if request is in batch mode"""
return request.batch_token_ids is not None
@endpoint() @endpoint()
async def generate(self, request: PreprocessedRequest): async def generate(self, request: PreprocessedRequest):
# Check if we're in batch mode at the start
is_batch = self._is_batch_request(request)
batch_size = self._get_request_batch_size(request)
# TODO: maintain a mapping from SGLang's Ouput struct to LLMEngineOuput # TODO: maintain a mapping from SGLang's Ouput struct to LLMEngineOuput
sampling_params = self._build_sampling_params(request) sampling_params = self._build_sampling_params(request)
if self.engine_args.disaggregation_mode != "null": if self.engine_args.disaggregation_mode != "null":
bootstrap_room = self._generate_bootstrap_room() if is_batch:
bootstrap_room = [
self._generate_bootstrap_room() for _ in range(batch_size)
]
bootstrap_host = [self.bootstrap_host] * batch_size
bootstrap_port = [self.bootstrap_port] * batch_size
else:
bootstrap_host = self.bootstrap_host
bootstrap_port = self.bootstrap_port
bootstrap_room = self._generate_bootstrap_room()
# decode worker request # decode worker request
disagg_request = DisaggPreprocessedRequest( disagg_request = DisaggPreprocessedRequest(
request=request, request=request,
sampling_params=sampling_params, sampling_params=sampling_params,
bootstrap_host=self.bootstrap_host, bootstrap_host=bootstrap_host,
bootstrap_port=self.bootstrap_port, bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room, bootstrap_room=bootstrap_room,
) )
# prefill response is not used # prefill response is not used
prefill = await self.engine.async_generate( prefill = await self.engine.async_generate(
input_ids=request.token_ids, input_ids=request.token_ids
if not is_batch
else request.batch_token_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=True, stream=True,
bootstrap_host=self.bootstrap_host, bootstrap_host=bootstrap_host,
bootstrap_port=self.bootstrap_port, bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room, bootstrap_room=bootstrap_room,
) )
prefill_task = asyncio.create_task(self._prefill_generator(prefill)) prefill_task = asyncio.create_task(self._prefill_generator(prefill))
decode = await self.decode_client.generate(disagg_request.model_dump_json()) decode = await self.decode_client.generate(disagg_request.model_dump_json())
async for out in self._process_stream(decode, unpack=True): async for out in self._process_stream(
decode, unpack=True, is_batch=is_batch
):
yield out yield out
await prefill_task await prefill_task
else: else:
g = await self.engine.async_generate( g = await self.engine.async_generate(
input_ids=request.token_ids, input_ids=request.token_ids
if not is_batch
else request.batch_token_ids,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=True, stream=True,
) )
async for out in self._process_stream(g, unpack=False): async for out in self._process_stream(g, unpack=False, is_batch=is_batch):
yield out yield out
async def _process_stream(self, stream_source, unpack: bool): async def _process_stream(self, stream_source, unpack: bool, is_batch: bool):
num_output_tokens_so_far = 0 # Initialize based on batch mode
num_output_tokens_so_far: Union[Dict[int, int], int]
if is_batch:
num_output_tokens_so_far = {}
else:
num_output_tokens_so_far = 0
async for res in stream_source: async for res in stream_source:
data = res.data() if unpack else res data = res.data() if unpack else res
finish_reason = data["meta_info"]["finish_reason"] finish_reason = data["meta_info"]["finish_reason"]
if finish_reason:
# Don't forward the stop token if is_batch:
out = {"token_ids": [], "finish_reason": finish_reason["type"]} # Handle batch response
assert isinstance(num_output_tokens_so_far, dict)
index = data.get("index", 0)
if index not in num_output_tokens_so_far:
num_output_tokens_so_far[index] = 0
if finish_reason:
out = {
"token_ids": [],
"finish_reason": finish_reason["type"],
"index": index,
}
else:
next_total_toks = len(data["output_ids"])
new_tokens = data["output_ids"][num_output_tokens_so_far[index] :]
out = {
"token_ids": new_tokens,
"index": index,
}
num_output_tokens_so_far[index] = next_total_toks
else: else:
next_total_toks = len(data["output_ids"]) # Handle single response
out = {"token_ids": data["output_ids"][num_output_tokens_so_far:]} assert isinstance(num_output_tokens_so_far, int)
if finish_reason:
out = {"token_ids": [], "finish_reason": finish_reason["type"]}
else:
next_total_toks = len(data["output_ids"])
out = {"token_ids": data["output_ids"][num_output_tokens_so_far:]}
num_output_tokens_so_far = next_total_toks
yield out yield out
num_output_tokens_so_far = next_total_toks
def _generate_bootstrap_room(self): def _generate_bootstrap_room(self):
return random.randint(0, 2**63 - 1) return random.randint(0, 2**63 - 1)
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List, Optional from typing import List, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
...@@ -47,6 +47,7 @@ class SamplingOptions(BaseModel): ...@@ -47,6 +47,7 @@ class SamplingOptions(BaseModel):
class PreprocessedRequest(BaseModel): class PreprocessedRequest(BaseModel):
token_ids: List[TokenIdType] token_ids: List[TokenIdType]
batch_token_ids: Optional[List[List[TokenIdType]]] = None
stop_conditions: StopConditions stop_conditions: StopConditions
sampling_options: SamplingOptions sampling_options: SamplingOptions
eos_token_ids: List[TokenIdType] = Field(default_factory=list) eos_token_ids: List[TokenIdType] = Field(default_factory=list)
...@@ -57,7 +58,7 @@ class PreprocessedRequest(BaseModel): ...@@ -57,7 +58,7 @@ class PreprocessedRequest(BaseModel):
class DisaggPreprocessedRequest(BaseModel): class DisaggPreprocessedRequest(BaseModel):
request: PreprocessedRequest request: PreprocessedRequest
sampling_params: dict sampling_params: dict
bootstrap_host: str bootstrap_host: Union[str, List[str]]
bootstrap_port: int bootstrap_port: Union[int, List[int]]
bootstrap_room: int bootstrap_room: Union[int, List[int]]
data_parallel_rank: Optional[int] = None data_parallel_rank: Optional[int] = None
...@@ -60,22 +60,71 @@ class RequestHandler: ...@@ -60,22 +60,71 @@ class RequestHandler:
# sglang defaults this to 128 # sglang defaults this to 128
"max_new_tokens": request["stop_conditions"]["max_tokens"], "max_new_tokens": request["stop_conditions"]["max_tokens"],
} }
num_output_tokens_so_far = 0
gen = await self.engine_client.async_generate( # Check if this is a batch request
input_ids=request["token_ids"], sampling_params=sampling_params, stream=True is_batch = "batch_token_ids" in request and request["batch_token_ids"]
)
if is_batch:
# Track tokens separately for each batch item
num_output_tokens_so_far = {}
logging.debug("received batch token ids")
gen = await self.engine_client.async_generate(
input_ids=request["batch_token_ids"],
sampling_params=sampling_params,
stream=True,
)
else:
num_output_tokens_so_far = 0
logging.debug("received token ids")
gen = await self.engine_client.async_generate(
input_ids=request["token_ids"],
sampling_params=sampling_params,
stream=True,
)
async for res in gen: async for res in gen:
# res is a dict # res is a dict
logging.debug(f"res: {res}")
finish_reason = res["meta_info"]["finish_reason"] finish_reason = res["meta_info"]["finish_reason"]
if finish_reason:
# Don't forward the stop token if is_batch:
out = {"token_ids": [], "finish_reason": finish_reason["type"]} # Handle batch response - get index from SGLang response
index = res.get("index", 0)
if index not in num_output_tokens_so_far:
num_output_tokens_so_far[index] = 0
if finish_reason:
logging.warning(f"finish_reason: {finish_reason}")
# Final response for this batch item
out = {
"token_ids": [],
"finish_reason": finish_reason["type"],
"index": index,
}
else:
# Streaming response for this batch item
next_total_toks = len(res["output_ids"])
new_tokens = res["output_ids"][num_output_tokens_so_far[index] :]
out = {
"token_ids": new_tokens,
"index": index,
}
num_output_tokens_so_far[index] = next_total_toks
else: else:
next_total_toks = len(res["output_ids"]) if finish_reason:
out = {"token_ids": res["output_ids"][num_output_tokens_so_far:]} out = {
"token_ids": [],
"finish_reason": finish_reason["type"],
}
else:
next_total_toks = len(res["output_ids"])
new_tokens = res["output_ids"][num_output_tokens_so_far:]
out = {
"token_ids": new_tokens,
}
num_output_tokens_so_far = next_total_toks
yield out yield out
num_output_tokens_so_far = next_total_toks
class EmbeddingRequestHandler(RequestHandler): class EmbeddingRequestHandler(RequestHandler):
......
...@@ -269,6 +269,7 @@ fn run_request( ...@@ -269,6 +269,7 @@ fn run_request(
cum_log_probs: None, // TODO output.cumulative_logprob.map(|v| v as f64), cum_log_probs: None, // TODO output.cumulative_logprob.map(|v| v as f64),
log_probs: None, // TODO output.logprobs log_probs: None, // TODO output.logprobs
finish_reason: None, finish_reason: None,
index: None,
}; };
work_request work_request
.response_channel .response_channel
......
...@@ -224,6 +224,7 @@ impl ...@@ -224,6 +224,7 @@ impl
log_probs: data.log_probs, log_probs: data.log_probs,
finish_reason: data.finish_reason, finish_reason: data.finish_reason,
//mdcsum: mdcsum.clone(), //mdcsum: mdcsum.clone(),
index: data.index,
}) })
}) })
}); });
......
...@@ -115,6 +115,7 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> { ...@@ -115,6 +115,7 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
finish_reason: None, finish_reason: None,
index: None,
}; };
Annotated::from_data(delta) Annotated::from_data(delta)
} }
......
...@@ -53,7 +53,7 @@ use crate::protocols::{ ...@@ -53,7 +53,7 @@ use crate::protocols::{
}; };
use crate::tokenizers::{traits::Tokenizer, HuggingFaceTokenizer}; use crate::tokenizers::{traits::Tokenizer, HuggingFaceTokenizer};
use crate::preprocessor::prompt::PromptFormatter; use crate::preprocessor::prompt::{PromptFormatter, PromptInput, TextInput, TokenInput};
pub use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest}; pub use crate::protocols::common::llm_backend::{BackendOutput, PreprocessedRequest};
...@@ -160,33 +160,79 @@ impl OpenAIPreprocessor { ...@@ -160,33 +160,79 @@ impl OpenAIPreprocessor {
let mut annotations = HashMap::new(); let mut annotations = HashMap::new();
let mut builder = PreprocessedRequest::builder(); let mut builder = PreprocessedRequest::builder();
let use_raw_prompt = request // match request type before any conversion/processing
.nvext() match request.prompt_input_type() {
.is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false)); PromptInput::Tokens(_) => {
if let Some(token_input) = request.extract_tokens() {
let formatted_prompt = if use_raw_prompt { match token_input {
match request.raw_prompt() { TokenInput::Single(tokens) => {
Some(prompt) => prompt, builder.token_ids(tokens);
None => { }
tracing::warn!("Raw prompt requested but not available"); TokenInput::Batch(token_batches) => {
self.formatter.render(request)? if token_batches.len() == 1 {
builder.token_ids(token_batches[0].clone());
} else {
builder.batch_token_ids(Some(token_batches));
builder.token_ids(vec![]);
}
}
}
} }
} }
} else { PromptInput::Text(_) => {
self.formatter.render(request)? if let Some(text_input) = request.extract_text() {
}; match text_input {
TextInput::Single(_) => {
let encoding = tokio::task::block_in_place(|| self.tokenizer.encode(&formatted_prompt))?; let use_raw_prompt = request
.nvext()
.is_some_and(|ext| ext.use_raw_prompt.unwrap_or(false));
let formatted_prompt = if use_raw_prompt {
match request.raw_prompt() {
Some(prompt) => prompt,
None => {
tracing::warn!("Raw prompt requested but not available");
self.formatter.render(request)?
}
}
} else {
self.formatter.render(request)?
};
let encoding = tokio::task::block_in_place(|| {
self.tokenizer.encode(&formatted_prompt)
})?;
if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) {
annotations.insert(
ANNOTATION_FORMATTED_PROMPT.to_string(),
formatted_prompt,
);
}
if request.has_annotation(ANNOTATION_FORMATTED_PROMPT) { if request.has_annotation(ANNOTATION_TOKEN_IDS) {
annotations.insert(ANNOTATION_FORMATTED_PROMPT.to_string(), formatted_prompt); annotations.insert(
} ANNOTATION_TOKEN_IDS.to_string(),
serde_json::to_string(&encoding.token_ids)?,
);
}
if request.has_annotation(ANNOTATION_TOKEN_IDS) { builder.token_ids(encoding.token_ids);
annotations.insert( }
ANNOTATION_TOKEN_IDS.to_string(), TextInput::Batch(texts) => {
serde_json::to_string(&encoding.token_ids)?, let mut token_batches = Vec::new();
); // TODO: room for optimization here
for text in texts {
let encoding =
tokio::task::block_in_place(|| self.tokenizer.encode(&text))?;
token_batches.push(encoding.token_ids);
}
builder.batch_token_ids(Some(token_batches));
builder.token_ids(vec![]);
}
}
}
}
} }
let mut stop_conditions = request.extract_stop_conditions()?; let mut stop_conditions = request.extract_stop_conditions()?;
...@@ -207,9 +253,8 @@ impl OpenAIPreprocessor { ...@@ -207,9 +253,8 @@ impl OpenAIPreprocessor {
builder.eos_token_ids(self.model_info.eos_token_ids()); builder.eos_token_ids(self.model_info.eos_token_ids());
} }
builder.token_ids(encoding.token_ids);
builder.sampling_options(request.extract_sampling_options()?);
builder.stop_conditions(stop_conditions); builder.stop_conditions(stop_conditions);
builder.sampling_options(request.extract_sampling_options()?);
builder.annotations(request.annotations().unwrap_or_default()); builder.annotations(request.annotations().unwrap_or_default());
builder.mdc_sum(Some(self.mdcsum.clone())); builder.mdc_sum(Some(self.mdcsum.clone()));
builder.estimated_prefix_hit_num_blocks(None); builder.estimated_prefix_hit_num_blocks(None);
......
...@@ -38,6 +38,24 @@ mod template; ...@@ -38,6 +38,24 @@ mod template;
pub use template::ContextMixins; pub use template::ContextMixins;
#[derive(Debug)]
pub enum TokenInput {
Single(Vec<u32>),
Batch(Vec<Vec<u32>>),
}
#[derive(Debug)]
pub enum TextInput {
Single(String),
Batch(Vec<String>),
}
#[derive(Debug)]
pub enum PromptInput {
Tokens(TokenInput),
Text(TextInput),
}
/// Trait that defines a request that can map to an OpenAI-like request. /// Trait that defines a request that can map to an OpenAI-like request.
pub trait OAIChatLikeRequest { pub trait OAIChatLikeRequest {
fn messages(&self) -> Value; fn messages(&self) -> Value;
...@@ -49,6 +67,20 @@ pub trait OAIChatLikeRequest { ...@@ -49,6 +67,20 @@ pub trait OAIChatLikeRequest {
} }
fn should_add_generation_prompt(&self) -> bool; fn should_add_generation_prompt(&self) -> bool;
/// Returns the type of input for the prompt. Default is Text.
fn prompt_input_type(&self) -> PromptInput {
PromptInput::Text(TextInput::Single(String::new()))
}
/// Extract tokens if the input is pre-tokenized
fn extract_tokens(&self) -> Option<TokenInput> {
None
}
fn extract_text(&self) -> Option<TextInput> {
None
}
} }
pub trait OAIPromptFormatter: Send + Sync + 'static { pub trait OAIPromptFormatter: Send + Sync + 'static {
......
...@@ -22,6 +22,8 @@ use crate::protocols::openai::{ ...@@ -22,6 +22,8 @@ use crate::protocols::openai::{
}; };
use tracing; use tracing;
use crate::preprocessor::prompt::{PromptInput, TextInput, TokenInput};
impl OAIChatLikeRequest for NvCreateChatCompletionRequest { impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
fn messages(&self) -> Value { fn messages(&self) -> Value {
Value::from_serialize(&self.inner.messages) Value::from_serialize(&self.inner.messages)
...@@ -53,6 +55,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest { ...@@ -53,6 +55,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
true true
} }
} }
fn extract_text(&self) -> Option<TextInput> {
Some(TextInput::Single(String::new()))
}
} }
impl OAIChatLikeRequest for NvCreateCompletionRequest { impl OAIChatLikeRequest for NvCreateCompletionRequest {
...@@ -72,6 +78,48 @@ impl OAIChatLikeRequest for NvCreateCompletionRequest { ...@@ -72,6 +78,48 @@ impl OAIChatLikeRequest for NvCreateCompletionRequest {
fn should_add_generation_prompt(&self) -> bool { fn should_add_generation_prompt(&self) -> bool {
true true
} }
fn prompt_input_type(&self) -> PromptInput {
match &self.inner.prompt {
async_openai::types::Prompt::IntegerArray(_) => {
PromptInput::Tokens(TokenInput::Single(vec![]))
}
async_openai::types::Prompt::ArrayOfIntegerArray(_) => {
PromptInput::Tokens(TokenInput::Batch(vec![]))
}
async_openai::types::Prompt::String(_) => {
PromptInput::Text(TextInput::Single(String::new()))
}
async_openai::types::Prompt::StringArray(_) => {
PromptInput::Text(TextInput::Batch(vec![]))
}
}
}
fn extract_tokens(&self) -> Option<TokenInput> {
match &self.inner.prompt {
async_openai::types::Prompt::IntegerArray(tokens) => Some(TokenInput::Single(
tokens.iter().map(|&t| t as u32).collect(),
)),
async_openai::types::Prompt::ArrayOfIntegerArray(arrays) => Some(TokenInput::Batch(
arrays
.iter()
.map(|arr| arr.iter().map(|&t| t as u32).collect())
.collect(),
)),
_ => None,
}
}
fn extract_text(&self) -> Option<TextInput> {
match &self.inner.prompt {
async_openai::types::Prompt::String(text) => Some(TextInput::Single(text.to_string())),
async_openai::types::Prompt::StringArray(texts) => {
Some(TextInput::Batch(texts.to_vec()))
}
_ => None,
}
}
} }
impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter { impl OAIPromptFormatter for HfTokenizerConfigJsonFormatter {
......
...@@ -46,6 +46,9 @@ pub struct BackendOutput { ...@@ -46,6 +46,9 @@ pub struct BackendOutput {
pub finish_reason: Option<FinishReason>, pub finish_reason: Option<FinishReason>,
// Model Deployment Card checksum // Model Deployment Card checksum
//pub mdcsum: String, //pub mdcsum: String,
// Index field for batch requests to match OpenAI format
pub index: Option<u32>,
} }
/// The LLM engine and backnd with manage it's own state, specifically translating how a /// The LLM engine and backnd with manage it's own state, specifically translating how a
...@@ -77,6 +80,9 @@ pub struct LLMEngineOutput { ...@@ -77,6 +80,9 @@ pub struct LLMEngineOutput {
// TODO: Enrich this with more information as can apply our first-level postprocessing // TODO: Enrich this with more information as can apply our first-level postprocessing
// logic and return more detailed information // logic and return more detailed information
pub finish_reason: Option<FinishReason>, pub finish_reason: Option<FinishReason>,
// Index field for batch requests to match OpenAI format
pub index: Option<u32>,
} }
impl LLMEngineOutput { impl LLMEngineOutput {
...@@ -88,6 +94,7 @@ impl LLMEngineOutput { ...@@ -88,6 +94,7 @@ impl LLMEngineOutput {
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
finish_reason: Some(FinishReason::Cancelled), finish_reason: Some(FinishReason::Cancelled),
index: None,
} }
} }
...@@ -99,6 +106,7 @@ impl LLMEngineOutput { ...@@ -99,6 +106,7 @@ impl LLMEngineOutput {
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
finish_reason: Some(FinishReason::Stop), finish_reason: Some(FinishReason::Stop),
index: None,
} }
} }
...@@ -110,6 +118,7 @@ impl LLMEngineOutput { ...@@ -110,6 +118,7 @@ impl LLMEngineOutput {
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
finish_reason: Some(FinishReason::Length), finish_reason: Some(FinishReason::Length),
index: None,
} }
} }
...@@ -121,6 +130,7 @@ impl LLMEngineOutput { ...@@ -121,6 +130,7 @@ impl LLMEngineOutput {
cum_log_probs: None, cum_log_probs: None,
log_probs: None, log_probs: None,
finish_reason: Some(FinishReason::Error(err_msg)), finish_reason: Some(FinishReason::Error(err_msg)),
index: None,
} }
} }
} }
...@@ -26,6 +26,10 @@ pub struct PreprocessedRequest { ...@@ -26,6 +26,10 @@ pub struct PreprocessedRequest {
/// Type of prompt /// Type of prompt
pub token_ids: Vec<TokenIdType>, pub token_ids: Vec<TokenIdType>,
/// Batch Token Ids = for batch completion requests (i.e using ArrayOfIntegerArray type from OpenAI /completions)
#[builder(default)]
pub batch_token_ids: Option<Vec<Vec<TokenIdType>>>,
/// StopConditions are conditions that the inference engine will use to stop generation. /// StopConditions are conditions that the inference engine will use to stop generation.
pub stop_conditions: StopConditions, pub stop_conditions: StopConditions,
......
...@@ -131,8 +131,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe ...@@ -131,8 +131,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe
}; };
// create choice // create choice
let index = 0; let index = delta.index.unwrap_or(0).into();
Ok(self.create_choice(index, delta.text, finish_reason)) let response = self.create_choice(index, delta.text.clone(), finish_reason);
Ok(response)
} }
fn get_isl(&self) -> Option<u32> { fn get_isl(&self) -> Option<u32> {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment