Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
6796d38c
Unverified
Commit
6796d38c
authored
Feb 17, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 17, 2023
Browse files
feat(router): add cors allow origin options (#73)
parent
c720555a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
3 deletions
+41
-3
Cargo.lock
Cargo.lock
+1
-0
launcher/src/main.rs
launcher/src/main.rs
+9
-0
router/Cargo.toml
router/Cargo.toml
+1
-0
router/src/main.rs
router/src/main.rs
+17
-0
router/src/server.rs
router/src/server.rs
+13
-3
No files found.
Cargo.lock
View file @
6796d38c
...
...
@@ -2275,6 +2275,7 @@ dependencies = [
"tokenizers",
"tokio",
"tokio-stream",
"tower-http",
"tracing",
"tracing-opentelemetry",
"tracing-subscriber",
...
...
launcher/src/main.rs
View file @
6796d38c
...
...
@@ -53,6 +53,8 @@ struct Args {
json_output
:
bool
,
#[clap(long,
env)]
otlp_endpoint
:
Option
<
String
>
,
#[clap(long,
env)]
cors_allow_origin
:
Vec
<
String
>
,
}
fn
main
()
->
ExitCode
{
...
...
@@ -85,6 +87,7 @@ fn main() -> ExitCode {
disable_custom_kernels
,
json_output
,
otlp_endpoint
,
cors_allow_origin
,
}
=
args
;
// Signal handler
...
...
@@ -320,6 +323,12 @@ fn main() -> ExitCode {
argv
.push
(
otlp_endpoint
);
}
// CORS origins
for
origin
in
cors_allow_origin
.into_iter
()
{
argv
.push
(
"--cors-allow-origin"
.to_string
());
argv
.push
(
origin
);
}
let
mut
webserver
=
match
Popen
::
create
(
&
argv
,
PopenConfig
{
...
...
router/Cargo.toml
View file @
6796d38c
...
...
@@ -32,6 +32,7 @@ thiserror = "1.0.38"
tokenizers
=
"0.13.2"
tokio
=
{
version
=
"1.25.0"
,
features
=
[
"rt"
,
"rt-multi-thread"
,
"parking_lot"
,
"signal"
,
"sync"
]
}
tokio-stream
=
"0.1.11"
tower-http
=
{
version
=
"0.3.5"
,
features
=
["cors"]
}
tracing
=
"0.1.37"
tracing-opentelemetry
=
"0.18.0"
tracing-subscriber
=
{
version
=
"0.3.16"
,
features
=
[
"json"
,
"env-filter"
]
}
...
...
router/src/main.rs
View file @
6796d38c
/// Text Generation Inference webserver entrypoint
use
axum
::
http
::
HeaderValue
;
use
clap
::
Parser
;
use
opentelemetry
::
sdk
::
propagation
::
TraceContextPropagator
;
use
opentelemetry
::
sdk
::
trace
;
...
...
@@ -10,6 +11,7 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use
text_generation_client
::
ShardedClient
;
use
text_generation_router
::
server
;
use
tokenizers
::
Tokenizer
;
use
tower_http
::
cors
::
AllowOrigin
;
use
tracing_subscriber
::
layer
::
SubscriberExt
;
use
tracing_subscriber
::
util
::
SubscriberInitExt
;
use
tracing_subscriber
::{
EnvFilter
,
Layer
};
...
...
@@ -42,6 +44,8 @@ struct Args {
json_output
:
bool
,
#[clap(long,
env)]
otlp_endpoint
:
Option
<
String
>
,
#[clap(long,
env)]
cors_allow_origin
:
Option
<
Vec
<
String
>>
,
}
fn
main
()
->
Result
<
(),
std
::
io
::
Error
>
{
...
...
@@ -61,12 +65,24 @@ fn main() -> Result<(), std::io::Error> {
validation_workers
,
json_output
,
otlp_endpoint
,
cors_allow_origin
,
}
=
args
;
if
validation_workers
==
0
{
panic!
(
"validation_workers must be > 0"
);
}
// CORS allowed origins
// map to go inside the option and then map to parse from String to HeaderValue
// Finally, convert to AllowOrigin
let
cors_allow_origin
:
Option
<
AllowOrigin
>
=
cors_allow_origin
.map
(|
cors_allow_origin
|
{
AllowOrigin
::
list
(
cors_allow_origin
.iter
()
.map
(|
origin
|
origin
.parse
::
<
HeaderValue
>
()
.unwrap
()),
)
});
// Download and instantiate tokenizer
// This will only be used to validate payloads
//
...
...
@@ -107,6 +123,7 @@ fn main() -> Result<(), std::io::Error> {
tokenizer
,
validation_workers
,
addr
,
cors_allow_origin
,
)
.await
;
Ok
(())
...
...
router/src/server.rs
View file @
6796d38c
...
...
@@ -5,11 +5,11 @@ use crate::{
Infer
,
StreamDetails
,
StreamResponse
,
Token
,
Validation
,
};
use
axum
::
extract
::
Extension
;
use
axum
::
http
::{
HeaderMap
,
StatusCode
};
use
axum
::
http
::{
HeaderMap
,
Method
,
StatusCode
};
use
axum
::
response
::
sse
::{
Event
,
KeepAlive
,
Sse
};
use
axum
::
response
::
IntoResponse
;
use
axum
::
routing
::{
get
,
post
};
use
axum
::{
Json
,
Router
};
use
axum
::{
http
,
Json
,
Router
};
use
axum_tracing_opentelemetry
::
opentelemetry_tracing_layer
;
use
futures
::
Stream
;
use
metrics_exporter_prometheus
::{
PrometheusBuilder
,
PrometheusHandle
};
...
...
@@ -20,6 +20,7 @@ use tokenizers::Tokenizer;
use
tokio
::
signal
;
use
tokio
::
time
::
Instant
;
use
tokio_stream
::
StreamExt
;
use
tower_http
::
cors
::{
AllowOrigin
,
CorsLayer
};
use
tracing
::{
info_span
,
instrument
,
Instrument
};
use
utoipa
::
OpenApi
;
use
utoipa_swagger_ui
::
SwaggerUi
;
...
...
@@ -334,6 +335,7 @@ pub async fn run(
tokenizer
:
Tokenizer
,
validation_workers
:
usize
,
addr
:
SocketAddr
,
allow_origin
:
Option
<
AllowOrigin
>
,
)
{
// OpenAPI documentation
#[derive(OpenApi)]
...
...
@@ -391,6 +393,13 @@ pub async fn run(
.install_recorder
()
.expect
(
"failed to install metrics recorder"
);
// CORS layer
let
allow_origin
=
allow_origin
.unwrap_or
(
AllowOrigin
::
any
());
let
cors_layer
=
CorsLayer
::
new
()
.allow_methods
([
Method
::
GET
,
Method
::
POST
])
.allow_headers
([
http
::
header
::
CONTENT_TYPE
])
.allow_origin
(
allow_origin
);
// Create router
let
app
=
Router
::
new
()
.merge
(
SwaggerUi
::
new
(
"/docs"
)
.url
(
"/api-doc/openapi.json"
,
ApiDoc
::
openapi
()))
...
...
@@ -402,7 +411,8 @@ pub async fn run(
.layer
(
Extension
(
infer
))
.route
(
"/metrics"
,
get
(
metrics
))
.layer
(
Extension
(
prom_handle
))
.layer
(
opentelemetry_tracing_layer
());
.layer
(
opentelemetry_tracing_layer
())
.layer
(
cors_layer
);
// Run server
axum
::
Server
::
bind
(
&
addr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment