Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
50b495f3
Unverified
Commit
50b495f3
authored
Dec 14, 2023
by
OlivierDehaene
Committed by
GitHub
Dec 14, 2023
Browse files
feat: add more latency metrics in forward (#1346)
parent
44b267ab
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
238 additions
and
108 deletions
+238
-108
benchmark/src/generation.rs
benchmark/src/generation.rs
+1
-1
proto/generate.proto
proto/generate.proto
+14
-0
router/client/src/client.rs
router/client/src/client.rs
+52
-4
router/client/src/sharded_client.rs
router/client/src/sharded_client.rs
+33
-18
router/src/infer.rs
router/src/infer.rs
+15
-2
router/src/validation.rs
router/src/validation.rs
+2
-2
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+16
-16
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+18
-16
server/tests/models/test_santacoder.py
server/tests/models/test_santacoder.py
+4
-4
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+14
-14
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+12
-5
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+15
-8
server/text_generation_server/models/idefics_causal_lm.py
server/text_generation_server/models/idefics_causal_lm.py
+11
-10
server/text_generation_server/models/model.py
server/text_generation_server/models/model.py
+3
-1
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+12
-4
server/text_generation_server/server.py
server/text_generation_server/server.py
+15
-2
server/text_generation_server/utils/tokens.py
server/text_generation_server/utils/tokens.py
+1
-1
No files found.
benchmark/src/generation.rs
View file @
50b495f3
...
@@ -163,7 +163,7 @@ async fn prefill(
...
@@ -163,7 +163,7 @@ async fn prefill(
// Run prefill
// Run prefill
let
start_time
=
Instant
::
now
();
let
start_time
=
Instant
::
now
();
let
(
_
,
decode_batch
)
=
client
.prefill
(
batch
.clone
())
.await
?
;
let
(
_
,
decode_batch
,
_
)
=
client
.prefill
(
batch
.clone
())
.await
?
;
// Get latency
// Get latency
let
latency
=
start_time
.elapsed
();
let
latency
=
start_time
.elapsed
();
...
...
proto/generate.proto
View file @
50b495f3
...
@@ -182,6 +182,12 @@ message PrefillResponse {
...
@@ -182,6 +182,12 @@ message PrefillResponse {
repeated
Generation
generations
=
1
;
repeated
Generation
generations
=
1
;
/// Next batch (cached)
/// Next batch (cached)
optional
CachedBatch
batch
=
2
;
optional
CachedBatch
batch
=
2
;
/// Forward elapsed time in nanoseconds
uint64
forward_ns
=
3
;
/// Decode elapsed time in nanoseconds
uint64
decode_ns
=
4
;
/// Total elapsed time in nanoseconds
uint64
total_ns
=
5
;
}
}
message
DecodeRequest
{
message
DecodeRequest
{
...
@@ -194,6 +200,14 @@ message DecodeResponse {
...
@@ -194,6 +200,14 @@ message DecodeResponse {
repeated
Generation
generations
=
1
;
repeated
Generation
generations
=
1
;
/// Next batch (cached)
/// Next batch (cached)
optional
CachedBatch
batch
=
2
;
optional
CachedBatch
batch
=
2
;
/// Forward elapsed time in nanoseconds
uint64
forward_ns
=
3
;
/// Decode elapsed time in nanoseconds
uint64
decode_ns
=
4
;
/// Total elapsed time in nanoseconds
uint64
total_ns
=
5
;
/// Concatenate elapsed time in nanoseconds
optional
uint64
concat_ns
=
6
;
}
}
message
WarmupRequest
{
message
WarmupRequest
{
...
...
router/client/src/client.rs
View file @
50b495f3
...
@@ -4,6 +4,7 @@ use crate::pb::generate::v2::*;
...
@@ -4,6 +4,7 @@ use crate::pb::generate::v2::*;
use
crate
::
Result
;
use
crate
::
Result
;
use
grpc_metadata
::
InjectTelemetryContext
;
use
grpc_metadata
::
InjectTelemetryContext
;
use
std
::
cmp
::
min
;
use
std
::
cmp
::
min
;
use
std
::
time
::
Duration
;
use
tonic
::
transport
::{
Channel
,
Uri
};
use
tonic
::
transport
::{
Channel
,
Uri
};
use
tracing
::
instrument
;
use
tracing
::
instrument
;
...
@@ -157,10 +158,14 @@ impl Client {
...
@@ -157,10 +158,14 @@ impl Client {
pub
async
fn
prefill
(
pub
async
fn
prefill
(
&
mut
self
,
&
mut
self
,
batch
:
Batch
,
batch
:
Batch
,
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
)
>
{
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
,
PrefillTimings
)
>
{
let
request
=
tonic
::
Request
::
new
(
PrefillRequest
{
batch
:
Some
(
batch
)
})
.inject_context
();
let
request
=
tonic
::
Request
::
new
(
PrefillRequest
{
batch
:
Some
(
batch
)
})
.inject_context
();
let
response
=
self
.stub
.prefill
(
request
)
.await
?
.into_inner
();
let
response
=
self
.stub
.prefill
(
request
)
.await
?
.into_inner
();
Ok
((
response
.generations
,
response
.batch
))
Ok
((
response
.generations
,
response
.batch
,
PrefillTimings
::
new
(
response
.forward_ns
,
response
.decode_ns
,
response
.total_ns
),
))
}
}
/// Generate one token for each request in the given cached batches
/// Generate one token for each request in the given cached batches
...
@@ -171,9 +176,52 @@ impl Client {
...
@@ -171,9 +176,52 @@ impl Client {
pub
async
fn
decode
(
pub
async
fn
decode
(
&
mut
self
,
&
mut
self
,
batches
:
Vec
<
CachedBatch
>
,
batches
:
Vec
<
CachedBatch
>
,
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
)
>
{
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
,
DecodeTimings
)
>
{
let
request
=
tonic
::
Request
::
new
(
DecodeRequest
{
batches
})
.inject_context
();
let
request
=
tonic
::
Request
::
new
(
DecodeRequest
{
batches
})
.inject_context
();
let
response
=
self
.stub
.decode
(
request
)
.await
?
.into_inner
();
let
response
=
self
.stub
.decode
(
request
)
.await
?
.into_inner
();
Ok
((
response
.generations
,
response
.batch
))
Ok
((
response
.generations
,
response
.batch
,
DecodeTimings
::
new
(
response
.concat_ns
,
response
.forward_ns
,
response
.decode_ns
,
response
.total_ns
,
),
))
}
}
pub
struct
PrefillTimings
{
pub
forward
:
Duration
,
pub
decode
:
Duration
,
pub
total
:
Duration
,
}
impl
PrefillTimings
{
fn
new
(
forward_ns
:
u64
,
decode_ns
:
u64
,
total_ns
:
u64
)
->
Self
{
Self
{
forward
:
Duration
::
from_nanos
(
forward_ns
),
decode
:
Duration
::
from_nanos
(
decode_ns
),
total
:
Duration
::
from_nanos
(
total_ns
),
}
}
}
pub
struct
DecodeTimings
{
pub
concat
:
Option
<
Duration
>
,
pub
forward
:
Duration
,
pub
decode
:
Duration
,
pub
total
:
Duration
,
}
impl
DecodeTimings
{
fn
new
(
concat_ns
:
Option
<
u64
>
,
forward_ns
:
u64
,
decode_ns
:
u64
,
total_ns
:
u64
)
->
Self
{
Self
{
concat
:
concat_ns
.map
(|
v
|
Duration
::
from_nanos
(
v
)),
forward
:
Duration
::
from_nanos
(
forward_ns
),
decode
:
Duration
::
from_nanos
(
decode_ns
),
total
:
Duration
::
from_nanos
(
total_ns
),
}
}
}
}
}
router/client/src/sharded_client.rs
View file @
50b495f3
use
crate
::
client
::{
DecodeTimings
,
PrefillTimings
};
/// Multi shard Client
/// Multi shard Client
use
crate
::{
Batch
,
CachedBatch
,
Client
,
Generation
,
HealthResponse
,
ShardInfo
};
use
crate
::{
Batch
,
CachedBatch
,
Client
,
Generation
,
HealthResponse
,
ShardInfo
};
use
crate
::{
ClientError
,
Result
};
use
crate
::{
ClientError
,
Result
};
...
@@ -116,49 +117,63 @@ impl ShardedClient {
...
@@ -116,49 +117,63 @@ impl ShardedClient {
///
///
/// Returns Generation for each request in batch
/// Returns Generation for each request in batch
/// and the next cached batch
/// and the next cached batch
#[instrument(skip_all,
fields(id
=
&
batch
.
id,
size
=
&
batch
.
size))]
#[instrument(skip_all,
fields(id
=
&
batch
.
id,
size
=
&
batch
.
size))]
pub
async
fn
prefill
(
pub
async
fn
prefill
(
&
mut
self
,
&
mut
self
,
batch
:
Batch
,
batch
:
Batch
,
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
)
>
{
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
,
PrefillTimings
)
>
{
let
futures
:
Vec
<
_
>
=
self
let
futures
:
Vec
<
_
>
=
self
.clients
.clients
.iter_mut
()
.iter_mut
()
.map
(|
client
|
Box
::
pin
(
client
.prefill
(
batch
.clone
())))
.map
(|
client
|
Box
::
pin
(
client
.prefill
(
batch
.clone
())))
.collect
();
.collect
();
let
results
:
Result
<
Vec
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
)
>>
=
let
results
:
Result
<
Vec
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
,
PrefillTimings
)
>>
=
join_all
(
futures
)
.await
.into_iter
()
.collect
();
join_all
(
futures
)
.await
.into_iter
()
.collect
();
merge_generations
(
results
?
)
let
mut
results
=
results
?
;
let
(
mut
generations
,
next_batch
,
mut
timings
)
=
results
.pop
()
.ok_or
(
ClientError
::
EmptyResults
)
?
;
// Merge generations from different model shards
for
(
mut
shard_generations
,
_
,
shard_timings
)
in
results
.into_iter
()
{
generations
.append
(
&
mut
shard_generations
);
// Return the timings of the slowest shard
if
shard_timings
.total
>
timings
.total
{
timings
=
shard_timings
;
}
}
Ok
((
generations
,
next_batch
,
timings
))
}
}
/// Generate one token for each request in the given cached batches
/// Generate one token for each request in the given cached batches
///
///
/// Returns Generation for each request in batches
/// Returns Generation for each request in batches
/// and the next cached batch
/// and the next cached batch
#[instrument(skip_all,
fields(size
=
batches
.
iter()
.
map(
|
batch
|
{
batch
.
size
}
)
.
sum::
<
u32
>
()))]
#[instrument(skip_all,
fields(size
=
batches
.
iter()
.
map(
|
batch
|
{
batch
.
size
}
)
.
sum::
<
u32
>
()))]
pub
async
fn
decode
(
pub
async
fn
decode
(
&
mut
self
,
&
mut
self
,
batches
:
Vec
<
CachedBatch
>
,
batches
:
Vec
<
CachedBatch
>
,
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
)
>
{
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
,
DecodeTimings
)
>
{
let
futures
:
Vec
<
_
>
=
self
let
futures
:
Vec
<
_
>
=
self
.clients
.clients
.iter_mut
()
.iter_mut
()
.map
(|
client
|
Box
::
pin
(
client
.decode
(
batches
.clone
())))
.map
(|
client
|
Box
::
pin
(
client
.decode
(
batches
.clone
())))
.collect
();
.collect
();
let
results
:
Result
<
Vec
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
)
>>
=
let
results
:
Result
<
Vec
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
,
DecodeTimings
)
>>
=
join_all
(
futures
)
.await
.into_iter
()
.collect
();
join_all
(
futures
)
.await
.into_iter
()
.collect
();
merge_generations
(
results
?
)
let
mut
results
=
results
?
;
}
}
/// Merge generations from the different model shards
let
(
mut
generations
,
next_batch
,
mut
timings
)
=
fn
merge_generations
(
results
.pop
()
.ok_or
(
ClientError
::
EmptyResults
)
?
;
mut
results
:
Vec
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
)
>
,
)
->
Result
<
(
Vec
<
Generation
>
,
Option
<
CachedBatch
>
)
>
{
let
(
mut
generations
,
next_batch
)
=
results
.pop
()
.ok_or
(
ClientError
::
EmptyResults
)
?
;
for
(
mut
shard_generations
,
_
)
in
results
.into_iter
()
{
// Merge generations from different model shards
for
(
mut
shard_generations
,
_
,
shard_timings
)
in
results
.into_iter
()
{
generations
.append
(
&
mut
shard_generations
);
generations
.append
(
&
mut
shard_generations
);
// Return the timings of the slowest shard
if
shard_timings
.total
>
timings
.total
{
timings
=
shard_timings
;
}
}
Ok
((
generations
,
next_batch
,
timings
))
}
}
Ok
((
generations
,
next_batch
))
}
}
router/src/infer.rs
View file @
50b495f3
...
@@ -379,15 +379,20 @@ async fn prefill(
...
@@ -379,15 +379,20 @@ async fn prefill(
metrics
::
increment_counter!
(
"tgi_batch_inference_count"
,
"method"
=>
"prefill"
);
metrics
::
increment_counter!
(
"tgi_batch_inference_count"
,
"method"
=>
"prefill"
);
match
client
.prefill
(
batch
)
.await
{
match
client
.prefill
(
batch
)
.await
{
Ok
((
generations
,
next_batch
))
=>
{
Ok
((
generations
,
next_batch
,
timings
))
=>
{
// Update health
// Update health
generation_health
.store
(
true
,
Ordering
::
SeqCst
);
generation_health
.store
(
true
,
Ordering
::
SeqCst
);
let
start_filtering_time
=
Instant
::
now
();
// Send generated tokens and filter stopped entries
// Send generated tokens and filter stopped entries
filter_send_generations
(
generations
,
entries
);
filter_send_generations
(
generations
,
entries
);
// Filter next batch and remove requests that were stopped
// Filter next batch and remove requests that were stopped
let
next_batch
=
filter_batch
(
client
,
next_batch
,
entries
)
.await
;
let
next_batch
=
filter_batch
(
client
,
next_batch
,
entries
)
.await
;
metrics
::
histogram!
(
"tgi_batch_forward_duration"
,
timings
.forward
.as_secs_f64
(),
"method"
=>
"prefill"
);
metrics
::
histogram!
(
"tgi_batch_decode_duration"
,
timings
.decode
.as_secs_f64
(),
"method"
=>
"prefill"
);
metrics
::
histogram!
(
"tgi_batch_filter_duration"
,
start_filtering_time
.elapsed
()
.as_secs_f64
(),
"method"
=>
"prefill"
);
metrics
::
histogram!
(
"tgi_batch_inference_duration"
,
start_time
.elapsed
()
.as_secs_f64
(),
"method"
=>
"prefill"
);
metrics
::
histogram!
(
"tgi_batch_inference_duration"
,
start_time
.elapsed
()
.as_secs_f64
(),
"method"
=>
"prefill"
);
metrics
::
increment_counter!
(
"tgi_batch_inference_success"
,
"method"
=>
"prefill"
);
metrics
::
increment_counter!
(
"tgi_batch_inference_success"
,
"method"
=>
"prefill"
);
next_batch
next_batch
...
@@ -416,15 +421,23 @@ async fn decode(
...
@@ -416,15 +421,23 @@ async fn decode(
metrics
::
increment_counter!
(
"tgi_batch_inference_count"
,
"method"
=>
"decode"
);
metrics
::
increment_counter!
(
"tgi_batch_inference_count"
,
"method"
=>
"decode"
);
match
client
.decode
(
batches
)
.await
{
match
client
.decode
(
batches
)
.await
{
Ok
((
generations
,
next_batch
))
=>
{
Ok
((
generations
,
next_batch
,
timings
))
=>
{
// Update health
// Update health
generation_health
.store
(
true
,
Ordering
::
SeqCst
);
generation_health
.store
(
true
,
Ordering
::
SeqCst
);
let
start_filtering_time
=
Instant
::
now
();
// Send generated tokens and filter stopped entries
// Send generated tokens and filter stopped entries
filter_send_generations
(
generations
,
entries
);
filter_send_generations
(
generations
,
entries
);
// Filter next batch and remove requests that were stopped
// Filter next batch and remove requests that were stopped
let
next_batch
=
filter_batch
(
client
,
next_batch
,
entries
)
.await
;
let
next_batch
=
filter_batch
(
client
,
next_batch
,
entries
)
.await
;
if
let
Some
(
concat_duration
)
=
timings
.concat
{
metrics
::
histogram!
(
"tgi_batch_concat_duration"
,
concat_duration
.as_secs_f64
(),
"method"
=>
"decode"
);
}
metrics
::
histogram!
(
"tgi_batch_forward_duration"
,
timings
.forward
.as_secs_f64
(),
"method"
=>
"decode"
);
metrics
::
histogram!
(
"tgi_batch_decode_duration"
,
timings
.decode
.as_secs_f64
(),
"method"
=>
"decode"
);
metrics
::
histogram!
(
"tgi_batch_filter_duration"
,
start_filtering_time
.elapsed
()
.as_secs_f64
(),
"method"
=>
"decode"
);
metrics
::
histogram!
(
"tgi_batch_inference_duration"
,
start_time
.elapsed
()
.as_secs_f64
(),
"method"
=>
"decode"
);
metrics
::
histogram!
(
"tgi_batch_inference_duration"
,
start_time
.elapsed
()
.as_secs_f64
(),
"method"
=>
"decode"
);
metrics
::
increment_counter!
(
"tgi_batch_inference_success"
,
"method"
=>
"decode"
);
metrics
::
increment_counter!
(
"tgi_batch_inference_success"
,
"method"
=>
"decode"
);
next_batch
next_batch
...
...
router/src/validation.rs
View file @
50b495f3
...
@@ -540,7 +540,7 @@ mod tests {
...
@@ -540,7 +540,7 @@ mod tests {
let
max_stop_sequence
=
3
;
let
max_stop_sequence
=
3
;
let
max_top_n_tokens
=
4
;
let
max_top_n_tokens
=
4
;
let
max_input_length
=
5
;
let
max_input_length
=
5
;
let
max_total_tokens
=
6
;
let
max_total_tokens
=
10
6
;
let
workers
=
1
;
let
workers
=
1
;
let
validation
=
Validation
::
new
(
let
validation
=
Validation
::
new
(
workers
,
workers
,
...
@@ -600,7 +600,7 @@ mod tests {
...
@@ -600,7 +600,7 @@ mod tests {
let
max_stop_sequences
=
3
;
let
max_stop_sequences
=
3
;
let
max_top_n_tokens
=
4
;
let
max_top_n_tokens
=
4
;
let
max_input_length
=
5
;
let
max_input_length
=
5
;
let
max_total_tokens
=
6
;
let
max_total_tokens
=
10
6
;
let
workers
=
1
;
let
workers
=
1
;
let
validation
=
Validation
::
new
(
let
validation
=
Validation
::
new
(
workers
,
workers
,
...
...
server/tests/models/test_bloom.py
View file @
50b495f3
...
@@ -103,7 +103,7 @@ def test_causal_lm_batch_type(default_bloom):
...
@@ -103,7 +103,7 @@ def test_causal_lm_batch_type(default_bloom):
def
test_causal_lm_generate_token
(
default_bloom
,
default_bloom_batch
):
def
test_causal_lm_generate_token
(
default_bloom
,
default_bloom_batch
):
sequence_length
=
len
(
default_bloom_batch
.
all_input_ids
[
0
])
sequence_length
=
len
(
default_bloom_batch
.
all_input_ids
[
0
])
generations
,
next_batch
=
default_bloom
.
generate_token
(
default_bloom_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
default_bloom_batch
)
assert
len
(
generations
)
==
len
(
default_bloom_batch
)
assert
len
(
generations
)
==
len
(
default_bloom_batch
)
assert
isinstance
(
next_batch
,
CausalLMBatch
)
assert
isinstance
(
next_batch
,
CausalLMBatch
)
...
@@ -153,10 +153,10 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
...
@@ -153,10 +153,10 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
def
test_causal_lm_generate_token_completion
(
default_bloom
,
default_bloom_batch
):
def
test_causal_lm_generate_token_completion
(
default_bloom
,
default_bloom_batch
):
next_batch
=
default_bloom_batch
next_batch
=
default_bloom_batch
for
_
in
range
(
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
for
_
in
range
(
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
default_bloom_batch
)
assert
len
(
generations
)
==
len
(
default_bloom_batch
)
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generations
)
==
1
assert
len
(
generations
)
==
1
...
@@ -178,10 +178,10 @@ def test_causal_lm_generate_token_completion_multi(
...
@@ -178,10 +178,10 @@ def test_causal_lm_generate_token_completion_multi(
for
i
in
range
(
for
i
in
range
(
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
1
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
1
):
):
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
default_multi_requests_bloom_batch
)
assert
len
(
generations
)
==
len
(
default_multi_requests_bloom_batch
)
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generations
)
==
2
assert
len
(
generations
)
==
2
...
@@ -201,10 +201,10 @@ def test_causal_lm_generate_token_completion_multi(
...
@@ -201,10 +201,10 @@ def test_causal_lm_generate_token_completion_multi(
for
_
in
range
(
for
_
in
range
(
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
):
):
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generations
)
==
1
assert
len
(
generations
)
==
1
...
@@ -224,11 +224,11 @@ def test_batch_concatenate(
...
@@ -224,11 +224,11 @@ def test_batch_concatenate(
default_bloom
,
default_bloom_batch
,
default_multi_requests_bloom_batch
default_bloom
,
default_bloom_batch
,
default_multi_requests_bloom_batch
):
):
next_batch_0
=
default_bloom_batch
next_batch_0
=
default_bloom_batch
_
,
next_batch_0
=
default_bloom
.
generate_token
(
next_batch_0
)
_
,
next_batch_0
,
_
=
default_bloom
.
generate_token
(
next_batch_0
)
_
,
next_batch_0
=
default_bloom
.
generate_token
(
next_batch_0
)
_
,
next_batch_0
,
_
=
default_bloom
.
generate_token
(
next_batch_0
)
next_batch_1
=
default_multi_requests_bloom_batch
next_batch_1
=
default_multi_requests_bloom_batch
_
,
next_batch_1
=
default_bloom
.
generate_token
(
next_batch_1
)
_
,
next_batch_1
,
_
=
default_bloom
.
generate_token
(
next_batch_1
)
# Clone past_key_values before concatenating to compare after,
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
# because they are removed from the concatenated batches
...
@@ -288,10 +288,10 @@ def test_batch_concatenate(
...
@@ -288,10 +288,10 @@ def test_batch_concatenate(
for
_
in
range
(
for
_
in
range
(
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
):
):
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generations
)
==
3
assert
len
(
generations
)
==
3
...
@@ -313,10 +313,10 @@ def test_batch_concatenate(
...
@@ -313,10 +313,10 @@ def test_batch_concatenate(
-
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
-
2
):
):
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generations
)
==
2
assert
len
(
generations
)
==
2
...
@@ -337,10 +337,10 @@ def test_batch_concatenate(
...
@@ -337,10 +337,10 @@ def test_batch_concatenate(
-
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
4
-
4
):
):
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generations
)
==
1
assert
len
(
generations
)
==
1
...
...
server/tests/models/test_causal_lm.py
View file @
50b495f3
...
@@ -99,7 +99,9 @@ def test_causal_lm_batch_type(default_causal_lm):
...
@@ -99,7 +99,9 @@ def test_causal_lm_batch_type(default_causal_lm):
def
test_causal_lm_generate_token
(
default_causal_lm
,
default_causal_lm_batch
):
def
test_causal_lm_generate_token
(
default_causal_lm
,
default_causal_lm_batch
):
sequence_length
=
len
(
default_causal_lm_batch
.
all_input_ids
[
0
])
sequence_length
=
len
(
default_causal_lm_batch
.
all_input_ids
[
0
])
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
default_causal_lm_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
default_causal_lm_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
isinstance
(
next_batch
,
CausalLMBatch
)
assert
isinstance
(
next_batch
,
CausalLMBatch
)
...
@@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion(
...
@@ -151,10 +153,10 @@ def test_causal_lm_generate_token_completion(
):
):
next_batch
=
default_causal_lm_batch
next_batch
=
default_causal_lm_batch
for
_
in
range
(
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
for
_
in
range
(
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generations
)
==
1
assert
len
(
generations
)
==
1
...
@@ -174,10 +176,10 @@ def test_causal_lm_generate_token_completion_multi(
...
@@ -174,10 +176,10 @@ def test_causal_lm_generate_token_completion_multi(
for
i
in
range
(
for
i
in
range
(
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
1
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
1
):
):
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generations
)
==
2
assert
len
(
generations
)
==
2
...
@@ -200,10 +202,10 @@ def test_causal_lm_generate_token_completion_multi(
...
@@ -200,10 +202,10 @@ def test_causal_lm_generate_token_completion_multi(
for
_
in
range
(
for
_
in
range
(
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
stopping_criterias
[
0
].
max_new_tokens
-
stopping_criterias
[
1
].
max_new_tokens
-
1
):
):
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generations
)
==
1
assert
len
(
generations
)
==
1
...
@@ -222,11 +224,11 @@ def test_batch_concatenate(
...
@@ -222,11 +224,11 @@ def test_batch_concatenate(
default_causal_lm
,
default_causal_lm_batch
,
default_multi_requests_causal_lm_batch
default_causal_lm
,
default_causal_lm_batch
,
default_multi_requests_causal_lm_batch
):
):
next_batch_0
=
default_causal_lm_batch
next_batch_0
=
default_causal_lm_batch
_
,
next_batch_0
=
default_causal_lm
.
generate_token
(
next_batch_0
)
_
,
next_batch_0
,
_
=
default_causal_lm
.
generate_token
(
next_batch_0
)
_
,
next_batch_0
=
default_causal_lm
.
generate_token
(
next_batch_0
)
_
,
next_batch_0
,
_
=
default_causal_lm
.
generate_token
(
next_batch_0
)
next_batch_1
=
default_multi_requests_causal_lm_batch
next_batch_1
=
default_multi_requests_causal_lm_batch
_
,
next_batch_1
=
default_causal_lm
.
generate_token
(
next_batch_1
)
_
,
next_batch_1
,
_
=
default_causal_lm
.
generate_token
(
next_batch_1
)
# Clone past_key_values before concatenating to compare after,
# Clone past_key_values before concatenating to compare after,
# because they are removed from the concatenated batches
# because they are removed from the concatenated batches
...
@@ -285,10 +287,10 @@ def test_batch_concatenate(
...
@@ -285,10 +287,10 @@ def test_batch_concatenate(
for
_
in
range
(
for
_
in
range
(
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
):
):
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generations
)
==
3
assert
len
(
generations
)
==
3
...
@@ -311,10 +313,10 @@ def test_batch_concatenate(
...
@@ -311,10 +313,10 @@ def test_batch_concatenate(
-
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
-
2
):
):
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generations
)
==
2
assert
len
(
generations
)
==
2
...
@@ -333,10 +335,10 @@ def test_batch_concatenate(
...
@@ -333,10 +335,10 @@ def test_batch_concatenate(
-
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
4
-
4
):
):
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generations
)
==
1
assert
len
(
generations
)
==
1
...
...
server/tests/models/test_santacoder.py
View file @
50b495f3
...
@@ -55,10 +55,10 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
...
@@ -55,10 +55,10 @@ def test_santacoder_generate_token_completion(default_santacoder, default_pb_bat
next_batch
=
batch
next_batch
=
batch
for
_
in
range
(
batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
for
_
in
range
(
batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
generations
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_santacoder
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_santacoder
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generations
)
==
1
assert
len
(
generations
)
==
1
...
@@ -83,10 +83,10 @@ def test_fim_santacoder_generate_token_completion(
...
@@ -83,10 +83,10 @@ def test_fim_santacoder_generate_token_completion(
next_batch
=
batch
next_batch
=
batch
for
_
in
range
(
batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
for
_
in
range
(
batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
generations
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_santacoder
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_santacoder
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_santacoder
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generations
)
==
1
assert
len
(
generations
)
==
1
...
...
server/tests/models/test_seq2seq_lm.py
View file @
50b495f3
...
@@ -103,7 +103,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
...
@@ -103,7 +103,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
def
test_seq2seq_lm_generate_token
(
default_seq2seq_lm
,
default_seq2seq_lm_batch
):
def
test_seq2seq_lm_generate_token
(
default_seq2seq_lm
,
default_seq2seq_lm_batch
):
sequence_length
=
len
(
default_seq2seq_lm_batch
.
input_ids
[
0
])
sequence_length
=
len
(
default_seq2seq_lm_batch
.
input_ids
[
0
])
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
generations
,
next_batch
,
_
=
default_seq2seq_lm
.
generate_token
(
default_seq2seq_lm_batch
default_seq2seq_lm_batch
)
)
...
@@ -173,10 +173,10 @@ def test_seq2seq_lm_generate_token_completion(
...
@@ -173,10 +173,10 @@ def test_seq2seq_lm_generate_token_completion(
):
):
next_batch
=
default_seq2seq_lm_batch
next_batch
=
default_seq2seq_lm_batch
for
_
in
range
(
6
):
for
_
in
range
(
6
):
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generations
)
==
1
assert
len
(
generations
)
==
1
...
@@ -191,10 +191,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
...
@@ -191,10 +191,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch
=
default_multi_requests_seq2seq_lm_batch
next_batch
=
default_multi_requests_seq2seq_lm_batch
for
i
in
range
(
4
):
for
i
in
range
(
4
):
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generations
)
==
2
assert
len
(
generations
)
==
2
...
@@ -207,10 +207,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
...
@@ -207,10 +207,10 @@ def test_seq2seq_lm_generate_token_completion_multi(
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
].
id
])
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
0
].
id
])
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generations
)
==
1
assert
len
(
generations
)
==
1
...
@@ -228,11 +228,11 @@ def test_batch_concatenate(
...
@@ -228,11 +228,11 @@ def test_batch_concatenate(
default_multi_requests_seq2seq_lm_batch
,
default_multi_requests_seq2seq_lm_batch
,
):
):
next_batch_0
=
default_seq2seq_lm_batch
next_batch_0
=
default_seq2seq_lm_batch
_
,
next_batch_0
=
default_seq2seq_lm
.
generate_token
(
next_batch_0
)
_
,
next_batch_0
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch_0
)
_
,
next_batch_0
=
default_seq2seq_lm
.
generate_token
(
next_batch_0
)
_
,
next_batch_0
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch_0
)
next_batch_1
=
default_multi_requests_seq2seq_lm_batch
next_batch_1
=
default_multi_requests_seq2seq_lm_batch
_
,
next_batch_1
=
default_seq2seq_lm
.
generate_token
(
next_batch_1
)
_
,
next_batch_1
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch_1
)
# Copy hidden state because it is removed from the concatenated branches
# Copy hidden state because it is removed from the concatenated branches
next_batch_0_encoder_last_hidden_state
=
next_batch_0
.
encoder_last_hidden_state
next_batch_0_encoder_last_hidden_state
=
next_batch_0
.
encoder_last_hidden_state
...
@@ -324,10 +324,10 @@ def test_batch_concatenate(
...
@@ -324,10 +324,10 @@ def test_batch_concatenate(
)
)
for
_
in
range
(
3
):
for
_
in
range
(
3
):
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
assert
len
(
generations
)
==
len
(
next_batch
)
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generations
)
==
3
assert
len
(
generations
)
==
3
...
@@ -342,7 +342,7 @@ def test_batch_concatenate(
...
@@ -342,7 +342,7 @@ def test_batch_concatenate(
[
next_batch
.
requests
[
0
].
id
,
next_batch
.
requests
[
1
].
id
]
[
next_batch
.
requests
[
0
].
id
,
next_batch
.
requests
[
1
].
id
]
)
)
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generations
)
==
2
assert
len
(
generations
)
==
2
...
@@ -352,7 +352,7 @@ def test_batch_concatenate(
...
@@ -352,7 +352,7 @@ def test_batch_concatenate(
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
1
].
id
])
next_batch
=
next_batch
.
filter
([
next_batch
.
requests
[
1
].
id
])
generations
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
generations
,
next_batch
,
_
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generations
)
==
1
assert
len
(
generations
)
==
1
...
...
server/text_generation_server/models/causal_lm.py
View file @
50b495f3
from
text_generation_server.utils.tokens
import
batch_top_tokens
import
torch
import
torch
import
inspect
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
opentelemetry
import
trace
...
@@ -8,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
...
@@ -8,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenize
from
typing
import
Optional
,
Tuple
,
List
,
Type
,
Dict
from
typing
import
Optional
,
Tuple
,
List
,
Type
,
Dict
from
text_generation_server.models
import
Model
from
text_generation_server.models
import
Model
from
text_generation_server.utils.tokens
import
batch_top_tokens
from
text_generation_server.models.types
import
(
from
text_generation_server.models.types
import
(
Batch
,
Batch
,
Tokens
,
Tokens
,
...
@@ -564,7 +564,8 @@ class CausalLM(Model):
...
@@ -564,7 +564,8 @@ class CausalLM(Model):
@
tracer
.
start_as_current_span
(
"generate_token"
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
def
generate_token
(
def
generate_token
(
self
,
batch
:
CausalLMBatch
self
,
batch
:
CausalLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
CausalLMBatch
]]:
)
->
Tuple
[
List
[
Generation
],
Optional
[
CausalLMBatch
],
Tuple
[
int
,
int
]]:
start
=
time
.
time_ns
()
# slice the attention mask to the correct shape
# slice the attention mask to the correct shape
attention_mask
=
batch
.
attention_mask
[:,
:
-
batch
.
padding_right_offset
]
attention_mask
=
batch
.
attention_mask
[:,
:
-
batch
.
padding_right_offset
]
...
@@ -585,6 +586,8 @@ class CausalLM(Model):
...
@@ -585,6 +586,8 @@ class CausalLM(Model):
torch
.
log_softmax
(
logits
[:,
-
1
],
-
1
),
torch
.
log_softmax
(
logits
[:,
-
1
],
-
1
),
)
)
start_decode
=
time
.
time_ns
()
# Zipped iterator
# Zipped iterator
iterator
=
zip
(
iterator
=
zip
(
batch
.
requests
,
batch
.
requests
,
...
@@ -731,7 +734,9 @@ class CausalLM(Model):
...
@@ -731,7 +734,9 @@ class CausalLM(Model):
# We finished all generations in the batch; there is no next batch
# We finished all generations in the batch; there is no next batch
if
stopped
:
if
stopped
:
return
generations
,
None
forward_ns
=
start_decode
-
start
decode_ns
=
time
.
time_ns
()
-
start_decode
return
generations
,
None
,
(
forward_ns
,
decode_ns
)
# Slice unused values from prefill
# Slice unused values from prefill
batch
.
input_ids
=
batch
.
input_ids
[:,
:
1
]
batch
.
input_ids
=
batch
.
input_ids
[:,
:
1
]
...
@@ -747,4 +752,6 @@ class CausalLM(Model):
...
@@ -747,4 +752,6 @@ class CausalLM(Model):
# Update past key values
# Update past key values
batch
.
past_key_values
=
past
batch
.
past_key_values
=
past
return
generations
,
batch
forward_ns
=
start_decode
-
start
decode_ns
=
time
.
time_ns
()
-
start_decode
return
generations
,
batch
,
(
forward_ns
,
decode_ns
)
server/text_generation_server/models/flash_causal_lm.py
View file @
50b495f3
import
math
import
math
import
time
import
itertools
import
itertools
from
text_generation_server.utils.tokens
import
batch_top_tokens
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -9,9 +9,10 @@ import numpy as np
...
@@ -9,9 +9,10 @@ import numpy as np
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
transformers
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
typing
import
Optional
,
Tuple
,
List
,
Type
,
Union
,
Dict
from
typing
import
Optional
,
Tuple
,
List
,
Type
,
Dict
from
text_generation_server.models
import
Model
from
text_generation_server.models
import
Model
from
text_generation_server.utils.tokens
import
batch_top_tokens
from
text_generation_server.utils.speculate
import
get_speculate
from
text_generation_server.utils.speculate
import
get_speculate
from
text_generation_server.models.types
import
(
from
text_generation_server.models.types
import
(
Batch
,
Batch
,
...
@@ -689,7 +690,7 @@ class FlashCausalLM(Model):
...
@@ -689,7 +690,7 @@ class FlashCausalLM(Model):
self
.
dtype
,
self
.
dtype
,
self
.
device
,
self
.
device
,
)
)
_
,
batch
=
self
.
generate_token
(
batch
)
_
,
batch
,
_
=
self
.
generate_token
(
batch
)
except
torch
.
cuda
.
OutOfMemoryError
as
e
:
except
torch
.
cuda
.
OutOfMemoryError
as
e
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Not enough memory to handle
{
len
(
batch
.
input_ids
)
}
prefill tokens. "
f
"Not enough memory to handle
{
len
(
batch
.
input_ids
)
}
prefill tokens. "
...
@@ -799,7 +800,8 @@ class FlashCausalLM(Model):
...
@@ -799,7 +800,8 @@ class FlashCausalLM(Model):
@
tracer
.
start_as_current_span
(
"generate_token"
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
def
generate_token
(
def
generate_token
(
self
,
batch
:
FlashCausalLMBatch
self
,
batch
:
FlashCausalLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
FlashCausalLMBatch
]]:
)
->
Tuple
[
List
[
Generation
],
Optional
[
FlashCausalLMBatch
],
Tuple
[
int
,
int
]]:
start
=
time
.
time_ns
()
prefill
=
batch
.
cu_seqlen_prefill
is
not
None
prefill
=
batch
.
cu_seqlen_prefill
is
not
None
prefill_logprobs
=
batch
.
prefill_next_token_indices
is
not
None
prefill_logprobs
=
batch
.
prefill_next_token_indices
is
not
None
...
@@ -941,6 +943,8 @@ class FlashCausalLM(Model):
...
@@ -941,6 +943,8 @@ class FlashCausalLM(Model):
# GPU <-> CPU sync
# GPU <-> CPU sync
next_token_logprobs
=
next_token_logprobs
.
tolist
()
next_token_logprobs
=
next_token_logprobs
.
tolist
()
next_token_ids
=
next_input_ids
.
tolist
()
next_token_ids
=
next_input_ids
.
tolist
()
accepted_ids
=
accepted_ids
.
tolist
()
start_decode
=
time
.
time_ns
()
# Zipped iterator
# Zipped iterator
iterator
=
zip
(
iterator
=
zip
(
...
@@ -977,7 +981,6 @@ class FlashCausalLM(Model):
...
@@ -977,7 +981,6 @@ class FlashCausalLM(Model):
# Append next token to all tokens
# Append next token to all tokens
next_token_texts
=
[]
next_token_texts
=
[]
left
=
0
left
=
0
before
=
stopping_criteria
.
current_tokens
current_stopped
=
False
current_stopped
=
False
for
j
in
range
(
index
,
index
+
n_accepted_ids
):
for
j
in
range
(
index
,
index
+
n_accepted_ids
):
...
@@ -1092,7 +1095,7 @@ class FlashCausalLM(Model):
...
@@ -1092,7 +1095,7 @@ class FlashCausalLM(Model):
generations
.
append
(
generation
)
generations
.
append
(
generation
)
# Update values
# Update values
batch
.
input_lengths
[
i
]
=
input_length
+
n_accepted_ids
.
item
()
batch
.
input_lengths
[
i
]
=
input_length
+
n_accepted_ids
if
batch
.
input_lengths
[
i
]
>
batch
.
max_seqlen
:
if
batch
.
input_lengths
[
i
]
>
batch
.
max_seqlen
:
batch
.
max_seqlen
=
batch
.
input_lengths
[
i
]
batch
.
max_seqlen
=
batch
.
input_lengths
[
i
]
batch
.
prefix_offsets
[
i
]
=
prefix_offset
batch
.
prefix_offsets
[
i
]
=
prefix_offset
...
@@ -1102,10 +1105,14 @@ class FlashCausalLM(Model):
...
@@ -1102,10 +1105,14 @@ class FlashCausalLM(Model):
if
stopped
:
if
stopped
:
del
batch
del
batch
# No need to return a batch if we know that all requests stopped
# No need to return a batch if we know that all requests stopped
return
generations
,
None
forward_ns
=
start_decode
-
start
decode_ns
=
time
.
time_ns
()
-
start_decode
return
generations
,
None
,
(
forward_ns
,
decode_ns
)
batch
.
prefill_cu_outlens
=
None
batch
.
prefill_cu_outlens
=
None
batch
.
prefill_head_indices
=
None
batch
.
prefill_head_indices
=
None
batch
.
prefill_next_token_indices
=
None
batch
.
prefill_next_token_indices
=
None
return
generations
,
batch
forward_ns
=
start_decode
-
start
decode_ns
=
time
.
time_ns
()
-
start_decode
return
generations
,
batch
,
(
forward_ns
,
decode_ns
)
server/text_generation_server/models/idefics_causal_lm.py
View file @
50b495f3
import
torch
import
torch
import
inspect
import
time
import
re
from
io
import
BytesIO
import
base64
from
PIL
import
Image
import
re
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
transformers
import
(
from
transformers
import
(
AutoProcessor
,
AutoProcessor
,
AutoTokenizer
,
AutoTokenizer
,
AutoModelForCausalLM
,
PreTrainedTokenizerBase
,
PreTrainedTokenizerBase
,
ProcessorMixin
,
ProcessorMixin
,
)
)
...
@@ -670,7 +664,8 @@ class IdeficsCausalLM(Model):
...
@@ -670,7 +664,8 @@ class IdeficsCausalLM(Model):
@
tracer
.
start_as_current_span
(
"generate_token"
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
def
generate_token
(
def
generate_token
(
self
,
batch
:
IdeficsCausalLMBatch
self
,
batch
:
IdeficsCausalLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
IdeficsCausalLMBatch
]]:
)
->
Tuple
[
List
[
Generation
],
Optional
[
IdeficsCausalLMBatch
],
Tuple
[
int
,
int
]]:
start
=
time
.
time_ns
()
# slice the attention mask to the correct shape
# slice the attention mask to the correct shape
attention_mask
=
batch
.
attention_mask
[:,
:
-
batch
.
padding_right_offset
]
attention_mask
=
batch
.
attention_mask
[:,
:
-
batch
.
padding_right_offset
]
if
batch
.
input_ids
.
size
(
1
)
==
1
:
if
batch
.
input_ids
.
size
(
1
)
==
1
:
...
@@ -699,6 +694,8 @@ class IdeficsCausalLM(Model):
...
@@ -699,6 +694,8 @@ class IdeficsCausalLM(Model):
# Hardcoded remove image tokens
# Hardcoded remove image tokens
logits
[:,
32000
:
32001
]
=
torch
.
finfo
(
logits
.
dtype
).
min
logits
[:,
32000
:
32001
]
=
torch
.
finfo
(
logits
.
dtype
).
min
start_decode
=
time
.
time_ns
()
# Results
# Results
generations
:
List
[
Generation
]
=
[]
generations
:
List
[
Generation
]
=
[]
stopped
=
True
stopped
=
True
...
@@ -827,7 +824,9 @@ class IdeficsCausalLM(Model):
...
@@ -827,7 +824,9 @@ class IdeficsCausalLM(Model):
# We finished all generations in the batch; there is no next batch
# We finished all generations in the batch; there is no next batch
if
stopped
:
if
stopped
:
return
generations
,
None
forward_ns
=
start_decode
-
start
decode_ns
=
time
.
time_ns
()
-
start_decode
return
generations
,
None
,
(
forward_ns
,
decode_ns
)
# Slice unused values from prefill
# Slice unused values from prefill
batch
.
input_ids
=
batch
.
input_ids
[:,
:
1
]
batch
.
input_ids
=
batch
.
input_ids
[:,
:
1
]
...
@@ -847,4 +846,6 @@ class IdeficsCausalLM(Model):
...
@@ -847,4 +846,6 @@ class IdeficsCausalLM(Model):
batch
.
past_key_values
=
past
batch
.
past_key_values
=
past
batch
.
image_hidden_states
=
image_hidden_states
batch
.
image_hidden_states
=
image_hidden_states
return
generations
,
batch
forward_ns
=
start_decode
-
start
decode_ns
=
time
.
time_ns
()
-
start_decode
return
generations
,
batch
,
(
forward_ns
,
decode_ns
)
server/text_generation_server/models/model.py
View file @
50b495f3
...
@@ -65,7 +65,9 @@ class Model(ABC):
...
@@ -65,7 +65,9 @@ class Model(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
generate_token
(
self
,
batch
:
B
)
->
Tuple
[
List
[
Generation
],
Optional
[
B
]]:
def
generate_token
(
self
,
batch
:
B
)
->
Tuple
[
List
[
Generation
],
Optional
[
B
],
Tuple
[
int
,
int
]]:
raise
NotImplementedError
raise
NotImplementedError
def
warmup
(
self
,
batch
:
B
)
->
Optional
[
int
]:
def
warmup
(
self
,
batch
:
B
)
->
Optional
[
int
]:
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
50b495f3
from
text_generation_server.utils.tokens
import
batch_top_tokens
import
torch
import
torch
import
time
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
transformers
import
AutoTokenizer
,
AutoModelForSeq2SeqLM
,
PreTrainedTokenizerBase
from
transformers
import
AutoTokenizer
,
AutoModelForSeq2SeqLM
,
PreTrainedTokenizerBase
from
typing
import
Optional
,
Tuple
,
List
,
Type
,
Dict
from
typing
import
Optional
,
Tuple
,
List
,
Type
,
Dict
from
text_generation_server.utils.tokens
import
batch_top_tokens
from
text_generation_server.models
import
Model
from
text_generation_server.models
import
Model
from
text_generation_server.models.types
import
(
from
text_generation_server.models.types
import
(
GeneratedText
,
GeneratedText
,
...
@@ -613,7 +614,8 @@ class Seq2SeqLM(Model):
...
@@ -613,7 +614,8 @@ class Seq2SeqLM(Model):
@
tracer
.
start_as_current_span
(
"generate_token"
)
@
tracer
.
start_as_current_span
(
"generate_token"
)
def
generate_token
(
def
generate_token
(
self
,
batch
:
Seq2SeqLMBatch
self
,
batch
:
Seq2SeqLMBatch
)
->
Tuple
[
List
[
Generation
],
Optional
[
Seq2SeqLMBatch
]]:
)
->
Tuple
[
List
[
Generation
],
Optional
[
Seq2SeqLMBatch
],
Tuple
[
int
,
int
]]:
start
=
time
.
time_ns
()
if
batch
.
decoder_attention_mask
is
not
None
:
if
batch
.
decoder_attention_mask
is
not
None
:
# slice to the correct shape
# slice to the correct shape
decoder_attention_mask
=
batch
.
decoder_attention_mask
[
decoder_attention_mask
=
batch
.
decoder_attention_mask
[
...
@@ -644,6 +646,8 @@ class Seq2SeqLM(Model):
...
@@ -644,6 +646,8 @@ class Seq2SeqLM(Model):
torch
.
log_softmax
(
logits
[:,
-
1
],
-
1
),
torch
.
log_softmax
(
logits
[:,
-
1
],
-
1
),
)
)
start_decode
=
time
.
time_ns
()
# Finished requests
# Finished requests
generations
:
List
[
Generation
]
=
[]
generations
:
List
[
Generation
]
=
[]
stopped
=
True
stopped
=
True
...
@@ -788,7 +792,9 @@ class Seq2SeqLM(Model):
...
@@ -788,7 +792,9 @@ class Seq2SeqLM(Model):
# We finished all generations in the batch; there is no next batch
# We finished all generations in the batch; there is no next batch
if
stopped
:
if
stopped
:
return
generations
,
None
forward_ns
=
start_decode
-
start
decode_ns
=
time
.
time_ns
()
-
start_decode
return
generations
,
None
,
(
forward_ns
,
decode_ns
)
# We don't need input_ids after the prefill forward
# We don't need input_ids after the prefill forward
batch
.
input_ids
=
None
batch
.
input_ids
=
None
...
@@ -799,4 +805,6 @@ class Seq2SeqLM(Model):
...
@@ -799,4 +805,6 @@ class Seq2SeqLM(Model):
batch
.
decoder_attention_mask
[:,
-
batch
.
padding_right_offset
]
=
1
batch
.
decoder_attention_mask
[:,
-
batch
.
padding_right_offset
]
=
1
batch
.
padding_right_offset
-=
1
batch
.
padding_right_offset
-=
1
return
generations
,
batch
forward_ns
=
start_decode
-
start
decode_ns
=
time
.
time_ns
()
-
start_decode
return
generations
,
batch
,
(
forward_ns
,
decode_ns
)
server/text_generation_server/server.py
View file @
50b495f3
import
asyncio
import
asyncio
import
os
import
os
import
torch
import
torch
import
time
from
grpc
import
aio
from
grpc
import
aio
from
loguru
import
logger
from
loguru
import
logger
...
@@ -76,6 +77,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -76,6 +77,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
)
)
async
def
Prefill
(
self
,
request
,
context
):
async
def
Prefill
(
self
,
request
,
context
):
start
=
time
.
time_ns
()
if
(
if
(
self
.
model
.
batch_type
==
IdeficsCausalLMBatch
self
.
model
.
batch_type
==
IdeficsCausalLMBatch
):
# Hack, i would rather use kwargs in the `from_pb` call
):
# Hack, i would rather use kwargs in the `from_pb` call
...
@@ -91,15 +93,19 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -91,15 +93,19 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
dtype
,
self
.
model
.
device
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
dtype
,
self
.
model
.
device
)
)
generations
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
generations
,
next_batch
,
timings
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
next_batch
)
self
.
cache
.
set
(
next_batch
)
return
generate_pb2
.
PrefillResponse
(
return
generate_pb2
.
PrefillResponse
(
generations
=
[
generation
.
to_pb
()
for
generation
in
generations
],
generations
=
[
generation
.
to_pb
()
for
generation
in
generations
],
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
forward_ns
=
timings
[
0
],
decode_ns
=
timings
[
1
],
total_ns
=
time
.
time_ns
()
-
start
,
)
)
async
def
Decode
(
self
,
request
,
context
):
async
def
Decode
(
self
,
request
,
context
):
start
=
time
.
time_ns
()
if
len
(
request
.
batches
)
==
0
:
if
len
(
request
.
batches
)
==
0
:
raise
ValueError
(
"Must provide at least one batch"
)
raise
ValueError
(
"Must provide at least one batch"
)
...
@@ -114,16 +120,23 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -114,16 +120,23 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
raise
ValueError
(
"All batches are empty"
)
raise
ValueError
(
"All batches are empty"
)
if
len
(
batches
)
>
1
:
if
len
(
batches
)
>
1
:
start_concat
=
time
.
time_ns
()
batch
=
self
.
model
.
batch_type
.
concatenate
(
batches
)
batch
=
self
.
model
.
batch_type
.
concatenate
(
batches
)
concat_ns
=
time
.
time_ns
()
-
start_concat
else
:
else
:
batch
=
batches
[
0
]
batch
=
batches
[
0
]
concat_ns
=
None
generations
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
generations
,
next_batch
,
timings
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
next_batch
)
self
.
cache
.
set
(
next_batch
)
return
generate_pb2
.
DecodeResponse
(
return
generate_pb2
.
DecodeResponse
(
generations
=
[
generation
.
to_pb
()
for
generation
in
generations
],
generations
=
[
generation
.
to_pb
()
for
generation
in
generations
],
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
batch
=
next_batch
.
to_pb
()
if
next_batch
else
None
,
concat_ns
=
concat_ns
,
forward_ns
=
timings
[
0
],
decode_ns
=
timings
[
1
],
total_ns
=
time
.
time_ns
()
-
start
,
)
)
...
...
server/text_generation_server/utils/tokens.py
View file @
50b495f3
...
@@ -92,7 +92,7 @@ class NextTokenChooser:
...
@@ -92,7 +92,7 @@ class NextTokenChooser:
class
StopSequenceCriteria
:
class
StopSequenceCriteria
:
def
__init__
(
self
,
stop_sequence
:
str
):
def
__init__
(
self
,
stop_sequence
:
str
):
stop_sequence
=
re
.
escape
(
stop_sequence
)
stop_sequence
=
re
.
escape
(
stop_sequence
)
self
.
regex
=
re
.
compile
(
f
"
.*
{
stop_sequence
}
$"
)
self
.
regex
=
re
.
compile
(
f
"
{
stop_sequence
}
$"
)
def
__call__
(
self
,
output
:
str
)
->
bool
:
def
__call__
(
self
,
output
:
str
)
->
bool
:
if
self
.
regex
.
findall
(
output
):
if
self
.
regex
.
findall
(
output
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment