Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
32a25306
Unverified
Commit
32a25306
authored
Dec 15, 2022
by
OlivierDehaene
Committed by
GitHub
Dec 15, 2022
Browse files
feat: Return logprobs (#8)
parent
718096f6
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
247 additions
and
94 deletions
+247
-94
.github/workflows/server-tests.yaml
.github/workflows/server-tests.yaml
+30
-0
README.md
README.md
+1
-0
proto/generate.proto
proto/generate.proto
+12
-6
router/client/src/lib.rs
router/client/src/lib.rs
+1
-1
router/src/batcher.rs
router/src/batcher.rs
+10
-3
router/src/db.rs
router/src/db.rs
+3
-3
router/src/lib.rs
router/src/lib.rs
+13
-1
router/src/server.rs
router/src/server.rs
+26
-5
server/tests/conftest.py
server/tests/conftest.py
+1
-1
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+20
-12
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+13
-12
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+12
-12
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+3
-7
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+44
-9
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+3
-7
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+40
-10
server/text_generation/models/types.py
server/text_generation/models/types.py
+9
-3
server/text_generation/utils.py
server/text_generation/utils.py
+6
-2
No files found.
.github/workflows/server-tests.yaml
0 → 100644
View file @
32a25306
name
:
Server Tests
on
:
pull_request
:
paths
:
-
"
server/**"
-
"
proto/**"
jobs
:
run_tests
:
runs-on
:
ubuntu-20.04
steps
:
-
uses
:
actions/checkout@v2
-
name
:
Set up Python
uses
:
actions/setup-python@v1
with
:
python-version
:
3.9
-
name
:
Loading cache.
uses
:
actions/cache@v2
id
:
model_cache
with
:
path
:
~/.cache/huggingface/
key
:
models
-
name
:
Install server dependencies
run
:
|
make install-server
-
name
:
Run tests
run
:
|
pip install pytest
pytest -sv server/tests
README.md
View file @
32a25306
...
@@ -17,6 +17,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
...
@@ -17,6 +17,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
-
45ms per token generation for BLOOM with 8xA100 80GB
-
45ms per token generation for BLOOM with 8xA100 80GB
-
Logits warpers (temperature scaling, topk ...)
-
Logits warpers (temperature scaling, topk ...)
-
Stop sequences
-
Stop sequences
-
Log probabilities
## Officially supported models
## Officially supported models
...
...
proto/generate.proto
View file @
32a25306
...
@@ -27,7 +27,7 @@ message ClearCacheRequest {}
...
@@ -27,7 +27,7 @@ message ClearCacheRequest {}
/// Empty response
/// Empty response
message
ClearCacheResponse
{}
message
ClearCacheResponse
{}
message
LogitsWarp
erParameters
{
message
NextTokenChoos
erParameters
{
/// exponential scaling output probability distribution
/// exponential scaling output probability distribution
float
temperature
=
1
;
float
temperature
=
1
;
/// restricting to the k highest probability elements
/// restricting to the k highest probability elements
...
@@ -52,8 +52,8 @@ message Request {
...
@@ -52,8 +52,8 @@ message Request {
string
inputs
=
2
;
string
inputs
=
2
;
/// The number of tokens inside inputs
/// The number of tokens inside inputs
uint32
input_length
=
3
;
uint32
input_length
=
3
;
///
Logits Warp
er Parameters
///
Next Token Choos
er Parameters
LogitsWarp
erParameters
parameters
=
4
;
NextTokenChoos
erParameters
parameters
=
4
;
/// Stopping Criteria Parameters
/// Stopping Criteria Parameters
StoppingCriteriaParameters
stopping_parameters
=
5
;
StoppingCriteriaParameters
stopping_parameters
=
5
;
}
}
...
@@ -71,11 +71,17 @@ message GeneratedText {
...
@@ -71,11 +71,17 @@ message GeneratedText {
/// Request
/// Request
Request
request
=
1
;
Request
request
=
1
;
/// Output
/// Output
string
output
=
2
;
string
output
_text
=
2
;
/// Number of generated tokens
/// Number of generated tokens
uint32
tokens
=
3
;
uint32
generated_tokens
=
3
;
/// Tokens
repeated
string
tokens
=
4
;
/// Token IDs
repeated
uint32
token_ids
=
5
;
/// Logprobs
repeated
float
logprobs
=
6
;
/// Finish reason
/// Finish reason
string
finish_reason
=
4
;
string
finish_reason
=
7
;
}
}
message
GenerateRequest
{
message
GenerateRequest
{
...
...
router/client/src/lib.rs
View file @
32a25306
...
@@ -7,7 +7,7 @@ mod sharded_client;
...
@@ -7,7 +7,7 @@ mod sharded_client;
pub
use
client
::
Client
;
pub
use
client
::
Client
;
pub
use
pb
::
generate
::
v1
::{
pub
use
pb
::
generate
::
v1
::{
Batch
,
GeneratedText
,
LogitsWarp
erParameters
,
Request
,
StoppingCriteriaParameters
,
Batch
,
GeneratedText
,
NextTokenChoos
erParameters
,
Request
,
StoppingCriteriaParameters
,
};
};
pub
use
sharded_client
::
ShardedClient
;
pub
use
sharded_client
::
ShardedClient
;
use
thiserror
::
Error
;
use
thiserror
::
Error
;
...
...
router/src/batcher.rs
View file @
32a25306
...
@@ -187,9 +187,13 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
...
@@ -187,9 +187,13 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
let
entry
=
db
let
entry
=
db
.remove
(
&
output
.request
.unwrap
()
.id
)
.remove
(
&
output
.request
.unwrap
()
.id
)
.expect
(
"ID not found in db. This is a bug."
);
.expect
(
"ID not found in db. This is a bug."
);
let
response
=
InferResponse
{
let
response
=
InferResponse
{
output
:
output
.output
,
output_text
:
output
.output_text
,
generated_tokens
:
output
.generated_tokens
,
token_ids
:
output
.token_ids
,
tokens
:
output
.tokens
,
tokens
:
output
.tokens
,
logprobs
:
output
.logprobs
,
finish_reason
:
output
.finish_reason
,
finish_reason
:
output
.finish_reason
,
queued
:
entry
.time
,
queued
:
entry
.time
,
start
:
entry
.batch_time
.unwrap
(),
// unwrap is always valid
start
:
entry
.batch_time
.unwrap
(),
// unwrap is always valid
...
@@ -202,8 +206,11 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
...
@@ -202,8 +206,11 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
#[derive(Debug)]
#[derive(Debug)]
pub
(
crate
)
struct
InferResponse
{
pub
(
crate
)
struct
InferResponse
{
pub
(
crate
)
output
:
String
,
pub
(
crate
)
output_text
:
String
,
pub
(
crate
)
tokens
:
u32
,
pub
(
crate
)
generated_tokens
:
u32
,
pub
(
crate
)
token_ids
:
Vec
<
u32
>
,
pub
(
crate
)
tokens
:
Vec
<
String
>
,
pub
(
crate
)
logprobs
:
Vec
<
f32
>
,
pub
(
crate
)
finish_reason
:
String
,
pub
(
crate
)
finish_reason
:
String
,
pub
(
crate
)
queued
:
Instant
,
pub
(
crate
)
queued
:
Instant
,
pub
(
crate
)
start
:
Instant
,
pub
(
crate
)
start
:
Instant
,
...
...
router/src/db.rs
View file @
32a25306
...
@@ -5,7 +5,7 @@ use parking_lot::Mutex;
...
@@ -5,7 +5,7 @@ use parking_lot::Mutex;
use
std
::
collections
::
BTreeMap
;
use
std
::
collections
::
BTreeMap
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
text_generation_client
::{
use
text_generation_client
::{
Batch
,
ClientError
,
LogitsWarp
erParameters
,
Request
,
StoppingCriteriaParameters
,
Batch
,
ClientError
,
NextTokenChoos
erParameters
,
Request
,
StoppingCriteriaParameters
,
};
};
use
tokio
::
sync
::
oneshot
::
Sender
;
use
tokio
::
sync
::
oneshot
::
Sender
;
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
...
@@ -71,7 +71,7 @@ impl State {
...
@@ -71,7 +71,7 @@ impl State {
id
:
*
id
,
id
:
*
id
,
inputs
:
entry
.request.inputs
.clone
(),
inputs
:
entry
.request.inputs
.clone
(),
input_length
:
entry
.input_length
as
u32
,
input_length
:
entry
.input_length
as
u32
,
parameters
:
Some
(
LogitsWarp
erParameters
::
from
(
parameters
:
Some
(
NextTokenChoos
erParameters
::
from
(
entry
.request.parameters
.clone
(),
entry
.request.parameters
.clone
(),
)),
)),
stopping_parameters
:
Some
(
StoppingCriteriaParameters
::
from
(
stopping_parameters
:
Some
(
StoppingCriteriaParameters
::
from
(
...
@@ -162,7 +162,7 @@ impl Db {
...
@@ -162,7 +162,7 @@ impl Db {
}
}
}
}
impl
From
<
GenerateParameters
>
for
LogitsWarp
erParameters
{
impl
From
<
GenerateParameters
>
for
NextTokenChoos
erParameters
{
fn
from
(
parameters
:
GenerateParameters
)
->
Self
{
fn
from
(
parameters
:
GenerateParameters
)
->
Self
{
Self
{
Self
{
temperature
:
parameters
.temperature
,
temperature
:
parameters
.temperature
,
...
...
router/src/lib.rs
View file @
32a25306
...
@@ -21,7 +21,10 @@ pub(crate) struct GenerateParameters {
...
@@ -21,7 +21,10 @@ pub(crate) struct GenerateParameters {
pub
do_sample
:
bool
,
pub
do_sample
:
bool
,
#[serde(default
=
"default_max_new_tokens"
)]
#[serde(default
=
"default_max_new_tokens"
)]
pub
max_new_tokens
:
u32
,
pub
max_new_tokens
:
u32
,
#[serde(default)]
pub
stop
:
Vec
<
String
>
,
pub
stop
:
Vec
<
String
>
,
#[serde(default)]
pub
details
:
bool
,
}
}
fn
default_temperature
()
->
f32
{
fn
default_temperature
()
->
f32
{
...
@@ -52,6 +55,7 @@ fn default_parameters() -> GenerateParameters {
...
@@ -52,6 +55,7 @@ fn default_parameters() -> GenerateParameters {
do_sample
:
default_do_sample
(),
do_sample
:
default_do_sample
(),
max_new_tokens
:
default_max_new_tokens
(),
max_new_tokens
:
default_max_new_tokens
(),
stop
:
vec!
[],
stop
:
vec!
[],
details
:
false
,
}
}
}
}
...
@@ -62,10 +66,18 @@ pub(crate) struct GenerateRequest {
...
@@ -62,10 +66,18 @@ pub(crate) struct GenerateRequest {
pub
parameters
:
GenerateParameters
,
pub
parameters
:
GenerateParameters
,
}
}
#[derive(Serialize)]
pub
(
crate
)
struct
Details
{
pub
finish_reason
:
String
,
pub
generated_tokens
:
u32
,
pub
tokens
:
Vec
<
(
u32
,
String
,
f32
)
>
,
}
#[derive(Serialize)]
#[derive(Serialize)]
pub
(
crate
)
struct
GeneratedText
{
pub
(
crate
)
struct
GeneratedText
{
pub
generated_text
:
String
,
pub
generated_text
:
String
,
pub
finish_reason
:
String
,
#[serde(skip_serializing_if
=
"Option::is_none"
)]
pub
details
:
Option
<
Details
>
,
}
}
#[derive(Serialize)]
#[derive(Serialize)]
...
...
router/src/server.rs
View file @
32a25306
use
crate
::{
use
crate
::{
Batcher
,
ErrorResponse
,
GenerateParameters
,
GenerateRequest
,
GeneratedText
,
Validation
,
Batcher
,
Details
,
ErrorResponse
,
GenerateParameters
,
GenerateRequest
,
GeneratedText
,
Validation
,
};
};
use
axum
::
extract
::
Extension
;
use
axum
::
extract
::
Extension
;
use
axum
::
http
::{
HeaderMap
,
StatusCode
};
use
axum
::
http
::{
HeaderMap
,
StatusCode
};
...
@@ -54,6 +54,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
...
@@ -54,6 +54,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
do_sample
:
false
,
do_sample
:
false
,
max_new_tokens
:
1
,
max_new_tokens
:
1
,
stop
:
vec!
[],
stop
:
vec!
[],
details
:
false
,
},
},
},
},
)
)
...
@@ -89,6 +90,7 @@ async fn generate(
...
@@ -89,6 +90,7 @@ async fn generate(
})
?
;
})
?
;
// Validate request
// Validate request
let
details
=
req
.0
.parameters.details
;
let
(
input_length
,
validated_request
)
=
let
(
input_length
,
validated_request
)
=
state
.validation
.validate
(
req
.0
)
.await
.map_err
(|
err
|
{
state
.validation
.validate
(
req
.0
)
.await
.map_err
(|
err
|
{
tracing
::
error!
(
"{}"
,
err
.to_string
());
tracing
::
error!
(
"{}"
,
err
.to_string
());
...
@@ -105,12 +107,31 @@ async fn generate(
...
@@ -105,12 +107,31 @@ async fn generate(
err
err
})
?
;
})
?
;
// Token details
let
details
=
match
details
{
true
=>
{
let
tokens
=
response
.token_ids
.into_iter
()
.zip
(
response
.tokens
.into_iter
())
.zip
(
response
.logprobs
.into_iter
())
.map
(|((
id
,
text
),
logprob
)|
(
id
,
text
,
logprob
))
.collect
();
Some
(
Details
{
finish_reason
:
response
.finish_reason
,
generated_tokens
:
response
.generated_tokens
,
tokens
,
})
}
false
=>
None
,
};
// Timings
// Timings
let
total_time
=
start_time
.elapsed
();
let
total_time
=
start_time
.elapsed
();
let
validation_time
=
response
.queued
-
start_time
;
let
validation_time
=
response
.queued
-
start_time
;
let
queue_time
=
response
.start
-
response
.queued
;
let
queue_time
=
response
.start
-
response
.queued
;
let
inference_time
=
response
.end
-
response
.start
;
let
inference_time
=
response
.end
-
response
.start
;
let
time_per_token
=
inference_time
/
response
.tokens
;
let
time_per_token
=
inference_time
/
response
.
generated_
tokens
;
// Headers
// Headers
let
mut
headers
=
HeaderMap
::
new
();
let
mut
headers
=
HeaderMap
::
new
();
...
@@ -141,12 +162,12 @@ async fn generate(
...
@@ -141,12 +162,12 @@ async fn generate(
tracing
::
Span
::
current
()
.record
(
"queue_time"
,
format!
(
"{:?}"
,
queue_time
));
tracing
::
Span
::
current
()
.record
(
"queue_time"
,
format!
(
"{:?}"
,
queue_time
));
tracing
::
Span
::
current
()
.record
(
"inference_time"
,
format!
(
"{:?}"
,
inference_time
));
tracing
::
Span
::
current
()
.record
(
"inference_time"
,
format!
(
"{:?}"
,
inference_time
));
tracing
::
Span
::
current
()
.record
(
"time_per_token"
,
format!
(
"{:?}"
,
time_per_token
));
tracing
::
Span
::
current
()
.record
(
"time_per_token"
,
format!
(
"{:?}"
,
time_per_token
));
tracing
::
info!
(
"Output: {}"
,
response
.output
);
tracing
::
info!
(
"Output: {}"
,
response
.output
_text
);
// Send response
// Send response
let
response
=
vec!
[
GeneratedText
{
let
response
=
vec!
[
GeneratedText
{
generated_text
:
response
.output
,
generated_text
:
response
.output
_text
,
finish_reason
:
response
.finish_reason
,
details
,
}];
}];
Ok
((
headers
,
Json
(
response
)))
Ok
((
headers
,
Json
(
response
)))
}
}
...
...
server/tests/conftest.py
View file @
32a25306
...
@@ -7,7 +7,7 @@ from text_generation.pb import generate_pb2
...
@@ -7,7 +7,7 @@ from text_generation.pb import generate_pb2
@
pytest
.
fixture
@
pytest
.
fixture
def
default_pb_parameters
():
def
default_pb_parameters
():
return
generate_pb2
.
LogitsWarp
erParameters
(
return
generate_pb2
.
NextTokenChoos
erParameters
(
temperature
=
1.0
,
temperature
=
1.0
,
top_k
=
0
,
top_k
=
0
,
top_p
=
1.0
,
top_p
=
1.0
,
...
...
server/tests/models/test_bloom.py
View file @
32a25306
...
@@ -128,10 +128,12 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
...
@@ -128,10 +128,12 @@ def test_causal_lm_generate_token_completion(default_bloom, default_bloom_batch)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"TestTestTestTestTestTestTestTestTestTestTest"
assert
(
generated_texts
[
0
].
output_text
==
"TestTestTestTestTestTestTestTestTestTestTest"
)
assert
generated_texts
[
0
].
request
==
default_bloom_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
request
==
default_bloom_batch
.
requests
[
0
]
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
generated_
tokens
==
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
==
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
)
...
@@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion_multi(
...
@@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion_multi(
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"TestTestTestTestTestTest"
assert
generated_texts
[
0
].
output
_text
==
"TestTestTestTestTestTest"
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
1
]
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
1
]
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
generated_
tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
)
...
@@ -170,10 +172,12 @@ def test_causal_lm_generate_token_completion_multi(
...
@@ -170,10 +172,12 @@ def test_causal_lm_generate_token_completion_multi(
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"TestTestTestTestTestTestTestTestTestTestTest"
assert
(
generated_texts
[
0
].
output_text
==
"TestTestTestTestTestTestTestTestTestTestTest"
)
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
0
]
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
generated_
tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
)
...
@@ -240,10 +244,10 @@ def test_batch_concatenate(
...
@@ -240,10 +244,10 @@ def test_batch_concatenate(
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"TestTestTestTestTestTest"
assert
generated_texts
[
0
].
output
_text
==
"TestTestTestTestTestTest"
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
1
]
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
1
]
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
generated_
tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
)
...
@@ -259,10 +263,12 @@ def test_batch_concatenate(
...
@@ -259,10 +263,12 @@ def test_batch_concatenate(
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"TestTestTestTestTestTestTestTestTestTestTest"
assert
(
generated_texts
[
0
].
output_text
==
"TestTestTestTestTestTestTestTestTestTestTest"
)
assert
generated_texts
[
0
].
request
==
default_bloom_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
request
==
default_bloom_batch
.
requests
[
0
]
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
generated_
tokens
==
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
==
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
)
...
@@ -279,9 +285,11 @@ def test_batch_concatenate(
...
@@ -279,9 +285,11 @@ def test_batch_concatenate(
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"TestTestTestTestTestTestTestTestTestTestTest"
assert
(
generated_texts
[
0
].
output_text
==
"TestTestTestTestTestTestTestTestTestTestTest"
)
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
0
]
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
generated_
tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
)
server/tests/models/test_causal_lm.py
View file @
32a25306
...
@@ -127,10 +127,11 @@ def test_causal_lm_generate_token_completion(
...
@@ -127,10 +127,11 @@ def test_causal_lm_generate_token_completion(
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"Test.java:784) at net.minecraft."
assert
generated_texts
[
0
].
output
_text
==
"Test.java:784) at net.minecraft."
assert
generated_texts
[
0
].
request
==
default_causal_lm_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
request
==
default_causal_lm_batch
.
requests
[
0
]
assert
len
(
generated_texts
[
0
].
tokens
)
==
len
(
generated_texts
[
0
].
logprobs
)
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
generated_
tokens
==
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
==
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
)
...
@@ -150,12 +151,12 @@ def test_causal_lm_generate_token_completion_multi(
...
@@ -150,12 +151,12 @@ def test_causal_lm_generate_token_completion_multi(
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"Test.java:784)"
assert
generated_texts
[
0
].
output
_text
==
"Test.java:784)"
assert
(
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
1
]
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
1
]
)
)
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
generated_
tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
)
...
@@ -171,12 +172,12 @@ def test_causal_lm_generate_token_completion_multi(
...
@@ -171,12 +172,12 @@ def test_causal_lm_generate_token_completion_multi(
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"Test.java:784) at net.minecraft."
assert
generated_texts
[
0
].
output
_text
==
"Test.java:784) at net.minecraft."
assert
(
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
0
]
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
0
]
)
)
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
generated_
tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
)
...
@@ -242,12 +243,12 @@ def test_batch_concatenate(
...
@@ -242,12 +243,12 @@ def test_batch_concatenate(
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"Test.java:784)"
assert
generated_texts
[
0
].
output
_text
==
"Test.java:784)"
assert
(
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
1
]
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
1
]
)
)
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
generated_
tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
)
...
@@ -263,10 +264,10 @@ def test_batch_concatenate(
...
@@ -263,10 +264,10 @@ def test_batch_concatenate(
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"Test.java:784) at net.minecraft."
assert
generated_texts
[
0
].
output
_text
==
"Test.java:784) at net.minecraft."
assert
generated_texts
[
0
].
request
==
default_causal_lm_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
request
==
default_causal_lm_batch
.
requests
[
0
]
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
generated_
tokens
==
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
==
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
)
...
@@ -283,11 +284,11 @@ def test_batch_concatenate(
...
@@ -283,11 +284,11 @@ def test_batch_concatenate(
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"Test.java:784) at net.minecraft."
assert
generated_texts
[
0
].
output
_text
==
"Test.java:784) at net.minecraft."
assert
(
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
0
]
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
0
]
)
)
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
generated_
tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
)
server/tests/models/test_seq2seq_lm.py
View file @
32a25306
...
@@ -148,9 +148,9 @@ def test_seq2seq_lm_generate_token_completion(
...
@@ -148,9 +148,9 @@ def test_seq2seq_lm_generate_token_completion(
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"a few weeks"
assert
generated_texts
[
0
].
output
_text
==
"a few weeks"
assert
generated_texts
[
0
].
request
==
default_seq2seq_lm_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
request
==
default_seq2seq_lm_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
tokens
==
7
assert
generated_texts
[
0
].
generated_
tokens
==
7
def
test_seq2seq_lm_generate_token_completion_multi
(
def
test_seq2seq_lm_generate_token_completion_multi
(
...
@@ -166,12 +166,12 @@ def test_seq2seq_lm_generate_token_completion_multi(
...
@@ -166,12 +166,12 @@ def test_seq2seq_lm_generate_token_completion_multi(
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"a few "
assert
generated_texts
[
0
].
output
_text
==
"a few "
assert
(
assert
(
generated_texts
[
0
].
request
generated_texts
[
0
].
request
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
1
]
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
1
]
)
)
assert
generated_texts
[
0
].
tokens
==
5
assert
generated_texts
[
0
].
generated_
tokens
==
5
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
assert
generated_texts
==
[]
...
@@ -180,12 +180,12 @@ def test_seq2seq_lm_generate_token_completion_multi(
...
@@ -180,12 +180,12 @@ def test_seq2seq_lm_generate_token_completion_multi(
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"a few weeks"
assert
generated_texts
[
0
].
output
_text
==
"a few weeks"
assert
(
assert
(
generated_texts
[
0
].
request
generated_texts
[
0
].
request
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
0
]
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
0
]
)
)
assert
generated_texts
[
0
].
tokens
==
7
assert
generated_texts
[
0
].
generated_
tokens
==
7
def
test_batch_concatenate
(
def
test_batch_concatenate
(
...
@@ -287,28 +287,28 @@ def test_batch_concatenate(
...
@@ -287,28 +287,28 @@ def test_batch_concatenate(
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"a few "
assert
generated_texts
[
0
].
output
_text
==
"a few "
assert
(
assert
(
generated_texts
[
0
].
request
generated_texts
[
0
].
request
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
1
]
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
1
]
)
)
assert
generated_texts
[
0
].
tokens
==
5
assert
generated_texts
[
0
].
generated_
tokens
==
5
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"a few weeks"
assert
generated_texts
[
0
].
output
_text
==
"a few weeks"
assert
generated_texts
[
0
].
request
==
default_seq2seq_lm_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
request
==
default_seq2seq_lm_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
tokens
==
7
assert
generated_texts
[
0
].
generated_
tokens
==
7
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"a few weeks"
assert
generated_texts
[
0
].
output
_text
==
"a few weeks"
assert
(
assert
(
generated_texts
[
0
].
request
generated_texts
[
0
].
request
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
0
]
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
0
]
)
)
assert
generated_texts
[
0
].
tokens
==
7
assert
generated_texts
[
0
].
generated_
tokens
==
7
server/text_generation/models/bloom.py
View file @
32a25306
...
@@ -246,12 +246,8 @@ class BLOOMSharded(BLOOM):
...
@@ -246,12 +246,8 @@ class BLOOMSharded(BLOOM):
)
)
# Logits are sharded, so we need to gather them
# Logits are sharded, so we need to gather them
logits_shard
=
outputs
.
logits
[:,
-
1
,
:].
contiguous
()
logits
=
[
torch
.
empty_like
(
outputs
.
logits
)
for
_
in
range
(
self
.
world_size
)]
torch
.
distributed
.
all_gather
(
logits
,
outputs
.
logits
,
group
=
self
.
process_group
)
batch_size
,
vocab_shard_size
=
logits_shard
.
shape
logits
=
torch
.
cat
(
logits
,
dim
=
2
)
vocab_size
=
self
.
world_size
*
vocab_shard_size
logits
=
[
torch
.
empty_like
(
logits_shard
)
for
_
in
range
(
self
.
world_size
)]
torch
.
distributed
.
all_gather
(
logits
,
logits_shard
,
group
=
self
.
process_group
)
logits
=
torch
.
cat
(
logits
,
dim
=
1
).
view
(
batch_size
,
1
,
vocab_size
)
return
logits
,
outputs
.
past_key_values
return
logits
,
outputs
.
past_key_values
server/text_generation/models/causal_lm.py
View file @
32a25306
...
@@ -22,6 +22,7 @@ class CausalLMBatch:
...
@@ -22,6 +22,7 @@ class CausalLMBatch:
# All tokens
# All tokens
all_input_ids
:
List
[
torch
.
Tensor
]
all_input_ids
:
List
[
torch
.
Tensor
]
all_logprobs
:
List
[
Optional
[
torch
.
Tensor
]]
# Lengths of all generations present in the batch
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
input_lengths
:
List
[
int
]
...
@@ -52,6 +53,7 @@ class CausalLMBatch:
...
@@ -52,6 +53,7 @@ class CausalLMBatch:
next_token_choosers
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
stopping_criterias
=
[]
input_lengths
=
[]
input_lengths
=
[]
all_logprobs
=
[]
# Parse batch
# Parse batch
for
r
in
pb
.
requests
:
for
r
in
pb
.
requests
:
...
@@ -61,6 +63,7 @@ class CausalLMBatch:
...
@@ -61,6 +63,7 @@ class CausalLMBatch:
stopping_criterias
.
append
(
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
)
)
all_logprobs
.
append
(
None
)
pad_to_multiple_of
=
8
if
"gpu"
in
str
(
device
)
else
None
pad_to_multiple_of
=
8
if
"gpu"
in
str
(
device
)
else
None
tokenized_inputs
=
tokenizer
(
tokenized_inputs
=
tokenizer
(
...
@@ -78,6 +81,7 @@ class CausalLMBatch:
...
@@ -78,6 +81,7 @@ class CausalLMBatch:
attention_mask
=
tokenized_inputs
[
"attention_mask"
],
attention_mask
=
tokenized_inputs
[
"attention_mask"
],
past_key_values
=
None
,
past_key_values
=
None
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
all_logprobs
=
all_logprobs
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
...
@@ -95,6 +99,7 @@ class CausalLMBatch:
...
@@ -95,6 +99,7 @@ class CausalLMBatch:
requests
=
[]
requests
=
[]
input_lengths
=
[]
input_lengths
=
[]
all_input_ids
=
[]
all_input_ids
=
[]
all_logprobs
=
[]
next_token_choosers
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
stopping_criterias
=
[]
...
@@ -110,6 +115,7 @@ class CausalLMBatch:
...
@@ -110,6 +115,7 @@ class CausalLMBatch:
requests
.
extend
(
batch
.
requests
)
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
input_lengths
.
extend
(
batch
.
input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_logprobs
.
extend
(
batch
.
all_logprobs
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
...
@@ -217,6 +223,7 @@ class CausalLMBatch:
...
@@ -217,6 +223,7 @@ class CausalLMBatch:
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
all_logprobs
=
all_logprobs
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
...
@@ -291,6 +298,7 @@ class CausalLM(Model):
...
@@ -291,6 +298,7 @@ class CausalLM(Model):
next_batch_input_lengths
=
[]
next_batch_input_lengths
=
[]
next_batch_input_ids
=
[]
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_batch_all_logprobs
=
[]
# Metadata
# Metadata
next_batch_size
=
0
next_batch_size
=
0
...
@@ -307,6 +315,7 @@ class CausalLM(Model):
...
@@ -307,6 +315,7 @@ class CausalLM(Model):
batch
.
next_token_choosers
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
batch
.
all_input_ids
,
batch
.
all_logprobs
,
)
)
# For each member of the batch
# For each member of the batch
...
@@ -316,34 +325,59 @@ class CausalLM(Model):
...
@@ -316,34 +325,59 @@ class CausalLM(Model):
logits
,
logits
,
next_token_chooser
,
next_token_chooser
,
stopping_criteria
,
stopping_criteria
,
all_tokens
,
all_input_ids
,
all_logprobs
,
)
in
enumerate
(
iterator
):
)
in
enumerate
(
iterator
):
# Select next token
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
])
tokens
,
logprobs
=
next_token_chooser
(
all_input_ids
,
logits
)
next_token
=
tokens
[
-
1
].
view
(
1
,
1
)
# Append next token to all tokens
# Append next token to all tokens
all_tokens
=
torch
.
cat
([
all_tokens
,
next_token
])
all_input_ids
=
torch
.
cat
([
all_input_ids
,
next_token
])
new_input_length
=
input_length
+
1
if
all_logprobs
is
None
:
# logprobs of all prompt tokens (except the first one) and the generated token
all_logprobs
=
logprobs
.
gather
(
1
,
all_input_ids
[
1
:])
else
:
# logprob of the generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token
]
all_logprobs
=
torch
.
cat
([
all_logprobs
,
next_token_logprob
])
# Evaluate stopping criteria
# Evaluate stopping criteria
stop
,
reason
=
stopping_criteria
(
all_
token
s
)
stop
,
reason
=
stopping_criteria
(
all_
input_id
s
)
if
stop
:
if
stop
:
# Decode all tokens
# Decode all tokens
output
=
self
.
tokenizer
.
decode
(
output
_text
=
self
.
tokenizer
.
decode
(
all_
token
s
.
squeeze
(
-
1
),
skip_special_tokens
=
True
all_
input_id
s
.
squeeze
(
-
1
),
skip_special_tokens
=
True
)
)
# Slice with input_length to remove padding
token_ids
=
all_input_ids
[
-
new_input_length
:]
tokens
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
# Add NaN for the first prompt token
logprobs
=
[
float
(
"nan"
)]
+
all_logprobs
[
-
new_input_length
:].
squeeze
(
1
).
tolist
()
# Add to the list of finished generations with the original request
# Add to the list of finished generations with the original request
generated_texts
.
append
(
generated_texts
.
append
(
GeneratedText
(
GeneratedText
(
request
,
output
,
stopping_criteria
.
current_tokens
,
reason
request
=
request
,
output_text
=
output_text
,
generated_tokens
=
stopping_criteria
.
current_tokens
,
tokens
=
tokens
,
token_ids
=
token_ids
.
squeeze
(
1
).
tolist
(),
logprobs
=
logprobs
,
reason
=
reason
,
)
)
)
)
# add to the next batch
# add to the next batch
else
:
else
:
next_batch_keep_indices
.
append
(
i
)
next_batch_keep_indices
.
append
(
i
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_all_input_ids
.
append
(
all_input_ids
)
next_batch_all_logprobs
.
append
(
all_logprobs
)
next_batch_size
+=
1
next_batch_size
+=
1
new_input_length
=
input_length
+
1
next_batch_input_lengths
.
append
(
new_input_length
)
next_batch_input_lengths
.
append
(
new_input_length
)
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
new_input_length
next_batch_max_sequence_length
,
new_input_length
...
@@ -397,6 +431,7 @@ class CausalLM(Model):
...
@@ -397,6 +431,7 @@ class CausalLM(Model):
attention_mask
=
next_batch_attention_mask
,
attention_mask
=
next_batch_attention_mask
,
past_key_values
=
next_batch_past_key_values
,
past_key_values
=
next_batch_past_key_values
,
all_input_ids
=
next_batch_all_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
all_logprobs
=
next_batch_all_logprobs
,
input_lengths
=
next_batch_input_lengths
,
input_lengths
=
next_batch_input_lengths
,
next_token_choosers
=
next_batch_next_token_choosers
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
stopping_criterias
=
next_batch_stopping_criterias
,
...
...
server/text_generation/models/galactica.py
View file @
32a25306
...
@@ -321,12 +321,8 @@ class GalacticaSharded(Galactica):
...
@@ -321,12 +321,8 @@ class GalacticaSharded(Galactica):
)
)
# Logits are sharded, so we need to gather them
# Logits are sharded, so we need to gather them
logits_shard
=
outputs
.
logits
[:,
-
1
,
:].
contiguous
()
logits
=
[
torch
.
empty_like
(
outputs
.
logits
)
for
_
in
range
(
self
.
world_size
)]
torch
.
distributed
.
all_gather
(
logits
,
outputs
.
logits
,
group
=
self
.
process_group
)
batch_size
,
vocab_shard_size
=
logits_shard
.
shape
logits
=
torch
.
cat
(
logits
,
dim
=
2
)
vocab_size
=
self
.
world_size
*
vocab_shard_size
logits
=
[
torch
.
empty_like
(
logits_shard
)
for
_
in
range
(
self
.
world_size
)]
torch
.
distributed
.
all_gather
(
logits
,
logits_shard
,
group
=
self
.
process_group
)
logits
=
torch
.
cat
(
logits
,
dim
=
1
).
view
(
batch_size
,
1
,
vocab_size
)
return
logits
,
outputs
.
past_key_values
return
logits
,
outputs
.
past_key_values
server/text_generation/models/seq2seq_lm.py
View file @
32a25306
...
@@ -30,6 +30,7 @@ class Seq2SeqLMBatch:
...
@@ -30,6 +30,7 @@ class Seq2SeqLMBatch:
# Lengths of all generations present in the batch
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
input_lengths
:
List
[
int
]
decoder_input_lengths
:
List
[
int
]
decoder_input_lengths
:
List
[
int
]
decoder_logprobs
:
List
[
Optional
[
torch
.
Tensor
]]
# Generation helpers
# Generation helpers
next_token_choosers
:
List
[
NextTokenChooser
]
next_token_choosers
:
List
[
NextTokenChooser
]
...
@@ -60,6 +61,7 @@ class Seq2SeqLMBatch:
...
@@ -60,6 +61,7 @@ class Seq2SeqLMBatch:
decoder_input_ids
=
[]
decoder_input_ids
=
[]
decoder_input_lengths
=
[]
decoder_input_lengths
=
[]
decoder_logprobs
=
[]
# Parse batch
# Parse batch
for
r
in
pb
.
requests
:
for
r
in
pb
.
requests
:
...
@@ -72,6 +74,7 @@ class Seq2SeqLMBatch:
...
@@ -72,6 +74,7 @@ class Seq2SeqLMBatch:
stopping_criterias
.
append
(
stopping_criterias
.
append
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
)
)
decoder_logprobs
.
append
(
None
)
# Tokenize batch
# Tokenize batch
pad_to_multiple_of
=
8
if
"gpu"
in
str
(
device
)
else
None
pad_to_multiple_of
=
8
if
"gpu"
in
str
(
device
)
else
None
...
@@ -95,6 +98,7 @@ class Seq2SeqLMBatch:
...
@@ -95,6 +98,7 @@ class Seq2SeqLMBatch:
past_key_values
=
None
,
past_key_values
=
None
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
decoder_input_lengths
=
decoder_input_lengths
,
decoder_input_lengths
=
decoder_input_lengths
,
decoder_logprobs
=
decoder_logprobs
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
size
=
len
(
pb
.
requests
),
size
=
len
(
pb
.
requests
),
...
@@ -117,6 +121,7 @@ class Seq2SeqLMBatch:
...
@@ -117,6 +121,7 @@ class Seq2SeqLMBatch:
requests
=
[]
requests
=
[]
input_lengths
=
[]
input_lengths
=
[]
decoder_input_lengths
=
[]
decoder_input_lengths
=
[]
decoder_logprobs
=
[]
next_token_choosers
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
stopping_criterias
=
[]
...
@@ -137,6 +142,7 @@ class Seq2SeqLMBatch:
...
@@ -137,6 +142,7 @@ class Seq2SeqLMBatch:
requests
.
extend
(
batch
.
requests
)
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
input_lengths
.
extend
(
batch
.
input_lengths
)
decoder_input_lengths
.
extend
(
batch
.
decoder_input_lengths
)
decoder_input_lengths
.
extend
(
batch
.
decoder_input_lengths
)
decoder_logprobs
.
extend
(
batch
.
decoder_logprobs
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
...
@@ -286,6 +292,7 @@ class Seq2SeqLMBatch:
...
@@ -286,6 +292,7 @@ class Seq2SeqLMBatch:
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
decoder_input_lengths
=
decoder_input_lengths
,
decoder_input_lengths
=
decoder_input_lengths
,
decoder_logprobs
=
decoder_logprobs
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
size
=
total_batch_size
,
...
@@ -385,6 +392,7 @@ class Seq2SeqLM(Model):
...
@@ -385,6 +392,7 @@ class Seq2SeqLM(Model):
next_batch_input_lengths
=
[]
next_batch_input_lengths
=
[]
next_batch_decoder_input_ids
=
[]
next_batch_decoder_input_ids
=
[]
next_batch_decoder_input_lengths
=
[]
next_batch_decoder_input_lengths
=
[]
next_batch_decoder_logprobs
=
[]
# Metadata
# Metadata
next_batch_size
=
0
next_batch_size
=
0
...
@@ -399,6 +407,7 @@ class Seq2SeqLM(Model):
...
@@ -399,6 +407,7 @@ class Seq2SeqLM(Model):
batch
.
requests
,
batch
.
requests
,
batch
.
input_lengths
,
batch
.
input_lengths
,
batch
.
decoder_input_lengths
,
batch
.
decoder_input_lengths
,
batch
.
decoder_logprobs
,
logits
,
logits
,
batch
.
next_token_choosers
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
stopping_criterias
,
...
@@ -411,38 +420,58 @@ class Seq2SeqLM(Model):
...
@@ -411,38 +420,58 @@ class Seq2SeqLM(Model):
request
,
request
,
input_length
,
input_length
,
decoder_input_length
,
decoder_input_length
,
decoder_logprobs
,
logits
,
logits
,
next_token_chooser
,
next_token_chooser
,
stopping_criteria
,
stopping_criteria
,
input_tokens
,
input_tokens
,
decoder_
token
s
,
decoder_
input_id
s
,
)
in
enumerate
(
iterator
):
)
in
enumerate
(
iterator
):
all_tokens
=
torch
.
cat
([
input_tokens
,
decoder_tokens
])
# Select next token
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
]
)
next_token
,
logprobs
=
next_token_chooser
(
decoder_input_ids
,
logits
)
# Append next token to decoder tokens
# Append next token to decoder tokens
decoder_tokens
=
torch
.
cat
([
decoder_tokens
,
next_token
.
squeeze
(
1
)])
decoder_input_ids
=
torch
.
cat
([
decoder_input_ids
,
next_token
])
new_decoder_input_length
=
decoder_input_length
+
1
next_token_logprob
=
logprobs
[
-
1
,
next_token
]
if
decoder_logprobs
is
None
:
decoder_logprobs
=
next_token_logprob
else
:
decoder_logprobs
=
torch
.
cat
([
decoder_logprobs
,
next_token_logprob
])
# Evaluate stopping criteria
# Evaluate stopping criteria
stop
,
reason
=
stopping_criteria
(
decoder_
token
s
)
stop
,
reason
=
stopping_criteria
(
decoder_
input_id
s
)
if
stop
:
if
stop
:
# Decode tokens
# Slice with decoder_input_length to remove padding
output
=
self
.
tokenizer
.
decode
(
decoder_tokens
,
skip_special_tokens
=
True
)
# Decode all tokens
token_ids
=
decoder_input_ids
[
-
new_decoder_input_length
:]
output_text
=
self
.
tokenizer
.
decode
(
token_ids
,
skip_special_tokens
=
True
)
tokens
=
self
.
tokenizer
.
batch_decode
(
token_ids
)
# Add NaN for the bos token
logprobs
=
[
float
(
"nan"
)]
+
decoder_logprobs
[
-
new_decoder_input_length
:
].
tolist
()
# Add to the list of finished generations with the original request
# Add to the list of finished generations with the original request
generated_texts
.
append
(
generated_texts
.
append
(
GeneratedText
(
GeneratedText
(
request
,
output
,
stopping_criteria
.
current_tokens
,
reason
request
=
request
,
output_text
=
output_text
,
generated_tokens
=
stopping_criteria
.
current_tokens
,
tokens
=
tokens
,
token_ids
=
token_ids
.
tolist
(),
logprobs
=
logprobs
,
reason
=
reason
,
)
)
)
)
# add to the next batch
# add to the next batch
else
:
else
:
next_batch_keep_indices
.
append
(
i
)
next_batch_keep_indices
.
append
(
i
)
next_batch_decoder_input_ids
.
append
(
decoder_
token
s
.
unsqueeze
(
0
))
next_batch_decoder_input_ids
.
append
(
decoder_
input_id
s
.
unsqueeze
(
0
))
next_batch_size
+=
1
next_batch_size
+=
1
new_decoder_input_length
=
decoder_input_length
+
1
next_batch_input_lengths
.
append
(
input_length
)
next_batch_input_lengths
.
append
(
input_length
)
next_batch_decoder_input_lengths
.
append
(
new_decoder_input_length
)
next_batch_decoder_input_lengths
.
append
(
new_decoder_input_length
)
next_batch_decoder_logprobs
.
append
(
decoder_logprobs
)
next_batch_max_input_length
=
max
(
next_batch_max_input_length
=
max
(
next_batch_max_input_length
,
input_length
next_batch_max_input_length
,
input_length
)
)
...
@@ -515,6 +544,7 @@ class Seq2SeqLM(Model):
...
@@ -515,6 +544,7 @@ class Seq2SeqLM(Model):
past_key_values
=
next_batch_past_key_values
,
past_key_values
=
next_batch_past_key_values
,
input_lengths
=
next_batch_input_lengths
,
input_lengths
=
next_batch_input_lengths
,
decoder_input_lengths
=
next_batch_decoder_input_lengths
,
decoder_input_lengths
=
next_batch_decoder_input_lengths
,
decoder_logprobs
=
next_batch_decoder_logprobs
,
next_token_choosers
=
next_batch_next_token_choosers
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
size
=
next_batch_size
,
...
...
server/text_generation/models/types.py
View file @
32a25306
...
@@ -30,14 +30,20 @@ class Batch(ABC):
...
@@ -30,14 +30,20 @@ class Batch(ABC):
@
dataclass
@
dataclass
class
GeneratedText
:
class
GeneratedText
:
request
:
generate_pb2
.
Request
request
:
generate_pb2
.
Request
output
:
str
output_text
:
str
tokens
:
int
generated_tokens
:
int
tokens
:
List
[
str
]
token_ids
:
List
[
int
]
logprobs
:
List
[
float
]
reason
:
str
reason
:
str
def
to_pb
(
self
)
->
generate_pb2
.
GeneratedText
:
def
to_pb
(
self
)
->
generate_pb2
.
GeneratedText
:
return
generate_pb2
.
GeneratedText
(
return
generate_pb2
.
GeneratedText
(
request
=
self
.
request
,
request
=
self
.
request
,
output
=
self
.
output
,
output_text
=
self
.
output_text
,
generated_tokens
=
self
.
generated_tokens
,
tokens
=
self
.
tokens
,
tokens
=
self
.
tokens
,
token_ids
=
self
.
token_ids
,
logprobs
=
self
.
logprobs
,
finish_reason
=
self
.
reason
,
finish_reason
=
self
.
reason
,
)
)
server/text_generation/utils.py
View file @
32a25306
...
@@ -55,12 +55,16 @@ class NextTokenChooser:
...
@@ -55,12 +55,16 @@ class NextTokenChooser:
self
.
choice
=
Sampling
()
if
sampling
else
Greedy
()
self
.
choice
=
Sampling
()
if
sampling
else
Greedy
()
def
__call__
(
self
,
input_ids
,
scores
):
def
__call__
(
self
,
input_ids
,
scores
):
# Warp logits
scores
=
self
.
warpers
(
input_ids
,
scores
)
scores
=
self
.
warpers
(
input_ids
,
scores
)
# Compute logprobs
logprobs
=
torch
.
log_softmax
(
scores
,
-
1
)
# Choose tokens
next_ids
=
self
.
choice
(
scores
)
next_ids
=
self
.
choice
(
scores
)
return
next_ids
.
unsqueeze
(
-
1
)
return
next_ids
,
logprobs
@
classmethod
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
LogitsWarp
erParameters
)
->
"NextTokenChooser"
:
def
from_pb
(
cls
,
pb
:
generate_pb2
.
NextTokenChoos
erParameters
)
->
"NextTokenChooser"
:
return
NextTokenChooser
(
return
NextTokenChooser
(
temperature
=
pb
.
temperature
,
temperature
=
pb
.
temperature
,
top_k
=
pb
.
top_k
,
top_k
=
pb
.
top_k
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment