Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
252f42c1
Unverified
Commit
252f42c1
authored
Apr 19, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 19, 2023
Browse files
fix(router): add auth token to get model info (#207)
parent
6837b2eb
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
9 deletions
+29
-9
launcher/src/main.rs
launcher/src/main.rs
+9
-0
router/src/main.rs
router/src/main.rs
+20
-9
No files found.
launcher/src/main.rs
View file @
252f42c1
...
...
@@ -414,6 +414,14 @@ fn main() -> ExitCode {
argv
.push
(
origin
);
}
// Copy current process env
let
mut
env
:
Vec
<
(
OsString
,
OsString
)
>
=
env
::
vars_os
()
.collect
();
// Parse Inference API token
if
let
Ok
(
api_token
)
=
env
::
var
(
"HF_API_TOKEN"
)
{
env
.push
((
"HUGGING_FACE_HUB_TOKEN"
.into
(),
api_token
.into
()))
};
let
mut
webserver
=
match
Popen
::
create
(
&
argv
,
PopenConfig
{
...
...
@@ -421,6 +429,7 @@ fn main() -> ExitCode {
stderr
:
Redirection
::
Pipe
,
// Needed for the shutdown procedure
setpgid
:
true
,
env
:
Some
(
env
),
..
Default
::
default
()
},
)
{
...
...
router/src/main.rs
View file @
252f42c1
...
...
@@ -90,6 +90,9 @@ fn main() -> Result<(), std::io::Error> {
)
});
// Parse Huggingface hub token
let
authorization_token
=
std
::
env
::
var
(
"HUGGING_FACE_HUB_TOKEN"
)
.ok
();
// Tokenizer instance
// This will only be used to validate payloads
let
local_path
=
Path
::
new
(
&
tokenizer_name
);
...
...
@@ -102,6 +105,7 @@ fn main() -> Result<(), std::io::Error> {
// We need to download it outside of the Tokio runtime
let
params
=
FromPretrainedParameters
{
revision
:
revision
.clone
(),
auth_token
:
authorization_token
.clone
(),
..
Default
::
default
()
};
Tokenizer
::
from_pretrained
(
tokenizer_name
.clone
(),
Some
(
params
))
.ok
()
...
...
@@ -129,7 +133,7 @@ fn main() -> Result<(), std::io::Error> {
sha
:
None
,
pipeline_tag
:
None
,
},
false
=>
get_model_info
(
&
tokenizer_name
,
&
revision
)
.await
,
false
=>
get_model_info
(
&
tokenizer_name
,
&
revision
,
authorization_token
)
.await
,
};
// if pipeline-tag == text-generation we default to return_full_text = true
...
...
@@ -233,14 +237,21 @@ fn init_logging(otlp_endpoint: Option<String>, json_output: bool) {
}
/// get model info from the Huggingface Hub
pub
async
fn
get_model_info
(
model_id
:
&
str
,
revision
:
&
str
)
->
ModelInfo
{
let
model_info
=
reqwest
::
get
(
format!
(
pub
async
fn
get_model_info
(
model_id
:
&
str
,
revision
:
&
str
,
token
:
Option
<
String
>
)
->
ModelInfo
{
let
client
=
reqwest
::
Client
::
new
();
let
mut
builder
=
client
.get
(
format!
(
"https://huggingface.co/api/models/{model_id}/revision/{revision}"
))
.await
.expect
(
"Could not connect to hf.co"
)
.text
()
.await
.expect
(
"error when retrieving model info from hf.co"
);
));
if
let
Some
(
token
)
=
token
{
builder
=
builder
.bearer_auth
(
token
);
}
let
model_info
=
builder
.send
()
.await
.expect
(
"Could not connect to hf.co"
)
.text
()
.await
.expect
(
"error when retrieving model info from hf.co"
);
serde_json
::
from_str
(
&
model_info
)
.expect
(
"unable to parse model info"
)
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment