Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
fc16a79b
Unverified
Commit
fc16a79b
authored
Jun 25, 2025
by
ishandhanani
Committed by
GitHub
Jun 25, 2025
Browse files
feat: support batch `/completions` (#1626)
Co-authored-by:
Ryan McCormick
<
rmccormick@nvidia.com
>
parent
3e1a5534
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
364 additions
and
68 deletions
+364
-68
container/Dockerfile.sglang-deepep
container/Dockerfile.sglang-deepep
+44
-4
examples/sglang/README.md
examples/sglang/README.md
+3
-3
examples/sglang/components/decode_worker.py
examples/sglang/components/decode_worker.py
+3
-1
examples/sglang/components/worker.py
examples/sglang/components/worker.py
+78
-17
examples/sglang/utils/protocol.py
examples/sglang/utils/protocol.py
+5
-4
launch/dynamo-run/src/subprocess/sglang_inc.py
launch/dynamo-run/src/subprocess/sglang_inc.py
+60
-11
lib/engines/llamacpp/src/lib.rs
lib/engines/llamacpp/src/lib.rs
+1
-0
lib/llm/src/backend.rs
lib/llm/src/backend.rs
+1
-0
lib/llm/src/engines.rs
lib/llm/src/engines.rs
+1
-0
lib/llm/src/preprocessor.rs
lib/llm/src/preprocessor.rs
+71
-26
lib/llm/src/preprocessor/prompt.rs
lib/llm/src/preprocessor/prompt.rs
+32
-0
lib/llm/src/preprocessor/prompt/template/oai.rs
lib/llm/src/preprocessor/prompt/template/oai.rs
+48
-0
lib/llm/src/protocols/common/llm_backend.rs
lib/llm/src/protocols/common/llm_backend.rs
+10
-0
lib/llm/src/protocols/common/preprocessor.rs
lib/llm/src/protocols/common/preprocessor.rs
+4
-0
lib/llm/src/protocols/openai/completions/delta.rs
lib/llm/src/protocols/openai/completions/delta.rs
+3
-2
No files found.
container/Dockerfile.sglang-deepep
View file @
fc16a79b
...
@@ -35,22 +35,62 @@ ARG ARCH_ALT=x86_64
...
@@ -35,22 +35,62 @@ ARG ARCH_ALT=x86_64
WORKDIR /sgl-workspace
WORKDIR /sgl-workspace
# Install UCX dependencies
RUN apt-get update -y && \
apt-get install -y --no-install-recommends \
--reinstall libibverbs-dev rdma-core ibverbs-utils libibumad-dev \
libnuma-dev librdmacm-dev ibverbs-providers \
autoconf libtool
# Build UCX from source
ARG NIXL_UCX_REF=v1.19.x
RUN rm -rf /opt/hpcx/ucx && \
rm -rf /usr/local/ucx && \
cd /usr/local/src && \
git clone https://github.com/openucx/ucx.git && \
cd ucx && \
git checkout $NIXL_UCX_REF && \
./autogen.sh && ./configure \
--prefix=/usr/local/ucx \
--enable-shared \
--disable-static \
--disable-doxygen-doc \
--enable-optimizations \
--enable-cma \
--enable-devel-headers \
--with-cuda=/usr/local/cuda \
--with-verbs \
--with-efa \
--with-dm \
--with-gdrcopy=/usr/local \
--enable-mt && \
make -j && \
make -j install-strip && \
ldconfig
ENV LD_LIBRARY_PATH=/usr/lib:/usr/local/ucx/lib:$LD_LIBRARY_PATH
# Pinning to NIXL 0.2.1 right now
# Pinning to NIXL 0.2.1 right now
# TODO: investigate pip install failure with 0.3.0 release
# TODO: investigate pip install failure with 0.3.0 release
ARG NIXL_COMMIT="5e4c179ee850d482a83cb2a211e0947e46281060"
ARG NIXL_COMMIT="5e4c179ee850d482a83cb2a211e0947e46281060"
RUN git clone https://github.com/ai-dynamo/nixl.git && cd nixl && git checkout ${NIXL_COMMIT} &&pip install --break-system-packages . --config-settings=setup-args="-Ducx_path=/
opt/hpcx
/ucx"
RUN git clone https://github.com/ai-dynamo/nixl.git && cd nixl && git checkout ${NIXL_COMMIT} &&
pip install --break-system-packages . --config-settings=setup-args="-Ducx_path=/
usr/local
/ucx"
WORKDIR /sgl-workspace
WORKDIR /sgl-workspace
RUN pip uninstall --break-system-packages -y sglang
RUN pip uninstall --break-system-packages -y sglang
RUN rm -rf sglang
RUN rm -rf sglang
# 0.4.7
# 0.4.8 has a bug with CUDA graphs and decode worker
RUN pip install --break-system-packages "sglang==0.4.7"
# https://github.com/sgl-project/sglang/issues/7511
RUN pip install --break-system-packages "sglang==0.4.7.post1"
# Allow forceful shutdown of inflight requests
ENV SGL_FORCE_SHUTDOWN=1
WORKDIR /sgl-workspace
WORKDIR /sgl-workspace
# https://github.com/ai-dynamo/dynamo/pull/1510
# https://github.com/ai-dynamo/dynamo/pull/1510
ARG DYNAMO_COMMIT="382e3aedc421b3b3abc338062b332b54b5aa8529"
ARG DYNAMO_COMMIT="382e3aedc421b3b3abc338062b332b54b5aa8529"
RUN git clone https://github.com/ai-dynamo/dynamo.git && cd dynamo && git checkout ${DYNAMO_COMMIT}
ARG DYNAMO_BRANCH="ishan/cmpl-token-id"
RUN git clone https://github.com/ai-dynamo/dynamo.git && cd dynamo && git checkout ${DYNAMO_BRANCH}
# install dynamo in editable mode
# install dynamo in editable mode
WORKDIR /sgl-workspace/dynamo
WORKDIR /sgl-workspace/dynamo
...
...
examples/sglang/README.md
View file @
fc16a79b
...
@@ -106,12 +106,12 @@ Dynamo supports SGLang's implementation of wide expert parallelism and large sca
...
@@ -106,12 +106,12 @@ Dynamo supports SGLang's implementation of wide expert parallelism and large sca
Steps to run:
Steps to run:
1.
Build the SGLang DeepEP container
1.
Build the SGLang DeepEP container
.
```
bash
```
bash
git clone https://github.com/sgl-project/sglang.git
git clone
-b
v0.4.8
https://github.com/sgl-project/sglang.git
cd
sglang/docker
cd
sglang/docker
docker build
-f
Dockerfile
.deepep
-t
deepep .
docker build
-f
Dockerfile
-t
deepep .
```
```
You will now have a
`deepep:latest`
image
You will now have a
`deepep:latest`
image
...
...
examples/sglang/components/decode_worker.py
View file @
fc16a79b
...
@@ -45,7 +45,9 @@ class SGLangDecodeWorker:
...
@@ -45,7 +45,9 @@ class SGLangDecodeWorker:
@
endpoint
()
@
endpoint
()
async
def
generate
(
self
,
req
:
DisaggPreprocessedRequest
):
async
def
generate
(
self
,
req
:
DisaggPreprocessedRequest
):
g
=
await
self
.
engine
.
async_generate
(
g
=
await
self
.
engine
.
async_generate
(
input_ids
=
req
.
request
.
token_ids
,
input_ids
=
req
.
request
.
token_ids
if
req
.
request
.
batch_token_ids
is
None
else
req
.
request
.
batch_token_ids
,
sampling_params
=
req
.
sampling_params
,
sampling_params
=
req
.
sampling_params
,
stream
=
True
,
stream
=
True
,
bootstrap_host
=
req
.
bootstrap_host
,
bootstrap_host
=
req
.
bootstrap_host
,
...
...
examples/sglang/components/worker.py
View file @
fc16a79b
...
@@ -28,6 +28,7 @@ import asyncio
...
@@ -28,6 +28,7 @@ import asyncio
import
logging
import
logging
import
random
import
random
import
socket
import
socket
from
typing
import
Dict
,
Union
import
sglang
as
sgl
import
sglang
as
sgl
from
components.decode_worker
import
SGLangDecodeWorker
from
components.decode_worker
import
SGLangDecodeWorker
...
@@ -112,63 +113,123 @@ class SGLangWorker:
...
@@ -112,63 +113,123 @@ class SGLangWorker:
sampling_params
[
"ignore_eos"
]
=
request
.
stop_conditions
.
ignore_eos
sampling_params
[
"ignore_eos"
]
=
request
.
stop_conditions
.
ignore_eos
return
sampling_params
return
sampling_params
def
_get_request_batch_size
(
self
,
request
:
PreprocessedRequest
):
"""Get batch size from request, returns None for single requests"""
if
request
.
batch_token_ids
is
not
None
:
return
len
(
request
.
batch_token_ids
)
return
None
def
_is_batch_request
(
self
,
request
:
PreprocessedRequest
):
"""Check if request is in batch mode"""
return
request
.
batch_token_ids
is
not
None
@
endpoint
()
@
endpoint
()
async
def
generate
(
self
,
request
:
PreprocessedRequest
):
async
def
generate
(
self
,
request
:
PreprocessedRequest
):
# Check if we're in batch mode at the start
is_batch
=
self
.
_is_batch_request
(
request
)
batch_size
=
self
.
_get_request_batch_size
(
request
)
# TODO: maintain a mapping from SGLang's Ouput struct to LLMEngineOuput
# TODO: maintain a mapping from SGLang's Ouput struct to LLMEngineOuput
sampling_params
=
self
.
_build_sampling_params
(
request
)
sampling_params
=
self
.
_build_sampling_params
(
request
)
if
self
.
engine_args
.
disaggregation_mode
!=
"null"
:
if
self
.
engine_args
.
disaggregation_mode
!=
"null"
:
bootstrap_room
=
self
.
_generate_bootstrap_room
()
if
is_batch
:
bootstrap_room
=
[
self
.
_generate_bootstrap_room
()
for
_
in
range
(
batch_size
)
]
bootstrap_host
=
[
self
.
bootstrap_host
]
*
batch_size
bootstrap_port
=
[
self
.
bootstrap_port
]
*
batch_size
else
:
bootstrap_host
=
self
.
bootstrap_host
bootstrap_port
=
self
.
bootstrap_port
bootstrap_room
=
self
.
_generate_bootstrap_room
()
# decode worker request
# decode worker request
disagg_request
=
DisaggPreprocessedRequest
(
disagg_request
=
DisaggPreprocessedRequest
(
request
=
request
,
request
=
request
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
bootstrap_host
=
self
.
bootstrap_host
,
bootstrap_host
=
bootstrap_host
,
bootstrap_port
=
self
.
bootstrap_port
,
bootstrap_port
=
bootstrap_port
,
bootstrap_room
=
bootstrap_room
,
bootstrap_room
=
bootstrap_room
,
)
)
# prefill response is not used
# prefill response is not used
prefill
=
await
self
.
engine
.
async_generate
(
prefill
=
await
self
.
engine
.
async_generate
(
input_ids
=
request
.
token_ids
,
input_ids
=
request
.
token_ids
if
not
is_batch
else
request
.
batch_token_ids
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
stream
=
True
,
stream
=
True
,
bootstrap_host
=
self
.
bootstrap_host
,
bootstrap_host
=
bootstrap_host
,
bootstrap_port
=
self
.
bootstrap_port
,
bootstrap_port
=
bootstrap_port
,
bootstrap_room
=
bootstrap_room
,
bootstrap_room
=
bootstrap_room
,
)
)
prefill_task
=
asyncio
.
create_task
(
self
.
_prefill_generator
(
prefill
))
prefill_task
=
asyncio
.
create_task
(
self
.
_prefill_generator
(
prefill
))
decode
=
await
self
.
decode_client
.
generate
(
disagg_request
.
model_dump_json
())
decode
=
await
self
.
decode_client
.
generate
(
disagg_request
.
model_dump_json
())
async
for
out
in
self
.
_process_stream
(
decode
,
unpack
=
True
):
async
for
out
in
self
.
_process_stream
(
decode
,
unpack
=
True
,
is_batch
=
is_batch
):
yield
out
yield
out
await
prefill_task
await
prefill_task
else
:
else
:
g
=
await
self
.
engine
.
async_generate
(
g
=
await
self
.
engine
.
async_generate
(
input_ids
=
request
.
token_ids
,
input_ids
=
request
.
token_ids
if
not
is_batch
else
request
.
batch_token_ids
,
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
stream
=
True
,
stream
=
True
,
)
)
async
for
out
in
self
.
_process_stream
(
g
,
unpack
=
False
):
async
for
out
in
self
.
_process_stream
(
g
,
unpack
=
False
,
is_batch
=
is_batch
):
yield
out
yield
out
async
def
_process_stream
(
self
,
stream_source
,
unpack
:
bool
):
async
def
_process_stream
(
self
,
stream_source
,
unpack
:
bool
,
is_batch
:
bool
):
num_output_tokens_so_far
=
0
# Initialize based on batch mode
num_output_tokens_so_far
:
Union
[
Dict
[
int
,
int
],
int
]
if
is_batch
:
num_output_tokens_so_far
=
{}
else
:
num_output_tokens_so_far
=
0
async
for
res
in
stream_source
:
async
for
res
in
stream_source
:
data
=
res
.
data
()
if
unpack
else
res
data
=
res
.
data
()
if
unpack
else
res
finish_reason
=
data
[
"meta_info"
][
"finish_reason"
]
finish_reason
=
data
[
"meta_info"
][
"finish_reason"
]
if
finish_reason
:
# Don't forward the stop token
if
is_batch
:
out
=
{
"token_ids"
:
[],
"finish_reason"
:
finish_reason
[
"type"
]}
# Handle batch response
assert
isinstance
(
num_output_tokens_so_far
,
dict
)
index
=
data
.
get
(
"index"
,
0
)
if
index
not
in
num_output_tokens_so_far
:
num_output_tokens_so_far
[
index
]
=
0
if
finish_reason
:
out
=
{
"token_ids"
:
[],
"finish_reason"
:
finish_reason
[
"type"
],
"index"
:
index
,
}
else
:
next_total_toks
=
len
(
data
[
"output_ids"
])
new_tokens
=
data
[
"output_ids"
][
num_output_tokens_so_far
[
index
]
:]
out
=
{
"token_ids"
:
new_tokens
,
"index"
:
index
,
}
num_output_tokens_so_far
[
index
]
=
next_total_toks
else
:
else
:
next_total_toks
=
len
(
data
[
"output_ids"
])
# Handle single response
out
=
{
"token_ids"
:
data
[
"output_ids"
][
num_output_tokens_so_far
:]}
assert
isinstance
(
num_output_tokens_so_far
,
int
)
if
finish_reason
:
out
=
{
"token_ids"
:
[],
"finish_reason"
:
finish_reason
[
"type"
]}
else
:
next_total_toks
=
len
(
data
[
"output_ids"
])
out
=
{
"token_ids"
:
data
[
"output_ids"
][
num_output_tokens_so_far
:]}
num_output_tokens_so_far
=
next_total_toks
yield
out
yield
out
num_output_tokens_so_far
=
next_total_toks
def
_generate_bootstrap_room
(
self
):
def
_generate_bootstrap_room
(
self
):
return
random
.
randint
(
0
,
2
**
63
-
1
)
return
random
.
randint
(
0
,
2
**
63
-
1
)
...
...
examples/sglang/utils/protocol.py
View file @
fc16a79b
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Union
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
...
@@ -47,6 +47,7 @@ class SamplingOptions(BaseModel):
...
@@ -47,6 +47,7 @@ class SamplingOptions(BaseModel):
class
PreprocessedRequest
(
BaseModel
):
class
PreprocessedRequest
(
BaseModel
):
token_ids
:
List
[
TokenIdType
]
token_ids
:
List
[
TokenIdType
]
batch_token_ids
:
Optional
[
List
[
List
[
TokenIdType
]]]
=
None
stop_conditions
:
StopConditions
stop_conditions
:
StopConditions
sampling_options
:
SamplingOptions
sampling_options
:
SamplingOptions
eos_token_ids
:
List
[
TokenIdType
]
=
Field
(
default_factory
=
list
)
eos_token_ids
:
List
[
TokenIdType
]
=
Field
(
default_factory
=
list
)
...
@@ -57,7 +58,7 @@ class PreprocessedRequest(BaseModel):
...
@@ -57,7 +58,7 @@ class PreprocessedRequest(BaseModel):
class
DisaggPreprocessedRequest
(
BaseModel
):
class
DisaggPreprocessedRequest
(
BaseModel
):
request
:
PreprocessedRequest
request
:
PreprocessedRequest
sampling_params
:
dict
sampling_params
:
dict
bootstrap_host
:
str
bootstrap_host
:
Union
[
str
,
List
[
str
]]
bootstrap_port
:
int
bootstrap_port
:
Union
[
int
,
List
[
int
]]
bootstrap_room
:
int
bootstrap_room
:
Union
[
int
,
List
[
int
]]
data_parallel_rank
:
Optional
[
int
]
=
None
data_parallel_rank
:
Optional
[
int
]
=
None
launch/dynamo-run/src/subprocess/sglang_inc.py
View file @
fc16a79b
...
@@ -60,22 +60,71 @@ class RequestHandler:
...
@@ -60,22 +60,71 @@ class RequestHandler:
# sglang defaults this to 128
# sglang defaults this to 128
"max_new_tokens"
:
request
[
"stop_conditions"
][
"max_tokens"
],
"max_new_tokens"
:
request
[
"stop_conditions"
][
"max_tokens"
],
}
}
num_output_tokens_so_far
=
0
gen
=
await
self
.
engine_client
.
async_generate
(
# Check if this is a batch request
input_ids
=
request
[
"token_ids"
],
sampling_params
=
sampling_params
,
stream
=
True
is_batch
=
"batch_token_ids"
in
request
and
request
[
"batch_token_ids"
]
)
if
is_batch
:
# Track tokens separately for each batch item
num_output_tokens_so_far
=
{}
logging
.
debug
(
"received batch token ids"
)
gen
=
await
self
.
engine_client
.
async_generate
(
input_ids
=
request
[
"batch_token_ids"
],
sampling_params
=
sampling_params
,
stream
=
True
,
)
else
:
num_output_tokens_so_far
=
0
logging
.
debug
(
"received token ids"
)
gen
=
await
self
.
engine_client
.
async_generate
(
input_ids
=
request
[
"token_ids"
],
sampling_params
=
sampling_params
,
stream
=
True
,
)
async
for
res
in
gen
:
async
for
res
in
gen
:
# res is a dict
# res is a dict
logging
.
debug
(
f
"res:
{
res
}
"
)
finish_reason
=
res
[
"meta_info"
][
"finish_reason"
]
finish_reason
=
res
[
"meta_info"
][
"finish_reason"
]
if
finish_reason
:
# Don't forward the stop token
if
is_batch
:
out
=
{
"token_ids"
:
[],
"finish_reason"
:
finish_reason
[
"type"
]}
# Handle batch response - get index from SGLang response
index
=
res
.
get
(
"index"
,
0
)
if
index
not
in
num_output_tokens_so_far
:
num_output_tokens_so_far
[
index
]
=
0
if
finish_reason
:
logging
.
warning
(
f
"finish_reason:
{
finish_reason
}
"
)
# Final response for this batch item
out
=
{
"token_ids"
:
[],
"finish_reason"
:
finish_reason
[
"type"
],
"index"
:
index
,
}
else
:
# Streaming response for this batch item
next_total_toks
=
len
(
res
[
"output_ids"
])
new_tokens
=
res
[
"output_ids"
][
num_output_tokens_so_far
[
index
]
:]
out
=
{
"token_ids"
:
new_tokens
,
"index"
:
index
,
}
num_output_tokens_so_far
[
index
]
=
next_total_toks
else
:
else
:
next_total_toks
=
len
(
res
[
"output_ids"
])
if
finish_reason
:
out
=
{
"token_ids"
:
res
[
"output_ids"
][
num_output_tokens_so_far
:]}
out
=
{
"token_ids"
:
[],
"finish_reason"
:
finish_reason
[
"type"
],
}
else
:
next_total_toks
=
len
(
res
[
"output_ids"
])
new_tokens
=
res
[
"output_ids"
][
num_output_tokens_so_far
:]
out
=
{
"token_ids"
:
new_tokens
,
}
num_output_tokens_so_far
=
next_total_toks
yield
out
yield
out
num_output_tokens_so_far
=
next_total_toks
class
EmbeddingRequestHandler
(
RequestHandler
):
class
EmbeddingRequestHandler
(
RequestHandler
):
...
...
lib/engines/llamacpp/src/lib.rs
View file @
fc16a79b
...
@@ -269,6 +269,7 @@ fn run_request(
...
@@ -269,6 +269,7 @@ fn run_request(
cum_log_probs
:
None
,
// TODO output.cumulative_logprob.map(|v| v as f64),
cum_log_probs
:
None
,
// TODO output.cumulative_logprob.map(|v| v as f64),
log_probs
:
None
,
// TODO output.logprobs
log_probs
:
None
,
// TODO output.logprobs
finish_reason
:
None
,
finish_reason
:
None
,
index
:
None
,
};
};
work_request
work_request
.response_channel
.response_channel
...
...
lib/llm/src/backend.rs
View file @
fc16a79b
...
@@ -224,6 +224,7 @@ impl
...
@@ -224,6 +224,7 @@ impl
log_probs
:
data
.log_probs
,
log_probs
:
data
.log_probs
,
finish_reason
:
data
.finish_reason
,
finish_reason
:
data
.finish_reason
,
//mdcsum: mdcsum.clone(),
//mdcsum: mdcsum.clone(),
index
:
data
.index
,
})
})
})
})
});
});
...
...
lib/llm/src/engines.rs
View file @
fc16a79b
...
@@ -115,6 +115,7 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
...
@@ -115,6 +115,7 @@ fn delta_core(tok: u32) -> Annotated<LLMEngineOutput> {
cum_log_probs
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
log_probs
:
None
,
finish_reason
:
None
,
finish_reason
:
None
,
index
:
None
,
};
};
Annotated
::
from_data
(
delta
)
Annotated
::
from_data
(
delta
)
}
}
...
...
lib/llm/src/preprocessor.rs
View file @
fc16a79b
...
@@ -53,7 +53,7 @@ use crate::protocols::{
...
@@ -53,7 +53,7 @@ use crate::protocols::{
};
};
use
crate
::
tokenizers
::{
traits
::
Tokenizer
,
HuggingFaceTokenizer
};
use
crate
::
tokenizers
::{
traits
::
Tokenizer
,
HuggingFaceTokenizer
};
use
crate
::
preprocessor
::
prompt
::
PromptFormatter
;
use
crate
::
preprocessor
::
prompt
::
{
PromptFormatter
,
PromptInput
,
TextInput
,
TokenInput
}
;
pub
use
crate
::
protocols
::
common
::
llm_backend
::{
BackendOutput
,
PreprocessedRequest
};
pub
use
crate
::
protocols
::
common
::
llm_backend
::{
BackendOutput
,
PreprocessedRequest
};
...
@@ -160,33 +160,79 @@ impl OpenAIPreprocessor {
...
@@ -160,33 +160,79 @@ impl OpenAIPreprocessor {
let
mut
annotations
=
HashMap
::
new
();
let
mut
annotations
=
HashMap
::
new
();
let
mut
builder
=
PreprocessedRequest
::
builder
();
let
mut
builder
=
PreprocessedRequest
::
builder
();
let
use_raw_prompt
=
request
// match request type before any conversion/processing
.nvext
()
match
request
.prompt_input_type
()
{
.is_some_and
(|
ext
|
ext
.use_raw_prompt
.unwrap_or
(
false
));
PromptInput
::
Tokens
(
_
)
=>
{
if
let
Some
(
token_input
)
=
request
.extract_tokens
()
{
let
formatted_prompt
=
if
use_raw_prompt
{
match
token_input
{
match
request
.raw_prompt
()
{
TokenInput
::
Single
(
tokens
)
=>
{
Some
(
prompt
)
=>
prompt
,
builder
.token_ids
(
tokens
);
None
=>
{
}
tracing
::
warn!
(
"Raw prompt requested but not available"
);
TokenInput
::
Batch
(
token_batches
)
=>
{
self
.formatter
.render
(
request
)
?
if
token_batches
.len
()
==
1
{
builder
.token_ids
(
token_batches
[
0
]
.clone
());
}
else
{
builder
.batch_token_ids
(
Some
(
token_batches
));
builder
.token_ids
(
vec!
[]);
}
}
}
}
}
}
}
}
else
{
PromptInput
::
Text
(
_
)
=>
{
self
.formatter
.render
(
request
)
?
if
let
Some
(
text_input
)
=
request
.extract_text
()
{
};
match
text_input
{
TextInput
::
Single
(
_
)
=>
{
let
encoding
=
tokio
::
task
::
block_in_place
(||
self
.tokenizer
.encode
(
&
formatted_prompt
))
?
;
let
use_raw_prompt
=
request
.nvext
()
.is_some_and
(|
ext
|
ext
.use_raw_prompt
.unwrap_or
(
false
));
let
formatted_prompt
=
if
use_raw_prompt
{
match
request
.raw_prompt
()
{
Some
(
prompt
)
=>
prompt
,
None
=>
{
tracing
::
warn!
(
"Raw prompt requested but not available"
);
self
.formatter
.render
(
request
)
?
}
}
}
else
{
self
.formatter
.render
(
request
)
?
};
let
encoding
=
tokio
::
task
::
block_in_place
(||
{
self
.tokenizer
.encode
(
&
formatted_prompt
)
})
?
;
if
request
.has_annotation
(
ANNOTATION_FORMATTED_PROMPT
)
{
annotations
.insert
(
ANNOTATION_FORMATTED_PROMPT
.to_string
(),
formatted_prompt
,
);
}
if
request
.has_annotation
(
ANNOTATION_FORMATTED_PROMPT
)
{
if
request
.has_annotation
(
ANNOTATION_TOKEN_IDS
)
{
annotations
.insert
(
ANNOTATION_FORMATTED_PROMPT
.to_string
(),
formatted_prompt
);
annotations
.insert
(
}
ANNOTATION_TOKEN_IDS
.to_string
(),
serde_json
::
to_string
(
&
encoding
.token_ids
)
?
,
);
}
if
request
.has_annotation
(
ANNOTATION_TOKEN_IDS
)
{
builder
.token_ids
(
encoding
.token_ids
);
annotations
.insert
(
}
ANNOTATION_TOKEN_IDS
.to_string
(),
TextInput
::
Batch
(
texts
)
=>
{
serde_json
::
to_string
(
&
encoding
.token_ids
)
?
,
let
mut
token_batches
=
Vec
::
new
();
);
// TODO: room for optimization here
for
text
in
texts
{
let
encoding
=
tokio
::
task
::
block_in_place
(||
self
.tokenizer
.encode
(
&
text
))
?
;
token_batches
.push
(
encoding
.token_ids
);
}
builder
.batch_token_ids
(
Some
(
token_batches
));
builder
.token_ids
(
vec!
[]);
}
}
}
}
}
}
let
mut
stop_conditions
=
request
.extract_stop_conditions
()
?
;
let
mut
stop_conditions
=
request
.extract_stop_conditions
()
?
;
...
@@ -207,9 +253,8 @@ impl OpenAIPreprocessor {
...
@@ -207,9 +253,8 @@ impl OpenAIPreprocessor {
builder
.eos_token_ids
(
self
.model_info
.eos_token_ids
());
builder
.eos_token_ids
(
self
.model_info
.eos_token_ids
());
}
}
builder
.token_ids
(
encoding
.token_ids
);
builder
.sampling_options
(
request
.extract_sampling_options
()
?
);
builder
.stop_conditions
(
stop_conditions
);
builder
.stop_conditions
(
stop_conditions
);
builder
.sampling_options
(
request
.extract_sampling_options
()
?
);
builder
.annotations
(
request
.annotations
()
.unwrap_or_default
());
builder
.annotations
(
request
.annotations
()
.unwrap_or_default
());
builder
.mdc_sum
(
Some
(
self
.mdcsum
.clone
()));
builder
.mdc_sum
(
Some
(
self
.mdcsum
.clone
()));
builder
.estimated_prefix_hit_num_blocks
(
None
);
builder
.estimated_prefix_hit_num_blocks
(
None
);
...
...
lib/llm/src/preprocessor/prompt.rs
View file @
fc16a79b
...
@@ -38,6 +38,24 @@ mod template;
...
@@ -38,6 +38,24 @@ mod template;
pub
use
template
::
ContextMixins
;
pub
use
template
::
ContextMixins
;
#[derive(Debug)]
pub
enum
TokenInput
{
Single
(
Vec
<
u32
>
),
Batch
(
Vec
<
Vec
<
u32
>>
),
}
#[derive(Debug)]
pub
enum
TextInput
{
Single
(
String
),
Batch
(
Vec
<
String
>
),
}
#[derive(Debug)]
pub
enum
PromptInput
{
Tokens
(
TokenInput
),
Text
(
TextInput
),
}
/// Trait that defines a request that can map to an OpenAI-like request.
/// Trait that defines a request that can map to an OpenAI-like request.
pub
trait
OAIChatLikeRequest
{
pub
trait
OAIChatLikeRequest
{
fn
messages
(
&
self
)
->
Value
;
fn
messages
(
&
self
)
->
Value
;
...
@@ -49,6 +67,20 @@ pub trait OAIChatLikeRequest {
...
@@ -49,6 +67,20 @@ pub trait OAIChatLikeRequest {
}
}
fn
should_add_generation_prompt
(
&
self
)
->
bool
;
fn
should_add_generation_prompt
(
&
self
)
->
bool
;
/// Returns the type of input for the prompt. Default is Text.
fn
prompt_input_type
(
&
self
)
->
PromptInput
{
PromptInput
::
Text
(
TextInput
::
Single
(
String
::
new
()))
}
/// Extract tokens if the input is pre-tokenized
fn
extract_tokens
(
&
self
)
->
Option
<
TokenInput
>
{
None
}
fn
extract_text
(
&
self
)
->
Option
<
TextInput
>
{
None
}
}
}
pub
trait
OAIPromptFormatter
:
Send
+
Sync
+
'static
{
pub
trait
OAIPromptFormatter
:
Send
+
Sync
+
'static
{
...
...
lib/llm/src/preprocessor/prompt/template/oai.rs
View file @
fc16a79b
...
@@ -22,6 +22,8 @@ use crate::protocols::openai::{
...
@@ -22,6 +22,8 @@ use crate::protocols::openai::{
};
};
use
tracing
;
use
tracing
;
use
crate
::
preprocessor
::
prompt
::{
PromptInput
,
TextInput
,
TokenInput
};
impl
OAIChatLikeRequest
for
NvCreateChatCompletionRequest
{
impl
OAIChatLikeRequest
for
NvCreateChatCompletionRequest
{
fn
messages
(
&
self
)
->
Value
{
fn
messages
(
&
self
)
->
Value
{
Value
::
from_serialize
(
&
self
.inner.messages
)
Value
::
from_serialize
(
&
self
.inner.messages
)
...
@@ -53,6 +55,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
...
@@ -53,6 +55,10 @@ impl OAIChatLikeRequest for NvCreateChatCompletionRequest {
true
true
}
}
}
}
fn
extract_text
(
&
self
)
->
Option
<
TextInput
>
{
Some
(
TextInput
::
Single
(
String
::
new
()))
}
}
}
impl
OAIChatLikeRequest
for
NvCreateCompletionRequest
{
impl
OAIChatLikeRequest
for
NvCreateCompletionRequest
{
...
@@ -72,6 +78,48 @@ impl OAIChatLikeRequest for NvCreateCompletionRequest {
...
@@ -72,6 +78,48 @@ impl OAIChatLikeRequest for NvCreateCompletionRequest {
fn
should_add_generation_prompt
(
&
self
)
->
bool
{
fn
should_add_generation_prompt
(
&
self
)
->
bool
{
true
true
}
}
fn
prompt_input_type
(
&
self
)
->
PromptInput
{
match
&
self
.inner.prompt
{
async_openai
::
types
::
Prompt
::
IntegerArray
(
_
)
=>
{
PromptInput
::
Tokens
(
TokenInput
::
Single
(
vec!
[]))
}
async_openai
::
types
::
Prompt
::
ArrayOfIntegerArray
(
_
)
=>
{
PromptInput
::
Tokens
(
TokenInput
::
Batch
(
vec!
[]))
}
async_openai
::
types
::
Prompt
::
String
(
_
)
=>
{
PromptInput
::
Text
(
TextInput
::
Single
(
String
::
new
()))
}
async_openai
::
types
::
Prompt
::
StringArray
(
_
)
=>
{
PromptInput
::
Text
(
TextInput
::
Batch
(
vec!
[]))
}
}
}
fn
extract_tokens
(
&
self
)
->
Option
<
TokenInput
>
{
match
&
self
.inner.prompt
{
async_openai
::
types
::
Prompt
::
IntegerArray
(
tokens
)
=>
Some
(
TokenInput
::
Single
(
tokens
.iter
()
.map
(|
&
t
|
t
as
u32
)
.collect
(),
)),
async_openai
::
types
::
Prompt
::
ArrayOfIntegerArray
(
arrays
)
=>
Some
(
TokenInput
::
Batch
(
arrays
.iter
()
.map
(|
arr
|
arr
.iter
()
.map
(|
&
t
|
t
as
u32
)
.collect
())
.collect
(),
)),
_
=>
None
,
}
}
fn
extract_text
(
&
self
)
->
Option
<
TextInput
>
{
match
&
self
.inner.prompt
{
async_openai
::
types
::
Prompt
::
String
(
text
)
=>
Some
(
TextInput
::
Single
(
text
.to_string
())),
async_openai
::
types
::
Prompt
::
StringArray
(
texts
)
=>
{
Some
(
TextInput
::
Batch
(
texts
.to_vec
()))
}
_
=>
None
,
}
}
}
}
impl
OAIPromptFormatter
for
HfTokenizerConfigJsonFormatter
{
impl
OAIPromptFormatter
for
HfTokenizerConfigJsonFormatter
{
...
...
lib/llm/src/protocols/common/llm_backend.rs
View file @
fc16a79b
...
@@ -46,6 +46,9 @@ pub struct BackendOutput {
...
@@ -46,6 +46,9 @@ pub struct BackendOutput {
pub
finish_reason
:
Option
<
FinishReason
>
,
pub
finish_reason
:
Option
<
FinishReason
>
,
// Model Deployment Card checksum
// Model Deployment Card checksum
//pub mdcsum: String,
//pub mdcsum: String,
// Index field for batch requests to match OpenAI format
pub
index
:
Option
<
u32
>
,
}
}
/// The LLM engine and backnd with manage it's own state, specifically translating how a
/// The LLM engine and backnd with manage it's own state, specifically translating how a
...
@@ -77,6 +80,9 @@ pub struct LLMEngineOutput {
...
@@ -77,6 +80,9 @@ pub struct LLMEngineOutput {
// TODO: Enrich this with more information as can apply our first-level postprocessing
// TODO: Enrich this with more information as can apply our first-level postprocessing
// logic and return more detailed information
// logic and return more detailed information
pub
finish_reason
:
Option
<
FinishReason
>
,
pub
finish_reason
:
Option
<
FinishReason
>
,
// Index field for batch requests to match OpenAI format
pub
index
:
Option
<
u32
>
,
}
}
impl
LLMEngineOutput
{
impl
LLMEngineOutput
{
...
@@ -88,6 +94,7 @@ impl LLMEngineOutput {
...
@@ -88,6 +94,7 @@ impl LLMEngineOutput {
cum_log_probs
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
log_probs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Cancelled
),
finish_reason
:
Some
(
FinishReason
::
Cancelled
),
index
:
None
,
}
}
}
}
...
@@ -99,6 +106,7 @@ impl LLMEngineOutput {
...
@@ -99,6 +106,7 @@ impl LLMEngineOutput {
cum_log_probs
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
log_probs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Stop
),
finish_reason
:
Some
(
FinishReason
::
Stop
),
index
:
None
,
}
}
}
}
...
@@ -110,6 +118,7 @@ impl LLMEngineOutput {
...
@@ -110,6 +118,7 @@ impl LLMEngineOutput {
cum_log_probs
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
log_probs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Length
),
finish_reason
:
Some
(
FinishReason
::
Length
),
index
:
None
,
}
}
}
}
...
@@ -121,6 +130,7 @@ impl LLMEngineOutput {
...
@@ -121,6 +130,7 @@ impl LLMEngineOutput {
cum_log_probs
:
None
,
cum_log_probs
:
None
,
log_probs
:
None
,
log_probs
:
None
,
finish_reason
:
Some
(
FinishReason
::
Error
(
err_msg
)),
finish_reason
:
Some
(
FinishReason
::
Error
(
err_msg
)),
index
:
None
,
}
}
}
}
}
}
lib/llm/src/protocols/common/preprocessor.rs
View file @
fc16a79b
...
@@ -26,6 +26,10 @@ pub struct PreprocessedRequest {
...
@@ -26,6 +26,10 @@ pub struct PreprocessedRequest {
/// Type of prompt
/// Type of prompt
pub
token_ids
:
Vec
<
TokenIdType
>
,
pub
token_ids
:
Vec
<
TokenIdType
>
,
/// Batch Token Ids = for batch completion requests (i.e using ArrayOfIntegerArray type from OpenAI /completions)
#[builder(default)]
pub
batch_token_ids
:
Option
<
Vec
<
Vec
<
TokenIdType
>>>
,
/// StopConditions are conditions that the inference engine will use to stop generation.
/// StopConditions are conditions that the inference engine will use to stop generation.
pub
stop_conditions
:
StopConditions
,
pub
stop_conditions
:
StopConditions
,
...
...
lib/llm/src/protocols/openai/completions/delta.rs
View file @
fc16a79b
...
@@ -131,8 +131,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe
...
@@ -131,8 +131,9 @@ impl crate::protocols::openai::DeltaGeneratorExt<CompletionResponse> for DeltaGe
};
};
// create choice
// create choice
let
index
=
0
;
let
index
=
delta
.index
.unwrap_or
(
0
)
.into
();
Ok
(
self
.create_choice
(
index
,
delta
.text
,
finish_reason
))
let
response
=
self
.create_choice
(
index
,
delta
.text
.clone
(),
finish_reason
);
Ok
(
response
)
}
}
fn
get_isl
(
&
self
)
->
Option
<
u32
>
{
fn
get_isl
(
&
self
)
->
Option
<
u32
>
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment