Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
20c3c594
Unverified
Commit
20c3c594
authored
Feb 03, 2023
by
OlivierDehaene
Committed by
GitHub
Feb 03, 2023
Browse files
feat(router): refactor API and add openAPI schemas (#53)
parent
b1482d90
Changes
38
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
304 additions
and
165 deletions
+304
-165
router/src/lib.rs
router/src/lib.rs
+84
-40
router/src/server.rs
router/src/server.rs
+98
-14
router/src/validation.rs
router/src/validation.rs
+40
-23
server/README.md
server/README.md
+2
-2
server/pyproject.toml
server/pyproject.toml
+2
-2
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+4
-8
server/tests/test_utils.py
server/tests/test_utils.py
+4
-3
server/text_generation/cli.py
server/text_generation/cli.py
+4
-4
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+13
-13
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+8
-10
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+3
-3
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+8
-10
server/text_generation/models/gpt_neox.py
server/text_generation/models/gpt_neox.py
+5
-7
server/text_generation/models/santacoder.py
server/text_generation/models/santacoder.py
+3
-3
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+3
-3
server/text_generation/models/types.py
server/text_generation/models/types.py
+2
-1
server/text_generation/server.py
server/text_generation/server.py
+4
-4
server/text_generation/utils.py
server/text_generation/utils.py
+17
-15
No files found.
router/src/lib.rs
View file @
20c3c594
/// Text Generation Inference Webserver
mod
infer
;
mod
queue
;
pub
mod
server
;
...
...
@@ -8,45 +7,55 @@ mod validation;
use
infer
::
Infer
;
use
queue
::{
Entry
,
Queue
};
use
serde
::{
Deserialize
,
Serialize
};
use
utoipa
::
ToSchema
;
use
validation
::
Validation
;
#[derive(Clone,
Debug,
Deserialize)]
#[derive(Clone,
Debug,
Deserialize
,
ToSchema
)]
pub
(
crate
)
struct
GenerateParameters
{
#[serde(default
=
"default_temperature"
)]
pub
temperature
:
f32
,
#[serde(default
=
"default_repetition_penalty"
)]
pub
repetition_penalty
:
f32
,
#[serde(default
=
"default_top_k"
)]
pub
top_k
:
i32
,
#[serde(default
=
"default_top_p"
)]
pub
top_p
:
f32
,
#[serde(default)]
#[schema(
exclusive_minimum
=
0.0
,
nullable
=
true
,
default
=
"null"
,
example
=
0.5
)]
pub
temperature
:
Option
<
f32
>
,
#[serde(default)]
#[schema(
exclusive_minimum
=
0.0
,
nullable
=
true
,
default
=
"null"
,
example
=
1.03
)]
pub
repetition_penalty
:
Option
<
f32
>
,
#[serde(default)]
#[schema(exclusive_minimum
=
0
,
nullable
=
true
,
default
=
"null"
,
example
=
10
)]
pub
top_k
:
Option
<
i32
>
,
#[serde(default)]
#[schema(
exclusive_minimum
=
0.0
,
maximum
=
1.0
,
nullable
=
true
,
default
=
"null"
,
example
=
0.95
)]
pub
top_p
:
Option
<
f32
>
,
#[serde(default
=
"default_do_sample"
)]
#[schema(default
=
"false"
,
example
=
true
)]
pub
do_sample
:
bool
,
#[serde(default
=
"default_max_new_tokens"
)]
#[schema(exclusive_minimum
=
0
,
exclusive_maximum
=
512
,
default
=
"20"
)]
pub
max_new_tokens
:
u32
,
#[serde(default)]
#[schema(inline,
max_items
=
4
,
example
=
json
!
(
[
"photographer"
]
))]
pub
stop
:
Vec
<
String
>
,
#[serde(default)]
#[schema(default
=
"true"
)]
pub
details
:
bool
,
#[serde(default)]
pub
seed
:
Option
<
u64
>
,
}
fn
default_temperature
()
->
f32
{
1.0
}
fn
default_repetition_penalty
()
->
f32
{
1.0
}
fn
default_top_k
()
->
i32
{
0
}
fn
default_top_p
()
->
f32
{
1.0
}
fn
default_do_sample
()
->
bool
{
false
}
...
...
@@ -57,10 +66,10 @@ fn default_max_new_tokens() -> u32 {
fn
default_parameters
()
->
GenerateParameters
{
GenerateParameters
{
temperature
:
default_temperature
()
,
repetition_penalty
:
default_repetition_penalty
()
,
top_k
:
default_top_k
()
,
top_p
:
default_top_p
()
,
temperature
:
None
,
repetition_penalty
:
None
,
top_k
:
None
,
top_p
:
None
,
do_sample
:
default_do_sample
(),
max_new_tokens
:
default_max_new_tokens
(),
stop
:
vec!
[],
...
...
@@ -69,42 +78,77 @@ fn default_parameters() -> GenerateParameters {
}
}
#[derive(Clone,
Debug,
Deserialize)]
#[derive(Clone,
Debug,
Deserialize
,
ToSchema
)]
pub
(
crate
)
struct
GenerateRequest
{
#[schema(example
=
"My name is Olivier and I"
)]
pub
inputs
:
String
,
#[serde(default
=
"default_parameters"
)]
pub
parameters
:
GenerateParameters
,
}
#[derive(Debug,
Serialize)]
pub
struct
Token
(
u32
,
String
,
f32
);
#[derive(Debug,
Serialize,
ToSchema)]
pub
struct
Token
{
#[schema(example
=
0
)]
id
:
u32
,
#[schema(example
=
"test"
)]
text
:
String
,
#[schema(nullable
=
true
,
example
=
-
0.34
)]
logprob
:
f32
,
}
#[derive(Serialize,
ToSchema)]
#[serde(rename_all(serialize
=
"snake_case"
))]
pub
(
crate
)
enum
FinishReason
{
#[schema(rename
=
"length"
)]
Length
,
#[serde(rename
=
"eos_token"
)]
#[schema(rename
=
"eos_token"
)]
EndOfSequenceToken
,
#[schema(rename
=
"stop_sequence"
)]
StopSequence
,
}
#[derive(Serialize)]
#[derive(Serialize
,
ToSchema
)]
pub
(
crate
)
struct
Details
{
pub
finish_reason
:
String
,
#[schema(example
=
"length"
)]
pub
finish_reason
:
FinishReason
,
#[schema(example
=
1
)]
pub
generated_tokens
:
u32
,
#[schema(example
=
42
)]
pub
seed
:
Option
<
u64
>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
prefill
:
Option
<
Vec
<
Token
>>
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
tokens
:
Option
<
Vec
<
Token
>>
,
}
#[derive(Serialize)]
#[derive(Serialize
,
ToSchema
)]
pub
(
crate
)
struct
GenerateResponse
{
#[schema(example
=
"test"
)]
pub
generated_text
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
details
:
Option
<
Details
>
,
}
#[derive(Serialize)]
#[derive(Serialize,
ToSchema)]
pub
(
crate
)
struct
StreamDetails
{
#[schema(example
=
"length"
)]
pub
finish_reason
:
FinishReason
,
#[schema(example
=
1
)]
pub
generated_tokens
:
u32
,
#[schema(example
=
42
)]
pub
seed
:
Option
<
u64
>
,
}
#[derive(Serialize,
ToSchema)]
pub
(
crate
)
struct
StreamResponse
{
pub
token
:
Token
,
#[schema(nullable
=
true
,
default
=
"null"
,
example
=
"test"
)]
pub
generated_text
:
Option
<
String
>
,
pub
details
:
Option
<
Details
>
,
#[schema(nullable
=
true
,
default
=
"null"
)]
pub
details
:
Option
<
StreamDetails
>
,
}
#[derive(Serialize)]
#[derive(Serialize
,
ToSchema
)]
pub
(
crate
)
struct
ErrorResponse
{
#[schema(inline)]
pub
error
:
String
,
}
router/src/server.rs
View file @
20c3c594
/// HTTP Server logic
use
crate
::
infer
::{
InferError
,
InferStreamResponse
};
use
crate
::{
Details
,
ErrorResponse
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
Infer
,
StreamResponse
,
Validation
,
Details
,
ErrorResponse
,
FinishReason
,
GenerateParameters
,
GenerateRequest
,
GenerateResponse
,
Infer
,
StreamDetails
,
StreamResponse
,
Token
,
Validation
,
};
use
axum
::
extract
::
Extension
;
use
axum
::
http
::{
HeaderMap
,
StatusCode
};
...
...
@@ -19,6 +19,8 @@ use tokio::signal;
use
tokio
::
time
::
Instant
;
use
tokio_stream
::
StreamExt
;
use
tracing
::
instrument
;
use
utoipa
::
OpenApi
;
use
utoipa_swagger_ui
::
SwaggerUi
;
/// Health check method
#[instrument(skip(infer))]
...
...
@@ -32,13 +34,13 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
.generate
(
GenerateRequest
{
inputs
:
"liveness"
.to_string
(),
parameters
:
GenerateParameters
{
temperature
:
1.0
,
repetition_penalty
:
1.0
,
top_k
:
0
,
top_p
:
1.0
,
temperature
:
None
,
repetition_penalty
:
None
,
top_k
:
None
,
top_p
:
None
,
do_sample
:
false
,
max_new_tokens
:
1
,
stop
:
v
ec
!
[]
,
stop
:
V
ec
::
new
()
,
details
:
false
,
seed
:
None
,
},
...
...
@@ -47,7 +49,24 @@ async fn health(infer: Extension<Infer>) -> Result<(), (StatusCode, Json<ErrorRe
Ok
(())
}
/// Generate method
/// Generate tokens
#[utoipa::path(
post,
tag
=
"Text Generation Inference"
,
path
=
"/generate"
,
request_body
=
GenerateRequest,
responses(
(status
=
200
,
description
=
"Generated Text"
,
body
=
[
GenerateResponse]
),
(
status
=
424
,
description
=
"Generation Error"
,
body
=
[
ErrorResponse
],
example
=
json!
({
"error"
:
"Request failed during generation"
})),
(
status
=
429
,
description
=
"Model is overloaded"
,
body
=
[
ErrorResponse
],
example
=
json!
({
"error"
:
"Model is overloaded"
})),
(
status
=
422
,
description
=
"Input validation error"
,
body
=
[
ErrorResponse
],
example
=
json!
({
"error"
:
"Input validation error"
})),
(
status
=
500
,
description
=
"Incomplete generation"
,
body
=
[
ErrorResponse
],
example
=
json!
({
"error"
:
"Incomplete generation"
})),
)
)]
#[instrument(
skip(infer),
fields(
...
...
@@ -76,7 +95,7 @@ async fn generate(
// Token details
let
details
=
match
details
{
true
=>
Some
(
Details
{
finish_reason
:
response
.generated_text.finish_reason
,
finish_reason
:
FinishReason
::
from
(
response
.generated_text.finish_reason
)
,
generated_tokens
:
response
.generated_text.generated_tokens
,
prefill
:
Some
(
response
.prefill
),
tokens
:
Some
(
response
.tokens
),
...
...
@@ -132,7 +151,29 @@ async fn generate(
Ok
((
headers
,
Json
(
response
)))
}
/// Generate stream method
/// Generate a stream of token using Server Side Events
#[utoipa::path(
post,
tag
=
"Text Generation Inference"
,
path
=
"/generate_stream"
,
request_body
=
GenerateRequest,
responses(
(status
=
200
,
description
=
"Generated Text"
,
body
=
[
StreamResponse]
,
content_type
=
"text/event-stream "
),
(
status
=
424
,
description
=
"Generation Error"
,
body
=
[
ErrorResponse
],
example
=
json!
({
"error"
:
"Request failed during generation"
}),
content_type
=
"text/event-stream "
),
(
status
=
429
,
description
=
"Model is overloaded"
,
body
=
[
ErrorResponse
],
example
=
json!
({
"error"
:
"Model is overloaded"
}),
content_type
=
"text/event-stream "
),
(
status
=
422
,
description
=
"Input validation error"
,
body
=
[
ErrorResponse
],
example
=
json!
({
"error"
:
"Input validation error"
}),
content_type
=
"text/event-stream "
),
(
status
=
500
,
description
=
"Incomplete generation"
,
body
=
[
ErrorResponse
],
example
=
json!
({
"error"
:
"Incomplete generation"
}),
content_type
=
"text/event-stream "
),
)
)]
#[instrument(
skip(infer),
fields(
...
...
@@ -185,11 +226,9 @@ async fn generate_stream(
}
=>
{
// Token details
let
details
=
match
details
{
true
=>
Some
(
Details
{
finish_reason
:
generated_text
.finish_reason
,
true
=>
Some
(
Stream
Details
{
finish_reason
:
FinishReason
::
from
(
generated_text
.finish_reason
)
,
generated_tokens
:
generated_text
.generated_tokens
,
prefill
:
None
,
tokens
:
None
,
seed
:
generated_text
.seed
,
}),
false
=>
None
,
...
...
@@ -265,6 +304,39 @@ pub async fn run(
validation_workers
:
usize
,
addr
:
SocketAddr
,
)
{
// OpenAPI documentation
#[derive(OpenApi)]
#[openapi(
paths(
generate,
generate_stream,
),
components(
schemas(
GenerateRequest,
GenerateParameters,
Token,
GenerateResponse,
Details,
FinishReason,
StreamResponse,
StreamDetails,
ErrorResponse,
)
),
tags(
(name
=
"Text Generation Inference"
,
description
=
"Hugging Face Text Generation Inference API"
)
),
info(
title
=
"Text Generation Inference"
,
license(
name
=
"Apache 2.0"
,
url
=
"https://www.apache.org/licenses/LICENSE-2.0"
)
)
)]
struct
ApiDoc
;
// Create state
let
validation
=
Validation
::
new
(
validation_workers
,
tokenizer
,
max_input_length
);
let
infer
=
Infer
::
new
(
...
...
@@ -277,6 +349,7 @@ pub async fn run(
// Create router
let
app
=
Router
::
new
()
.merge
(
SwaggerUi
::
new
(
"/docs"
)
.url
(
"/api-doc/openapi.json"
,
ApiDoc
::
openapi
()))
.route
(
"/"
,
post
(
generate
))
.route
(
"/generate"
,
post
(
generate
))
.route
(
"/generate_stream"
,
post
(
generate_stream
))
...
...
@@ -320,6 +393,17 @@ async fn shutdown_signal() {
tracing
::
info!
(
"signal received, starting graceful shutdown"
);
}
impl
From
<
i32
>
for
FinishReason
{
fn
from
(
finish_reason
:
i32
)
->
Self
{
let
finish_reason
=
text_generation_client
::
FinishReason
::
from_i32
(
finish_reason
)
.unwrap
();
match
finish_reason
{
text_generation_client
::
FinishReason
::
Length
=>
FinishReason
::
Length
,
text_generation_client
::
FinishReason
::
EosToken
=>
FinishReason
::
EndOfSequenceToken
,
text_generation_client
::
FinishReason
::
StopSequence
=>
FinishReason
::
StopSequence
,
}
}
}
/// Convert to Axum supported formats
impl
From
<
InferError
>
for
(
StatusCode
,
Json
<
ErrorResponse
>
)
{
fn
from
(
err
:
InferError
)
->
Self
{
...
...
router/src/validation.rs
View file @
20c3c594
...
...
@@ -110,30 +110,58 @@ fn validate(
max_input_length
:
usize
,
rng
:
&
mut
ThreadRng
,
)
->
Result
<
ValidGenerateRequest
,
ValidationError
>
{
if
request
.parameters.temperature
<=
0.0
{
let
GenerateParameters
{
temperature
,
repetition_penalty
,
top_k
,
top_p
,
do_sample
,
max_new_tokens
,
stop
:
stop_sequences
,
seed
,
..
}
=
request
.parameters
;
let
temperature
=
temperature
.unwrap_or
(
1.0
);
if
temperature
<=
0.0
{
return
Err
(
ValidationError
::
Temperature
);
}
if
request
.parameters.repetition_penalty
<=
0.0
{
let
repetition_penalty
=
repetition_penalty
.unwrap_or
(
1.0
);
if
repetition_penalty
<=
0.0
{
return
Err
(
ValidationError
::
RepetitionPenalty
);
}
if
request
.parameters.top_p
<=
0.0
||
request
.parameters.top_p
>
1.0
{
let
top_p
=
top_p
.unwrap_or
(
1.0
);
if
top_p
<=
0.0
||
top_p
>
1.0
{
return
Err
(
ValidationError
::
TopP
);
}
if
request
.parameters.top_k
<
0
{
return
Err
(
ValidationError
::
TopK
);
}
if
request
.parameters.max_new_tokens
>
MAX_MAX_NEW_TOKENS
{
// Different because the proto default value is 0 while it is not a valid value
// for the user
let
top_k
:
u32
=
match
top_k
{
None
=>
Ok
(
0
),
Some
(
top_k
)
=>
{
if
top_k
<=
0
{
return
Err
(
ValidationError
::
TopK
);
}
Ok
(
top_k
as
u32
)
}
}
?
;
if
max_new_tokens
==
0
||
max_new_tokens
>
MAX_MAX_NEW_TOKENS
{
return
Err
(
ValidationError
::
MaxNewTokens
(
MAX_MAX_NEW_TOKENS
));
}
if
request
.parameters.stop
.len
()
>
MAX_STOP_SEQUENCES
{
if
stop_sequences
.len
()
>
MAX_STOP_SEQUENCES
{
return
Err
(
ValidationError
::
StopSequence
(
MAX_STOP_SEQUENCES
,
request
.parameters.stop
.len
(),
stop_sequences
.len
(),
));
}
// If seed is None, assign a random one
let
seed
=
match
request
.parameters.
seed
{
let
seed
=
match
seed
{
None
=>
rng
.gen
(),
Some
(
seed
)
=>
seed
,
};
...
...
@@ -147,21 +175,10 @@ fn validate(
Err
(
ValidationError
::
InputLength
(
input_length
,
max_input_length
))
}
else
{
// Return ValidGenerateRequest
let
GenerateParameters
{
temperature
,
repetition_penalty
,
top_k
,
top_p
,
do_sample
,
max_new_tokens
,
stop
:
stop_sequences
,
..
}
=
request
.parameters
;
let
parameters
=
NextTokenChooserParameters
{
temperature
,
repetition_penalty
,
top_k
:
top_k
as
u32
,
top_k
,
top_p
,
do_sample
,
seed
,
...
...
@@ -206,7 +223,7 @@ pub enum ValidationError {
TopP
,
#[error(
"top_k must be strictly positive"
)]
TopK
,
#[error(
"max_new_tokens must be <= {0}"
)]
#[error(
"max_new_tokens must be
strictly positive and
<= {0}"
)]
MaxNewTokens
(
u32
),
#[error(
"inputs must have less than {1} tokens. Given: {0}"
)]
InputLength
(
usize
,
usize
),
...
...
server/README.md
View file @
20c3c594
#
BLOOM
Inference Python gRPC Server
#
Text Generation
Inference Python gRPC Server
A Python gRPC server for
BLOOM
Inference
A Python gRPC server for
Text Generation
Inference
## Install
...
...
server/pyproject.toml
View file @
20c3c594
[tool.poetry]
name
=
"text-generation"
version
=
"0.
1
.0"
description
=
"
BLOOM
Inference Python gRPC Server"
version
=
"0.
2
.0"
description
=
"
Text Generation
Inference Python gRPC Server"
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
[tool.poetry.scripts]
...
...
server/tests/models/test_bloom.py
View file @
20c3c594
...
...
@@ -140,8 +140,7 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert
len
(
generations
)
==
1
assert
(
generations
[
0
].
generated_text
.
text
==
"TestTestTestTestTestTestTestTestTestTest"
generations
[
0
].
generated_text
.
text
==
"TestTestTestTestTestTestTestTestTestTest"
)
assert
generations
[
0
].
request_id
==
default_bloom_batch
.
requests
[
0
].
id
assert
(
...
...
@@ -187,8 +186,7 @@ def test_causal_lm_generate_token_completion_multi(
assert
len
(
generations
)
==
1
assert
(
generations
[
0
].
generated_text
.
text
==
"TestTestTestTestTestTestTestTestTestTest"
generations
[
0
].
generated_text
.
text
==
"TestTestTestTestTestTestTestTestTestTest"
)
assert
(
generations
[
0
].
request_id
==
default_multi_requests_bloom_batch
.
requests
[
0
].
id
...
...
@@ -283,8 +281,7 @@ def test_batch_concatenate(
assert
len
(
generations
)
==
2
assert
(
generations
[
0
].
generated_text
.
text
==
"TestTestTestTestTestTestTestTestTestTest"
generations
[
0
].
generated_text
.
text
==
"TestTestTestTestTestTestTestTestTestTest"
)
assert
generations
[
0
].
request_id
==
default_bloom_batch
.
requests
[
0
].
id
assert
(
...
...
@@ -306,8 +303,7 @@ def test_batch_concatenate(
assert
len
(
generations
)
==
1
assert
(
generations
[
0
].
generated_text
.
text
==
"TestTestTestTestTestTestTestTestTestTest"
generations
[
0
].
generated_text
.
text
==
"TestTestTestTestTestTestTestTestTestTest"
)
assert
(
generations
[
0
].
request_id
==
default_multi_requests_bloom_batch
.
requests
[
0
].
id
...
...
server/tests/test_utils.py
View file @
20c3c594
...
...
@@ -9,6 +9,7 @@ from text_generation.utils import (
StopSequenceCriteria
,
StoppingCriteria
,
LocalEntryNotFoundError
,
FinishReason
,
)
...
...
@@ -24,13 +25,13 @@ def test_stop_sequence_criteria():
def
test_stopping_criteria
():
criteria
=
StoppingCriteria
(
0
,
[
StopSequenceCriteria
(
"/test;"
)],
max_new_tokens
=
5
)
assert
criteria
(
65827
,
"/test"
)
==
(
False
,
None
)
assert
criteria
(
30
,
";"
)
==
(
True
,
"stop_sequence"
)
assert
criteria
(
30
,
";"
)
==
(
True
,
FinishReason
.
FINISH_REASON_STOP_SEQUENCE
)
def
test_stopping_criteria_eos
():
criteria
=
StoppingCriteria
(
0
,
[
StopSequenceCriteria
(
"/test;"
)],
max_new_tokens
=
5
)
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
(
0
,
""
)
==
(
True
,
"eos_token"
)
assert
criteria
(
0
,
""
)
==
(
True
,
FinishReason
.
FINISH_REASON_EOS_TOKEN
)
def
test_stopping_criteria_max
():
...
...
@@ -39,7 +40,7 @@ def test_stopping_criteria_max():
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
(
1
,
""
)
==
(
False
,
None
)
assert
criteria
(
1
,
""
)
==
(
True
,
"length"
)
assert
criteria
(
1
,
""
)
==
(
True
,
FinishReason
.
FINISH_REASON_LENGTH
)
def
test_weight_hub_files
():
...
...
server/text_generation/cli.py
View file @
20c3c594
...
...
@@ -13,7 +13,7 @@ app = typer.Typer()
@
app
.
command
()
def
serve
(
model_
name
:
str
,
model_
id
:
str
,
revision
:
Optional
[
str
]
=
None
,
sharded
:
bool
=
False
,
quantize
:
bool
=
False
,
...
...
@@ -46,16 +46,16 @@ def serve(
os
.
getenv
(
"MASTER_PORT"
,
None
)
is
not
None
),
"MASTER_PORT must be set when sharded is True"
server
.
serve
(
model_
name
,
revision
,
sharded
,
quantize
,
uds_path
)
server
.
serve
(
model_
id
,
revision
,
sharded
,
quantize
,
uds_path
)
@
app
.
command
()
def
download_weights
(
model_
name
:
str
,
model_
id
:
str
,
revision
:
Optional
[
str
]
=
None
,
extension
:
str
=
".safetensors"
,
):
utils
.
download_weights
(
model_
name
,
revision
,
extension
)
utils
.
download_weights
(
model_
id
,
revision
,
extension
)
if
__name__
==
"__main__"
:
...
...
server/text_generation/models/__init__.py
View file @
20c3c594
...
...
@@ -30,31 +30,31 @@ torch.backends.cudnn.allow_tf32 = True
def
get_model
(
model_
name
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
bool
model_
id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
config
=
AutoConfig
.
from_pretrained
(
model_
name
,
revision
=
revision
)
config
=
AutoConfig
.
from_pretrained
(
model_
id
,
revision
=
revision
)
if
config
.
model_type
==
"bloom"
:
if
sharded
:
return
BLOOMSharded
(
model_
name
,
revision
,
quantize
=
quantize
)
return
BLOOMSharded
(
model_
id
,
revision
,
quantize
=
quantize
)
else
:
return
BLOOM
(
model_
name
,
revision
,
quantize
=
quantize
)
return
BLOOM
(
model_
id
,
revision
,
quantize
=
quantize
)
elif
config
.
model_type
==
"gpt_neox"
:
if
sharded
:
return
GPTNeoxSharded
(
model_
name
,
revision
,
quantize
=
quantize
)
return
GPTNeoxSharded
(
model_
id
,
revision
,
quantize
=
quantize
)
else
:
return
GPTNeox
(
model_
name
,
revision
,
quantize
=
quantize
)
elif
model_
name
.
startswith
(
"facebook/galactica"
):
return
GPTNeox
(
model_
id
,
revision
,
quantize
=
quantize
)
elif
model_
id
.
startswith
(
"facebook/galactica"
):
if
sharded
:
return
GalacticaSharded
(
model_
name
,
revision
,
quantize
=
quantize
)
return
GalacticaSharded
(
model_
id
,
revision
,
quantize
=
quantize
)
else
:
return
Galactica
(
model_
name
,
revision
,
quantize
=
quantize
)
elif
"santacoder"
in
model_
name
:
return
SantaCoder
(
model_
name
,
revision
,
quantize
)
return
Galactica
(
model_
id
,
revision
,
quantize
=
quantize
)
elif
"santacoder"
in
model_
id
:
return
SantaCoder
(
model_
id
,
revision
,
quantize
)
else
:
if
sharded
:
raise
ValueError
(
"sharded is not supported for AutoModel"
)
try
:
return
CausalLM
(
model_
name
,
revision
,
quantize
=
quantize
)
return
CausalLM
(
model_
id
,
revision
,
quantize
=
quantize
)
except
Exception
:
return
Seq2SeqLM
(
model_
name
,
revision
,
quantize
=
quantize
)
return
Seq2SeqLM
(
model_
id
,
revision
,
quantize
=
quantize
)
server/text_generation/models/bloom.py
View file @
20c3c594
...
...
@@ -57,10 +57,10 @@ class BLOOM(CausalLM):
class
BLOOMSharded
(
BLOOM
):
def
__init__
(
self
,
model_
name
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
self
,
model_
id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
if
not
model_
name
.
startswith
(
"bigscience/bloom"
):
raise
ValueError
(
f
"Model
{
model_
name
}
is not supported"
)
if
not
model_
id
.
startswith
(
"bigscience/bloom"
):
raise
ValueError
(
f
"Model
{
model_
id
}
is not supported"
)
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
...
...
@@ -72,22 +72,20 @@ class BLOOMSharded(BLOOM):
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_
name
,
revision
=
revision
,
padding_side
=
"left"
model_
id
,
revision
=
revision
,
padding_side
=
"left"
)
config
=
AutoConfig
.
from_pretrained
(
model_
name
,
revision
=
revision
,
slow_but_exact
=
False
,
tp_parallel
=
True
model_
id
,
revision
=
revision
,
slow_but_exact
=
False
,
tp_parallel
=
True
)
config
.
pad_token_id
=
3
# Only download weights for small models
if
self
.
master
and
model_
name
==
"bigscience/bloom-560m"
:
download_weights
(
model_
name
,
revision
=
revision
,
extension
=
".safetensors"
)
if
self
.
master
and
model_
id
==
"bigscience/bloom-560m"
:
download_weights
(
model_
id
,
revision
=
revision
,
extension
=
".safetensors"
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_name
,
revision
=
revision
,
extension
=
".safetensors"
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
if
not
filenames
:
raise
ValueError
(
"No safetensors weights found"
)
...
...
server/text_generation/models/causal_lm.py
View file @
20c3c594
...
...
@@ -232,7 +232,7 @@ class CausalLMBatch(Batch):
class
CausalLM
(
Model
):
def
__init__
(
self
,
model_
name
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
def
__init__
(
self
,
model_
id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
...
...
@@ -244,10 +244,10 @@ class CausalLM(Model):
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_
name
,
revision
=
revision
,
padding_side
=
"left"
model_
id
,
revision
=
revision
,
padding_side
=
"left"
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_
name
,
model_
id
,
revision
=
revision
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
...
...
server/text_generation/models/galactica.py
View file @
20c3c594
...
...
@@ -149,10 +149,10 @@ class Galactica(CausalLM):
class
GalacticaSharded
(
Galactica
):
def
__init__
(
self
,
model_
name
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
self
,
model_
id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
if
not
model_
name
.
startswith
(
"facebook/galactica"
):
raise
ValueError
(
f
"Model
{
model_
name
}
is not supported"
)
if
not
model_
id
.
startswith
(
"facebook/galactica"
):
raise
ValueError
(
f
"Model
{
model_
id
}
is not supported"
)
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
...
...
@@ -164,22 +164,20 @@ class GalacticaSharded(Galactica):
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_
name
,
revision
=
revision
,
padding_side
=
"left"
model_
id
,
revision
=
revision
,
padding_side
=
"left"
)
config
=
AutoConfig
.
from_pretrained
(
model_
name
,
revision
=
revision
,
tp_parallel
=
True
model_
id
,
revision
=
revision
,
tp_parallel
=
True
)
tokenizer
.
pad_token_id
=
config
.
pad_token_id
# Only download weights for small models
if
self
.
master
and
model_
name
==
"facebook/galactica-125m"
:
download_weights
(
model_
name
,
revision
=
revision
,
extension
=
".safetensors"
)
if
self
.
master
and
model_
id
==
"facebook/galactica-125m"
:
download_weights
(
model_
id
,
revision
=
revision
,
extension
=
".safetensors"
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_name
,
revision
=
revision
,
extension
=
".safetensors"
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
if
not
filenames
:
raise
ValueError
(
"No safetensors weights found"
)
...
...
server/text_generation/models/gpt_neox.py
View file @
20c3c594
...
...
@@ -49,7 +49,7 @@ class GPTNeox(CausalLM):
class
GPTNeoxSharded
(
GPTNeox
):
def
__init__
(
self
,
model_
name
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
self
,
model_
id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
...
...
@@ -61,22 +61,20 @@ class GPTNeoxSharded(GPTNeox):
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_
name
,
revision
=
revision
,
padding_side
=
"left"
model_
id
,
revision
=
revision
,
padding_side
=
"left"
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
config
=
AutoConfig
.
from_pretrained
(
model_
name
,
revision
=
revision
,
tp_parallel
=
True
model_
id
,
revision
=
revision
,
tp_parallel
=
True
)
# Only master download weights
if
self
.
master
:
download_weights
(
model_
name
,
revision
=
revision
,
extension
=
".safetensors"
)
download_weights
(
model_
id
,
revision
=
revision
,
extension
=
".safetensors"
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_name
,
revision
=
revision
,
extension
=
".safetensors"
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
if
not
filenames
:
raise
ValueError
(
"No safetensors weights found"
)
...
...
server/text_generation/models/santacoder.py
View file @
20c3c594
...
...
@@ -14,7 +14,7 @@ EOD = "<|endoftext|>"
class
SantaCoder
(
CausalLM
):
def
__init__
(
self
,
model_
name
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
def
__init__
(
self
,
model_
id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
...
...
@@ -26,7 +26,7 @@ class SantaCoder(CausalLM):
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_
name
,
revision
=
revision
,
padding_side
=
"left"
model_
id
,
revision
=
revision
,
padding_side
=
"left"
)
tokenizer
.
add_special_tokens
(
{
...
...
@@ -43,7 +43,7 @@ class SantaCoder(CausalLM):
self
.
model
=
(
AutoModelForCausalLM
.
from_pretrained
(
model_
name
,
model_
id
,
revision
=
revision
,
torch_dtype
=
dtype
,
load_in_8bit
=
quantize
,
...
...
server/text_generation/models/seq2seq_lm.py
View file @
20c3c594
...
...
@@ -289,7 +289,7 @@ class Seq2SeqLMBatch(Batch):
class
Seq2SeqLM
(
Model
):
def
__init__
(
self
,
model_
name
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
def
__init__
(
self
,
model_
id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
...
...
@@ -301,14 +301,14 @@ class Seq2SeqLM(Model):
dtype
=
torch
.
float32
self
.
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_
name
,
model_
id
,
revision
=
revision
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
load_in_8bit
=
quantize
,
).
eval
()
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_
name
,
revision
=
revision
,
padding_side
=
"left"
model_
id
,
revision
=
revision
,
padding_side
=
"left"
)
tokenizer
.
bos_token_id
=
self
.
model
.
config
.
decoder_start_token_id
...
...
server/text_generation/models/types.py
View file @
20c3c594
...
...
@@ -7,6 +7,7 @@ from typing import List, Optional
from
transformers
import
PreTrainedTokenizerBase
from
text_generation.pb
import
generate_pb2
from
text_generation.pb.generate_pb2
import
FinishReason
class
Batch
(
ABC
):
...
...
@@ -38,7 +39,7 @@ class Batch(ABC):
class
GeneratedText
:
text
:
str
generated_tokens
:
int
finish_reason
:
str
finish_reason
:
FinishReason
seed
:
Optional
[
int
]
def
to_pb
(
self
)
->
generate_pb2
.
GeneratedText
:
...
...
server/text_generation/server.py
View file @
20c3c594
...
...
@@ -66,14 +66,14 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def
serve
(
model_
name
:
str
,
model_
id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
bool
,
uds_path
:
Path
,
):
async
def
serve_inner
(
model_
name
:
str
,
model_
id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
=
False
,
quantize
:
bool
=
False
,
...
...
@@ -89,7 +89,7 @@ def serve(
local_url
=
unix_socket_template
.
format
(
uds_path
,
0
)
server_urls
=
[
local_url
]
model
=
get_model
(
model_
name
,
revision
,
sharded
,
quantize
)
model
=
get_model
(
model_
id
,
revision
,
sharded
,
quantize
)
server
=
aio
.
server
(
interceptors
=
[
ExceptionInterceptor
()])
generate_pb2_grpc
.
add_TextGenerationServiceServicer_to_server
(
...
...
@@ -109,4 +109,4 @@ def serve(
logger
.
info
(
"Signal received. Shutting down"
)
await
server
.
stop
(
0
)
asyncio
.
run
(
serve_inner
(
model_
name
,
revision
,
sharded
,
quantize
))
asyncio
.
run
(
serve_inner
(
model_
id
,
revision
,
sharded
,
quantize
))
server/text_generation/utils.py
View file @
20c3c594
...
...
@@ -24,9 +24,11 @@ from transformers.generation.logits_process import (
)
from
text_generation.pb
import
generate_pb2
from
text_generation.pb.generate_pb2
import
FinishReason
WEIGHTS_CACHE_OVERRIDE
=
os
.
getenv
(
"WEIGHTS_CACHE_OVERRIDE"
,
None
)
class
Sampling
:
def
__init__
(
self
,
seed
:
int
,
device
:
str
=
"cpu"
):
self
.
generator
=
torch
.
Generator
(
device
)
...
...
@@ -129,15 +131,15 @@ class StoppingCriteria:
def
__call__
(
self
,
last_token
:
int
,
last_output
:
str
)
->
Tuple
[
bool
,
Optional
[
str
]]:
self
.
current_tokens
+=
1
if
self
.
current_tokens
>=
self
.
max_new_tokens
:
return
True
,
"length"
return
True
,
FinishReason
.
FINISH_REASON_LENGTH
if
last_token
==
self
.
eos_token_id
:
return
True
,
"eos_token"
return
True
,
FinishReason
.
FINISH_REASON_EOS_TOKEN
self
.
current_output
+=
last_output
for
stop_sequence_criteria
in
self
.
stop_sequence_criterias
:
if
stop_sequence_criteria
(
self
.
current_output
):
return
True
,
"stop_sequence"
return
True
,
FinishReason
.
FINISH_REASON_STOP_SEQUENCE
return
False
,
None
...
...
@@ -180,20 +182,20 @@ def initialize_torch_distributed():
return
torch
.
distributed
.
distributed_c10d
.
_get_default_group
(),
rank
,
world_size
def
weight_hub_files
(
model_
name
,
revision
=
None
,
extension
=
".safetensors"
):
def
weight_hub_files
(
model_
id
,
revision
=
None
,
extension
=
".safetensors"
):
"""Get the safetensors filenames on the hub"""
api
=
HfApi
()
info
=
api
.
model_info
(
model_
name
,
revision
=
revision
)
info
=
api
.
model_info
(
model_
id
,
revision
=
revision
)
filenames
=
[
s
.
rfilename
for
s
in
info
.
siblings
if
s
.
rfilename
.
endswith
(
extension
)]
return
filenames
def
try_to_load_from_cache
(
model_
name
,
revision
,
filename
):
def
try_to_load_from_cache
(
model_
id
,
revision
,
filename
):
"""Try to load a file from the Hugging Face cache"""
if
revision
is
None
:
revision
=
"main"
object_id
=
model_
name
.
replace
(
"/"
,
"--"
)
object_id
=
model_
id
.
replace
(
"/"
,
"--"
)
repo_cache
=
Path
(
HUGGINGFACE_HUB_CACHE
)
/
f
"models--
{
object_id
}
"
if
not
repo_cache
.
is_dir
():
...
...
@@ -228,38 +230,38 @@ def try_to_load_from_cache(model_name, revision, filename):
return
str
(
cached_file
)
if
cached_file
.
is_file
()
else
None
def
weight_files
(
model_
name
,
revision
=
None
,
extension
=
".safetensors"
):
def
weight_files
(
model_
id
,
revision
=
None
,
extension
=
".safetensors"
):
"""Get the local safetensors filenames"""
if
WEIGHTS_CACHE_OVERRIDE
is
not
None
:
return
list
(
Path
(
WEIGHTS_CACHE_OVERRIDE
).
glob
(
f
"*
{
extension
}
"
))
filenames
=
weight_hub_files
(
model_
name
,
revision
,
extension
)
filenames
=
weight_hub_files
(
model_
id
,
revision
,
extension
)
files
=
[]
for
filename
in
filenames
:
cache_file
=
try_to_load_from_cache
(
model_
name
,
revision
=
revision
,
filename
=
filename
model_
id
,
revision
=
revision
,
filename
=
filename
)
if
cache_file
is
None
:
raise
LocalEntryNotFoundError
(
f
"File
{
filename
}
of model
{
model_
name
}
not found in "
f
"File
{
filename
}
of model
{
model_
id
}
not found in "
f
"
{
os
.
getenv
(
'HUGGINGFACE_HUB_CACHE'
,
'the local cache'
)
}
. "
f
"Please run `text-generation-server download-weights
{
model_
name
}
` first."
f
"Please run `text-generation-server download-weights
{
model_
id
}
` first."
)
files
.
append
(
cache_file
)
return
files
def
download_weights
(
model_
name
,
revision
=
None
,
extension
=
".safetensors"
):
def
download_weights
(
model_
id
,
revision
=
None
,
extension
=
".safetensors"
):
"""Download the safetensors files from the hub"""
if
WEIGHTS_CACHE_OVERRIDE
is
not
None
:
return
list
(
Path
(
WEIGHTS_CACHE_OVERRIDE
).
glob
(
f
"*
{
extension
}
"
))
filenames
=
weight_hub_files
(
model_
name
,
revision
,
extension
)
filenames
=
weight_hub_files
(
model_
id
,
revision
,
extension
)
download_function
=
partial
(
hf_hub_download
,
repo_id
=
model_
name
,
repo_id
=
model_
id
,
local_files_only
=
False
,
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment