Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
ab7ccf5b
Unverified
Commit
ab7ccf5b
authored
Nov 21, 2024
by
OlivierDehaene
Committed by
GitHub
Nov 21, 2024
Browse files
feat: add payload limit (#2726)
* feat: add payload limit * update launcher
parent
d5bc6a20
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
45 additions
and
9 deletions
+45
-9
backends/trtllm/src/main.rs
backends/trtllm/src/main.rs
+5
-0
backends/v2/src/main.rs
backends/v2/src/main.rs
+4
-0
backends/v3/src/main.rs
backends/v3/src/main.rs
+4
-0
docs/source/reference/launcher.md
docs/source/reference/launcher.md
+11
-0
launcher/src/main.rs
launcher/src/main.rs
+8
-0
router/src/server.rs
router/src/server.rs
+5
-1
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+5
-6
server/text_generation_server/models/metadata_kernels.py
server/text_generation_server/models/metadata_kernels.py
+3
-2
No files found.
backends/trtllm/src/main.rs
View file @
ab7ccf5b
...
...
@@ -62,6 +62,8 @@ struct Args {
executor_worker
:
PathBuf
,
#[clap(default_value
=
"on"
,
long,
env)]
usage_stats
:
usage_stats
::
UsageStatsLevel
,
#[clap(default_value
=
"2000000"
,
long,
env)]
payload_limit
:
usize
,
}
async
fn
get_tokenizer
(
...
...
@@ -217,6 +219,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
auth_token
,
executor_worker
,
usage_stats
,
payload_limit
,
}
=
args
;
// Launch Tokio runtime
...
...
@@ -287,6 +290,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
tokenizer_name
,
tokenizer_config_path
,
revision
,
false
,
hostname
,
port
,
cors_allow_origin
,
...
...
@@ -296,6 +300,7 @@ async fn main() -> Result<(), TensorRtLlmBackendError> {
true
,
max_client_batch_size
,
usage_stats
,
payload_limit
,
)
.await
?
;
Ok
(())
...
...
backends/v2/src/main.rs
View file @
ab7ccf5b
...
...
@@ -70,6 +70,8 @@ struct Args {
max_client_batch_size
:
usize
,
#[clap(default_value
=
"on"
,
long,
env)]
usage_stats
:
usage_stats
::
UsageStatsLevel
,
#[clap(default_value
=
"2000000"
,
long,
env)]
payload_limit
:
usize
,
}
#[derive(Debug,
Subcommand)]
...
...
@@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support
,
max_client_batch_size
,
usage_stats
,
payload_limit
,
}
=
args
;
if
let
Some
(
Commands
::
PrintSchema
)
=
command
{
...
...
@@ -194,6 +197,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support
,
max_client_batch_size
,
usage_stats
,
payload_limit
,
)
.await
?
;
Ok
(())
...
...
backends/v3/src/main.rs
View file @
ab7ccf5b
...
...
@@ -70,6 +70,8 @@ struct Args {
max_client_batch_size
:
usize
,
#[clap(default_value
=
"on"
,
long,
env)]
usage_stats
:
usage_stats
::
UsageStatsLevel
,
#[clap(default_value
=
"2000000"
,
long,
env)]
payload_limit
:
usize
,
}
#[derive(Debug,
Subcommand)]
...
...
@@ -114,6 +116,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support
,
max_client_batch_size
,
usage_stats
,
payload_limit
,
}
=
args
;
if
let
Some
(
Commands
::
PrintSchema
)
=
command
{
...
...
@@ -210,6 +213,7 @@ async fn main() -> Result<(), RouterError> {
disable_grammar_support
,
max_client_batch_size
,
usage_stats
,
payload_limit
,
)
.await
?
;
Ok
(())
...
...
docs/source/reference/launcher.md
View file @
ab7ccf5b
...
...
@@ -456,6 +456,17 @@ Options:
- off: Disables all collection of usage statistics
- no-stack: Doesn
't send the error stack trace or error type, but allows sending a crash event
```
## PAYLOAD_LIMIT
```
shell
--payload-limit
<PAYLOAD_LIMIT>
Payload size limit
in
bytes
Default is 2MB
[
env
:
PAYLOAD_LIMIT
=]
[
default: 2000000]
```
## HELP
```
shell
...
...
launcher/src/main.rs
View file @
ab7ccf5b
...
...
@@ -692,6 +692,12 @@ struct Args {
/// Defaul is on.
#[clap(default_value
=
"on"
,
long,
env)]
usage_stats
:
UsageStatsLevel
,
/// Payload size limit in bytes
///
/// Default is 2MB
#[clap(default_value
=
"2000000"
,
long,
env)]
payload_limit
:
usize
,
}
#[derive(Debug)]
...
...
@@ -1479,6 +1485,8 @@ fn spawn_webserver(
format!
(
"{}-0"
,
args
.shard_uds_path
),
"--tokenizer-name"
.to_string
(),
args
.model_id
,
"--payload-limit"
.to_string
(),
args
.payload_limit
.to_string
(),
];
if
let
Some
(
max_input_tokens
)
=
max_input_tokens
{
router_args
.extend_from_slice
(
&
[
...
...
router/src/server.rs
View file @
ab7ccf5b
...
...
@@ -30,7 +30,7 @@ use crate::{
use
crate
::{
FunctionDefinition
,
HubPreprocessorConfig
,
ToolCall
,
ToolChoice
};
use
crate
::{
ModelInfo
,
ModelsInfo
};
use
async_stream
::
__
private
::
AsyncStream
;
use
axum
::
extract
::
Extension
;
use
axum
::
extract
::
{
DefaultBodyLimit
,
Extension
}
;
use
axum
::
http
::{
HeaderMap
,
HeaderValue
,
Method
,
StatusCode
};
use
axum
::
response
::
sse
::{
Event
,
KeepAlive
,
Sse
};
use
axum
::
response
::{
IntoResponse
,
Response
};
...
...
@@ -1674,6 +1674,7 @@ pub async fn run(
disable_grammar_support
:
bool
,
max_client_batch_size
:
usize
,
usage_stats_level
:
usage_stats
::
UsageStatsLevel
,
payload_limit
:
usize
,
)
->
Result
<
(),
WebServerError
>
{
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
...
...
@@ -1928,6 +1929,7 @@ pub async fn run(
model_info
,
compat_return_full_text
,
allow_origin
,
payload_limit
,
)
.await
;
...
...
@@ -1987,6 +1989,7 @@ async fn start(
model_info
:
HubModelInfo
,
compat_return_full_text
:
bool
,
allow_origin
:
Option
<
AllowOrigin
>
,
payload_limit
:
usize
,
)
->
Result
<
(),
WebServerError
>
{
// Determine the server port based on the feature and environment variable.
let
port
=
if
cfg!
(
feature
=
"google"
)
{
...
...
@@ -2384,6 +2387,7 @@ async fn start(
.layer
(
Extension
(
compute_type
))
.layer
(
Extension
(
prom_handle
.clone
()))
.layer
(
OtelAxumLayer
::
default
())
.layer
(
DefaultBodyLimit
::
max
(
payload_limit
))
.layer
(
cors_layer
);
tracing
::
info!
(
"Connected"
);
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
ab7ccf5b
...
...
@@ -962,9 +962,9 @@ class FlashCausalLMBatch(Batch):
self
.
input_lengths_tensor
=
torch
.
tensor
(
self
.
input_lengths
,
dtype
=
torch
.
int32
,
device
=
device
)
self
.
cu_seqlen_prefill
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
self
.
input_lengths_tensor
,
dim
=
0
),
(
1
,
0
)
)
.
to
(
torch
.
int32
)
cu_seqlen_prefill
=
self
.
input_lengths_tensor
.
new_zeros
(
len
(
self
)
+
1
)
torch
.
cumsum
(
self
.
input_lengths_tensor
,
out
=
cu_seqlen_prefill
[
1
:],
dim
=
0
)
self
.
cu_seqlen_prefill
=
cu_seqlen_prefill
.
to
(
torch
.
int32
)
self
.
cache_lengths_tensor
=
torch
.
tensor
(
self
.
cache_lengths
,
dtype
=
torch
.
int32
,
device
=
device
)
...
...
@@ -2020,9 +2020,8 @@ class FlashCausalLM(Model):
# For each member of the batch
# Cumulative length
cu_accepted_ids
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
accepted_ids
,
dim
=
0
),
(
1
,
0
)
)
cu_accepted_ids
=
accepted_ids
.
new_zeros
(
accepted_ids
.
shape
[
0
]
+
1
)
torch
.
cumsum
(
accepted_ids
,
dim
=
0
,
out
=
cu_accepted_ids
[
1
:])
cumulative_length
=
0
for
i
,
(
request
,
...
...
server/text_generation_server/models/metadata_kernels.py
View file @
ab7ccf5b
...
...
@@ -66,8 +66,9 @@ def block_tables_to_ragged(
)
if
has_triton
():
cu_seqlen
=
torch
.
nn
.
functional
.
pad
(
torch
.
cumsum
(
input_lengths_tensor
+
cache_lengths_tensor
,
dim
=
0
),
(
1
,
0
)
cu_seqlen
=
input_lengths_tensor
.
new_zeros
(
input_lengths_tensor
.
shape
[
0
]
+
1
)
torch
.
cumsum
(
input_lengths_tensor
+
cache_lengths_tensor
,
out
=
cu_seqlen
[
1
:],
dim
=
0
)
def
grid
(
meta
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment