Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
f59fb8b6
Unverified
Commit
f59fb8b6
authored
Jun 16, 2023
by
OlivierDehaene
Committed by
GitHub
Jun 16, 2023
Browse files
feat(router): add ngrok integration (#453)
parent
5ce89059
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
564 additions
and
224 deletions
+564
-224
Cargo.lock
Cargo.lock
+431
-212
launcher/src/main.rs
launcher/src/main.rs
+44
-0
router/Cargo.toml
router/Cargo.toml
+5
-0
router/src/main.rs
router/src/main.rs
+20
-0
router/src/queue.rs
router/src/queue.rs
+3
-3
router/src/server.rs
router/src/server.rs
+61
-8
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+0
-1
No files found.
Cargo.lock
View file @
f59fb8b6
This diff is collapsed.
Click to expand it.
launcher/src/main.rs
View file @
f59fb8b6
...
@@ -229,6 +229,26 @@ struct Args {
...
@@ -229,6 +229,26 @@ struct Args {
#[clap(long,
env)]
#[clap(long,
env)]
watermark_delta
:
Option
<
f32
>
,
watermark_delta
:
Option
<
f32
>
,
/// Enable ngrok tunneling
#[clap(long,
env)]
ngrok
:
bool
,
/// ngrok authentication token
#[clap(long,
env)]
ngrok_authtoken
:
Option
<
String
>
,
/// ngrok domain name where the axum webserver will be available at
#[clap(long,
env)]
ngrok_domain
:
Option
<
String
>
,
/// ngrok basic auth username
#[clap(long,
env)]
ngrok_username
:
Option
<
String
>
,
/// ngrok basic auth password
#[clap(long,
env)]
ngrok_password
:
Option
<
String
>
,
/// Display a lot of information about your runtime environment
/// Display a lot of information about your runtime environment
#[clap(long,
short,
action)]
#[clap(long,
short,
action)]
env
:
bool
,
env
:
bool
,
...
@@ -845,6 +865,30 @@ fn spawn_webserver(
...
@@ -845,6 +865,30 @@ fn spawn_webserver(
argv
.push
(
origin
);
argv
.push
(
origin
);
}
}
// Ngrok
if
args
.ngrok
{
let
authtoken
=
args
.ngrok_authtoken
.ok_or_else
(||
{
tracing
::
error!
(
"`ngrok-authtoken` must be set when using ngrok tunneling"
);
LauncherError
::
WebserverCannotStart
})
?
;
argv
.push
(
"--ngrok"
.to_string
());
argv
.push
(
"--ngrok-authtoken"
.to_string
());
argv
.push
(
authtoken
);
if
let
Some
(
domain
)
=
args
.ngrok_domain
{
argv
.push
(
"--ngrok-domain"
.to_string
());
argv
.push
(
domain
);
}
if
let
(
Some
(
username
),
Some
(
password
))
=
(
args
.ngrok_username
,
args
.ngrok_password
)
{
argv
.push
(
"--ngrok-username"
.to_string
());
argv
.push
(
username
);
argv
.push
(
"--ngrok-password"
.to_string
());
argv
.push
(
password
);
}
}
// Copy current process env
// Copy current process env
let
mut
env
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
let
mut
env
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
...
...
router/Cargo.toml
View file @
f59fb8b6
...
@@ -40,6 +40,11 @@ tracing-opentelemetry = "0.18.0"
...
@@ -40,6 +40,11 @@ tracing-opentelemetry = "0.18.0"
tracing-subscriber
=
{
version
=
"0.3.16"
,
features
=
[
"json"
,
"env-filter"
]
}
tracing-subscriber
=
{
version
=
"0.3.16"
,
features
=
[
"json"
,
"env-filter"
]
}
utoipa
=
{
version
=
"3.0.1"
,
features
=
["axum_extras"]
}
utoipa
=
{
version
=
"3.0.1"
,
features
=
["axum_extras"]
}
utoipa-swagger-ui
=
{
version
=
"3.0.2"
,
features
=
["axum"]
}
utoipa-swagger-ui
=
{
version
=
"3.0.2"
,
features
=
["axum"]
}
ngrok
=
{
version
=
"0.12.3"
,
features
=
["axum"]
,
optional
=
true
}
[build-dependencies]
[build-dependencies]
vergen
=
{
version
=
"8.0.0"
,
features
=
[
"build"
,
"git"
,
"gitcl"
]
}
vergen
=
{
version
=
"8.0.0"
,
features
=
[
"build"
,
"git"
,
"gitcl"
]
}
[features]
default
=
["ngrok"]
ngrok
=
["dep:ngrok"]
\ No newline at end of file
router/src/main.rs
View file @
f59fb8b6
...
@@ -56,6 +56,16 @@ struct Args {
...
@@ -56,6 +56,16 @@ struct Args {
otlp_endpoint
:
Option
<
String
>
,
otlp_endpoint
:
Option
<
String
>
,
#[clap(long,
env)]
#[clap(long,
env)]
cors_allow_origin
:
Option
<
Vec
<
String
>>
,
cors_allow_origin
:
Option
<
Vec
<
String
>>
,
#[clap(long,
env)]
ngrok
:
bool
,
#[clap(long,
env)]
ngrok_authtoken
:
Option
<
String
>
,
#[clap(long,
env)]
ngrok_domain
:
Option
<
String
>
,
#[clap(long,
env)]
ngrok_username
:
Option
<
String
>
,
#[clap(long,
env)]
ngrok_password
:
Option
<
String
>
,
}
}
fn
main
()
->
Result
<
(),
std
::
io
::
Error
>
{
fn
main
()
->
Result
<
(),
std
::
io
::
Error
>
{
...
@@ -80,6 +90,11 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -80,6 +90,11 @@ fn main() -> Result<(), std::io::Error> {
json_output
,
json_output
,
otlp_endpoint
,
otlp_endpoint
,
cors_allow_origin
,
cors_allow_origin
,
ngrok
,
ngrok_authtoken
,
ngrok_domain
,
ngrok_username
,
ngrok_password
,
}
=
args
;
}
=
args
;
if
validation_workers
==
0
{
if
validation_workers
==
0
{
...
@@ -198,6 +213,11 @@ fn main() -> Result<(), std::io::Error> {
...
@@ -198,6 +213,11 @@ fn main() -> Result<(), std::io::Error> {
validation_workers
,
validation_workers
,
addr
,
addr
,
cors_allow_origin
,
cors_allow_origin
,
ngrok
,
ngrok_authtoken
,
ngrok_domain
,
ngrok_username
,
ngrok_password
,
)
)
.await
;
.await
;
Ok
(())
Ok
(())
...
...
router/src/queue.rs
View file @
f59fb8b6
...
@@ -49,7 +49,7 @@ impl Queue {
...
@@ -49,7 +49,7 @@ impl Queue {
// Send append command to the background task managing the state
// Send append command to the background task managing the state
// Unwrap is safe here
// Unwrap is safe here
self
.queue_sender
self
.queue_sender
.send
(
QueueCommand
::
Append
(
entry
,
Span
::
current
()))
.send
(
QueueCommand
::
Append
(
Box
::
new
(
entry
)
,
Span
::
current
()))
.unwrap
();
.unwrap
();
}
}
...
@@ -85,7 +85,7 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
...
@@ -85,7 +85,7 @@ async fn queue_task(requires_padding: bool, receiver: flume::Receiver<QueueComma
while
let
Ok
(
cmd
)
=
receiver
.recv_async
()
.await
{
while
let
Ok
(
cmd
)
=
receiver
.recv_async
()
.await
{
match
cmd
{
match
cmd
{
QueueCommand
::
Append
(
entry
,
span
)
=>
{
QueueCommand
::
Append
(
entry
,
span
)
=>
{
span
.in_scope
(||
state
.append
(
entry
));
span
.in_scope
(||
state
.append
(
*
entry
));
metrics
::
increment_gauge!
(
"tgi_queue_size"
,
1.0
);
metrics
::
increment_gauge!
(
"tgi_queue_size"
,
1.0
);
}
}
QueueCommand
::
NextBatch
{
QueueCommand
::
NextBatch
{
...
@@ -256,7 +256,7 @@ type NextBatch = (IntMap<u64, Entry>, Batch, Span);
...
@@ -256,7 +256,7 @@ type NextBatch = (IntMap<u64, Entry>, Batch, Span);
#[derive(Debug)]
#[derive(Debug)]
enum
QueueCommand
{
enum
QueueCommand
{
Append
(
Entry
,
Span
),
Append
(
Box
<
Entry
>
,
Span
),
NextBatch
{
NextBatch
{
min_size
:
Option
<
usize
>
,
min_size
:
Option
<
usize
>
,
token_budget
:
u32
,
token_budget
:
u32
,
...
...
router/src/server.rs
View file @
f59fb8b6
use
crate
::
health
::
Health
;
/// HTTP Server logic
/// HTTP Server logic
use
crate
::
health
::
Health
;
use
crate
::
infer
::{
InferError
,
InferResponse
,
InferStreamResponse
};
use
crate
::
infer
::{
InferError
,
InferResponse
,
InferStreamResponse
};
use
crate
::
validation
::
ValidationError
;
use
crate
::
validation
::
ValidationError
;
use
crate
::{
use
crate
::{
...
@@ -520,6 +520,11 @@ pub async fn run(
...
@@ -520,6 +520,11 @@ pub async fn run(
validation_workers
:
usize
,
validation_workers
:
usize
,
addr
:
SocketAddr
,
addr
:
SocketAddr
,
allow_origin
:
Option
<
AllowOrigin
>
,
allow_origin
:
Option
<
AllowOrigin
>
,
ngrok
:
bool
,
ngrok_authtoken
:
Option
<
String
>
,
ngrok_domain
:
Option
<
String
>
,
ngrok_username
:
Option
<
String
>
,
ngrok_password
:
Option
<
String
>
,
)
{
)
{
// OpenAPI documentation
// OpenAPI documentation
#[derive(OpenApi)]
#[derive(OpenApi)]
...
@@ -683,13 +688,61 @@ pub async fn run(
...
@@ -683,13 +688,61 @@ pub async fn run(
.layer
(
opentelemetry_tracing_layer
())
.layer
(
opentelemetry_tracing_layer
())
.layer
(
cors_layer
);
.layer
(
cors_layer
);
// Run server
if
ngrok
{
axum
::
Server
::
bind
(
&
addr
)
#[cfg(feature
=
"ngrok"
)]
.serve
(
app
.into_make_service
())
{
// Wait until all requests are finished to shut down
use
ngrok
::
config
::
TunnelBuilder
;
.with_graceful_shutdown
(
shutdown_signal
())
use
ngrok
::
tunnel
::
UrlTunnel
;
.await
.unwrap
();
let
_
=
addr
;
let
authtoken
=
ngrok_authtoken
.expect
(
"`ngrok-authtoken` must be set when using ngrok tunneling"
);
let
mut
tunnel
=
ngrok
::
Session
::
builder
()
.authtoken
(
authtoken
)
.connect
()
.await
.unwrap
()
.http_endpoint
();
if
let
Some
(
domain
)
=
ngrok_domain
{
tunnel
=
tunnel
.domain
(
domain
);
}
if
let
(
Some
(
username
),
Some
(
password
))
=
(
ngrok_username
,
ngrok_password
)
{
tunnel
=
tunnel
.basic_auth
(
username
,
password
);
}
let
listener
=
tunnel
.listen
()
.await
.unwrap
();
// Run server
tracing
::
info!
(
"Ingress URL: {:?}"
,
listener
.url
());
axum
::
Server
::
builder
(
listener
)
.serve
(
app
.into_make_service
())
//Wait until all requests are finished to shut down
.with_graceful_shutdown
(
shutdown_signal
())
.await
.unwrap
();
}
#[cfg(not(feature
=
"ngrok"
))]
{
let
_
ngrok_authtoken
=
ngrok_authtoken
;
let
_
ngrok_domain
=
ngrok_domain
;
let
_
ngrok_username
=
ngrok_username
;
let
_
ngrok_password
=
ngrok_password
;
panic!
(
"`text-generation-router` was compiled without the `ngrok` feature"
);
}
}
else
{
// Run server
axum
::
Server
::
bind
(
&
addr
)
.serve
(
app
.into_make_service
())
// Wait until all requests are finished to shut down
.with_graceful_shutdown
(
shutdown_signal
())
.await
.unwrap
();
}
}
}
/// Shutdown signal handler
/// Shutdown signal handler
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
f59fb8b6
...
@@ -47,7 +47,6 @@ def load_multi_mqa(
...
@@ -47,7 +47,6 @@ def load_multi_mqa(
shape
=
slice_
.
get_shape
()
shape
=
slice_
.
get_shape
()
block_size
=
(
shape
[
0
]
-
2
*
head_size
)
//
world_size
block_size
=
(
shape
[
0
]
-
2
*
head_size
)
//
world_size
assert
(
shape
[
0
]
-
2
*
head_size
)
%
world_size
==
0
assert
(
shape
[
0
]
-
2
*
head_size
)
%
world_size
==
0
q_tensor
=
slice_
[
start
:
stop
]
start
=
rank
*
block_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
stop
=
(
rank
+
1
)
*
block_size
q_tensor
=
slice_
[
start
:
stop
]
q_tensor
=
slice_
[
start
:
stop
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment