Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
718096f6
Unverified
Commit
718096f6
authored
Dec 12, 2022
by
OlivierDehaene
Committed by
GitHub
Dec 12, 2022
Browse files
feat: Support stop sequences (#7)
parent
042180d8
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
254 additions
and
107 deletions
+254
-107
README.md
README.md
+2
-0
proto/generate.proto
proto/generate.proto
+15
-2
router/client/src/lib.rs
router/client/src/lib.rs
+3
-1
router/src/batcher.rs
router/src/batcher.rs
+2
-0
router/src/db.rs
router/src/db.rs
+15
-2
router/src/lib.rs
router/src/lib.rs
+3
-0
router/src/server.rs
router/src/server.rs
+4
-5
router/src/validation.rs
router/src/validation.rs
+10
-0
server/tests/conftest.py
server/tests/conftest.py
+5
-0
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+15
-7
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+25
-28
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+15
-7
server/tests/test_utils.py
server/tests/test_utils.py
+45
-0
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+13
-16
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+3
-16
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+11
-14
server/text_generation/models/types.py
server/text_generation/models/types.py
+5
-1
server/text_generation/utils.py
server/text_generation/utils.py
+63
-8
No files found.
README.md
View file @
718096f6
...
@@ -15,6 +15,8 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
...
@@ -15,6 +15,8 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
-
Quantization with
[
bitsandbytes
](
https://github.com/TimDettmers/bitsandbytes
)
-
Quantization with
[
bitsandbytes
](
https://github.com/TimDettmers/bitsandbytes
)
-
[
Safetensors
](
https://github.com/huggingface/safetensors
)
weight loading
-
[
Safetensors
](
https://github.com/huggingface/safetensors
)
weight loading
-
45ms per token generation for BLOOM with 8xA100 80GB
-
45ms per token generation for BLOOM with 8xA100 80GB
-
Logits warpers (temperature scaling, topk ...)
-
Stop sequences
## Officially supported models
## Officially supported models
...
...
proto/generate.proto
View file @
718096f6
...
@@ -28,12 +28,23 @@ message ClearCacheRequest {}
...
@@ -28,12 +28,23 @@ message ClearCacheRequest {}
message
ClearCacheResponse
{}
message
ClearCacheResponse
{}
message
LogitsWarperParameters
{
message
LogitsWarperParameters
{
/// exponential scaling output probability distribution
float
temperature
=
1
;
float
temperature
=
1
;
/// restricting to the k highest probability elements
uint32
top_k
=
2
;
uint32
top_k
=
2
;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float
top_p
=
3
;
float
top_p
=
3
;
/// apply sampling on the logits
bool
do_sample
=
4
;
bool
do_sample
=
4
;
}
}
message
StoppingCriteriaParameters
{
/// Maximum number of generated tokens
uint32
max_new_tokens
=
1
;
/// Optional stopping sequences
repeated
string
stop_sequences
=
2
;
}
message
Request
{
message
Request
{
/// Request ID
/// Request ID
uint64
id
=
1
;
uint64
id
=
1
;
...
@@ -43,8 +54,8 @@ message Request {
...
@@ -43,8 +54,8 @@ message Request {
uint32
input_length
=
3
;
uint32
input_length
=
3
;
/// Logits Warper Parameters
/// Logits Warper Parameters
LogitsWarperParameters
parameters
=
4
;
LogitsWarperParameters
parameters
=
4
;
/// Stopping
c
riteria
/// Stopping
C
riteria
Parameters
uint32
max_new_token
s
=
5
;
StoppingCriteriaParameters
stopping_parameter
s
=
5
;
}
}
message
Batch
{
message
Batch
{
...
@@ -63,6 +74,8 @@ message GeneratedText {
...
@@ -63,6 +74,8 @@ message GeneratedText {
string
output
=
2
;
string
output
=
2
;
/// Number of generated tokens
/// Number of generated tokens
uint32
tokens
=
3
;
uint32
tokens
=
3
;
/// Finish reason
string
finish_reason
=
4
;
}
}
message
GenerateRequest
{
message
GenerateRequest
{
...
...
router/client/src/lib.rs
View file @
718096f6
...
@@ -6,7 +6,9 @@ mod pb;
...
@@ -6,7 +6,9 @@ mod pb;
mod
sharded_client
;
mod
sharded_client
;
pub
use
client
::
Client
;
pub
use
client
::
Client
;
pub
use
pb
::
generate
::
v1
::{
Batch
,
GeneratedText
,
LogitsWarperParameters
,
Request
};
pub
use
pb
::
generate
::
v1
::{
Batch
,
GeneratedText
,
LogitsWarperParameters
,
Request
,
StoppingCriteriaParameters
,
};
pub
use
sharded_client
::
ShardedClient
;
pub
use
sharded_client
::
ShardedClient
;
use
thiserror
::
Error
;
use
thiserror
::
Error
;
use
tonic
::
transport
;
use
tonic
::
transport
;
...
...
router/src/batcher.rs
View file @
718096f6
...
@@ -190,6 +190,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
...
@@ -190,6 +190,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
let
response
=
InferResponse
{
let
response
=
InferResponse
{
output
:
output
.output
,
output
:
output
.output
,
tokens
:
output
.tokens
,
tokens
:
output
.tokens
,
finish_reason
:
output
.finish_reason
,
queued
:
entry
.time
,
queued
:
entry
.time
,
start
:
entry
.batch_time
.unwrap
(),
// unwrap is always valid
start
:
entry
.batch_time
.unwrap
(),
// unwrap is always valid
end
:
Instant
::
now
(),
end
:
Instant
::
now
(),
...
@@ -203,6 +204,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
...
@@ -203,6 +204,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
pub
(
crate
)
struct
InferResponse
{
pub
(
crate
)
struct
InferResponse
{
pub
(
crate
)
output
:
String
,
pub
(
crate
)
output
:
String
,
pub
(
crate
)
tokens
:
u32
,
pub
(
crate
)
tokens
:
u32
,
pub
(
crate
)
finish_reason
:
String
,
pub
(
crate
)
queued
:
Instant
,
pub
(
crate
)
queued
:
Instant
,
pub
(
crate
)
start
:
Instant
,
pub
(
crate
)
start
:
Instant
,
pub
(
crate
)
end
:
Instant
,
pub
(
crate
)
end
:
Instant
,
...
...
router/src/db.rs
View file @
718096f6
...
@@ -4,7 +4,9 @@ use crate::{GenerateParameters, GenerateRequest};
...
@@ -4,7 +4,9 @@ use crate::{GenerateParameters, GenerateRequest};
use
parking_lot
::
Mutex
;
use
parking_lot
::
Mutex
;
use
std
::
collections
::
BTreeMap
;
use
std
::
collections
::
BTreeMap
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
use
text_generation_client
::{
Batch
,
ClientError
,
LogitsWarperParameters
,
Request
};
use
text_generation_client
::{
Batch
,
ClientError
,
LogitsWarperParameters
,
Request
,
StoppingCriteriaParameters
,
};
use
tokio
::
sync
::
oneshot
::
Sender
;
use
tokio
::
sync
::
oneshot
::
Sender
;
use
tokio
::
time
::
Instant
;
use
tokio
::
time
::
Instant
;
...
@@ -72,7 +74,9 @@ impl State {
...
@@ -72,7 +74,9 @@ impl State {
parameters
:
Some
(
LogitsWarperParameters
::
from
(
parameters
:
Some
(
LogitsWarperParameters
::
from
(
entry
.request.parameters
.clone
(),
entry
.request.parameters
.clone
(),
)),
)),
max_new_tokens
:
entry
.request.parameters.max_new_tokens
,
stopping_parameters
:
Some
(
StoppingCriteriaParameters
::
from
(
entry
.request.parameters
.clone
(),
)),
});
});
ids
.push
(
*
id
);
ids
.push
(
*
id
);
...
@@ -168,3 +172,12 @@ impl From<GenerateParameters> for LogitsWarperParameters {
...
@@ -168,3 +172,12 @@ impl From<GenerateParameters> for LogitsWarperParameters {
}
}
}
}
}
}
impl
From
<
GenerateParameters
>
for
StoppingCriteriaParameters
{
fn
from
(
parameters
:
GenerateParameters
)
->
Self
{
Self
{
stop_sequences
:
parameters
.stop
,
max_new_tokens
:
parameters
.max_new_tokens
,
}
}
}
router/src/lib.rs
View file @
718096f6
...
@@ -21,6 +21,7 @@ pub(crate) struct GenerateParameters {
...
@@ -21,6 +21,7 @@ pub(crate) struct GenerateParameters {
pub
do_sample
:
bool
,
pub
do_sample
:
bool
,
#[serde(default
=
"default_max_new_tokens"
)]
#[serde(default
=
"default_max_new_tokens"
)]
pub
max_new_tokens
:
u32
,
pub
max_new_tokens
:
u32
,
pub
stop
:
Vec
<
String
>
,
}
}
fn
default_temperature
()
->
f32
{
fn
default_temperature
()
->
f32
{
...
@@ -50,6 +51,7 @@ fn default_parameters() -> GenerateParameters {
...
@@ -50,6 +51,7 @@ fn default_parameters() -> GenerateParameters {
top_p
:
default_top_p
(),
top_p
:
default_top_p
(),
do_sample
:
default_do_sample
(),
do_sample
:
default_do_sample
(),
max_new_tokens
:
default_max_new_tokens
(),
max_new_tokens
:
default_max_new_tokens
(),
stop
:
vec!
[],
}
}
}
}
...
@@ -63,6 +65,7 @@ pub(crate) struct GenerateRequest {
...
@@ -63,6 +65,7 @@ pub(crate) struct GenerateRequest {
#[derive(Serialize)]
#[derive(Serialize)]
pub
(
crate
)
struct
GeneratedText
{
pub
(
crate
)
struct
GeneratedText
{
pub
generated_text
:
String
,
pub
generated_text
:
String
,
pub
finish_reason
:
String
,
}
}
#[derive(Serialize)]
#[derive(Serialize)]
...
...
router/src/server.rs
View file @
718096f6
...
@@ -53,6 +53,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
...
@@ -53,6 +53,7 @@ async fn health(state: Extension<ServerState>) -> Result<(), (StatusCode, Json<E
top_p
:
1.0
,
top_p
:
1.0
,
do_sample
:
false
,
do_sample
:
false
,
max_new_tokens
:
1
,
max_new_tokens
:
1
,
stop
:
vec!
[],
},
},
},
},
)
)
...
@@ -88,11 +89,8 @@ async fn generate(
...
@@ -88,11 +89,8 @@ async fn generate(
})
?
;
})
?
;
// Validate request
// Validate request
let
(
input_length
,
validated_request
)
=
state
let
(
input_length
,
validated_request
)
=
.validation
state
.validation
.validate
(
req
.0
)
.await
.map_err
(|
err
|
{
.validate
(
req
.0
)
.await
.map_err
(|
err
|
{
tracing
::
error!
(
"{}"
,
err
.to_string
());
tracing
::
error!
(
"{}"
,
err
.to_string
());
err
err
})
?
;
})
?
;
...
@@ -148,6 +146,7 @@ async fn generate(
...
@@ -148,6 +146,7 @@ async fn generate(
// Send response
// Send response
let
response
=
vec!
[
GeneratedText
{
let
response
=
vec!
[
GeneratedText
{
generated_text
:
response
.output
,
generated_text
:
response
.output
,
finish_reason
:
response
.finish_reason
,
}];
}];
Ok
((
headers
,
Json
(
response
)))
Ok
((
headers
,
Json
(
response
)))
}
}
...
...
router/src/validation.rs
View file @
718096f6
...
@@ -121,6 +121,14 @@ fn validation_worker(
...
@@ -121,6 +121,14 @@ fn validation_worker(
.unwrap_or
(());
.unwrap_or
(());
continue
;
continue
;
}
}
if
request
.parameters.stop
.len
()
>
4
{
response_tx
.send
(
Err
(
ValidationError
::
StopSequence
(
request
.parameters.stop
.len
(),
)))
.unwrap_or
(());
continue
;
}
// Get the number of tokens in the input
// Get the number of tokens in the input
match
tokenizer
.encode
(
request
.inputs
.clone
(),
false
)
{
match
tokenizer
.encode
(
request
.inputs
.clone
(),
false
)
{
...
@@ -163,6 +171,8 @@ pub enum ValidationError {
...
@@ -163,6 +171,8 @@ pub enum ValidationError {
MaxNewTokens
,
MaxNewTokens
,
#[error(
"inputs must have less than {1} tokens. Given: {0}"
)]
#[error(
"inputs must have less than {1} tokens. Given: {0}"
)]
InputLength
(
usize
,
usize
),
InputLength
(
usize
,
usize
),
#[error(
"stop supports up to 4 stop sequences. Given: {0}"
)]
StopSequence
(
usize
),
#[error(
"tokenizer error {0}"
)]
#[error(
"tokenizer error {0}"
)]
Tokenizer
(
String
),
Tokenizer
(
String
),
}
}
...
...
server/tests/conftest.py
View file @
718096f6
...
@@ -15,6 +15,11 @@ def default_pb_parameters():
...
@@ -15,6 +15,11 @@ def default_pb_parameters():
)
)
@
pytest
.
fixture
def
default_pb_stop_parameters
():
return
generate_pb2
.
StoppingCriteriaParameters
(
stop_sequences
=
[],
max_new_tokens
=
10
)
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
bloom_560m_tokenizer
():
def
bloom_560m_tokenizer
():
return
AutoTokenizer
.
from_pretrained
(
"bigscience/bloom-560m"
,
padding_side
=
"left"
)
return
AutoTokenizer
.
from_pretrained
(
"bigscience/bloom-560m"
,
padding_side
=
"left"
)
...
...
server/tests/models/test_bloom.py
View file @
718096f6
...
@@ -9,13 +9,13 @@ from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
...
@@ -9,13 +9,13 @@ from text_generation.models.bloom import BloomCausalLMBatch, BLOOM
@
pytest
.
fixture
@
pytest
.
fixture
def
default_pb_request
(
default_pb_parameters
):
def
default_pb_request
(
default_pb_parameters
,
default_pb_stop_parameters
):
return
generate_pb2
.
Request
(
return
generate_pb2
.
Request
(
id
=
0
,
id
=
0
,
inputs
=
"Test"
,
inputs
=
"Test"
,
input_length
=
1
,
input_length
=
1
,
parameters
=
default_pb_parameters
,
parameters
=
default_pb_parameters
,
max_new_tokens
=
10
,
stopping_parameters
=
default_pb_stop_parameters
,
)
)
...
@@ -36,7 +36,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
...
@@ -36,7 +36,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
req_0
=
copy
(
default_pb_request
)
req_0
=
copy
(
default_pb_request
)
req_1
=
default_pb_request
req_1
=
default_pb_request
req_1
.
id
=
1
req_1
.
id
=
1
req_1
.
max_new_tokens
=
5
req_1
.
stopping_parameters
.
max_new_tokens
=
5
batch_pb
=
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
req_0
,
req_1
],
size
=
2
)
batch_pb
=
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
req_0
,
req_1
],
size
=
2
)
return
BloomCausalLMBatch
.
from_pb
(
return
BloomCausalLMBatch
.
from_pb
(
...
@@ -56,7 +56,6 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
...
@@ -56,7 +56,6 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert
batch
.
requests
==
default_pb_batch
.
requests
assert
batch
.
requests
==
default_pb_batch
.
requests
assert
len
(
batch
.
input_ids
)
==
default_pb_batch
.
size
assert
len
(
batch
.
input_ids
)
==
default_pb_batch
.
size
assert
len
(
batch
.
input_ids
[
0
])
==
8
assert
batch
.
input_ids
[
0
][
-
1
]
==
10264
assert
batch
.
input_ids
[
0
][
-
1
]
==
10264
assert
torch
.
all
(
batch
.
input_ids
[
0
][:
-
1
]
==
3
)
assert
torch
.
all
(
batch
.
input_ids
[
0
][:
-
1
]
==
3
)
...
@@ -85,6 +84,7 @@ def test_causal_lm_batch_type(default_bloom):
...
@@ -85,6 +84,7 @@ def test_causal_lm_batch_type(default_bloom):
def
test_causal_lm_generate_token
(
default_bloom
,
default_bloom_batch
):
def
test_causal_lm_generate_token
(
default_bloom
,
default_bloom_batch
):
sequence_length
=
len
(
default_bloom_batch
.
all_input_ids
[
0
])
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
default_bloom_batch
)
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
default_bloom_batch
)
assert
generated_texts
==
[]
assert
generated_texts
==
[]
...
@@ -92,7 +92,11 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
...
@@ -92,7 +92,11 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert
not
next_batch
.
keys_head_dim_last
assert
not
next_batch
.
keys_head_dim_last
assert
len
(
next_batch
.
all_input_ids
)
==
next_batch
.
size
assert
len
(
next_batch
.
all_input_ids
)
==
next_batch
.
size
assert
len
(
next_batch
.
all_input_ids
[
0
])
==
len
(
next_batch
.
attention_mask
[
0
])
==
9
assert
(
len
(
next_batch
.
all_input_ids
[
0
])
==
len
(
next_batch
.
attention_mask
[
0
])
==
sequence_length
+
1
)
assert
torch
.
all
(
next_batch
.
all_input_ids
[
0
][
-
2
:]
==
10264
)
assert
torch
.
all
(
next_batch
.
all_input_ids
[
0
][
-
2
:]
==
10264
)
assert
torch
.
all
(
next_batch
.
all_input_ids
[
0
][:
-
2
]
==
3
)
assert
torch
.
all
(
next_batch
.
all_input_ids
[
0
][:
-
2
]
==
3
)
...
@@ -106,8 +110,12 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
...
@@ -106,8 +110,12 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert
next_batch
.
max_sequence_length
==
next_batch
.
input_lengths
[
0
]
assert
next_batch
.
max_sequence_length
==
next_batch
.
input_lengths
[
0
]
assert
next_batch
.
past_key_values
is
not
None
assert
next_batch
.
past_key_values
is
not
None
assert
all
([
p
[
0
].
shape
==
(
16
,
64
,
8
)
for
p
in
next_batch
.
past_key_values
])
assert
all
(
assert
all
([
p
[
1
].
shape
==
(
16
,
8
,
64
)
for
p
in
next_batch
.
past_key_values
])
[
p
[
0
].
shape
==
(
16
,
64
,
sequence_length
)
for
p
in
next_batch
.
past_key_values
]
)
assert
all
(
[
p
[
1
].
shape
==
(
16
,
sequence_length
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
def
test_causal_lm_generate_token_completion
(
default_bloom
,
default_bloom_batch
):
def
test_causal_lm_generate_token_completion
(
default_bloom
,
default_bloom_batch
):
...
...
server/tests/models/test_causal_lm.py
View file @
718096f6
...
@@ -8,13 +8,13 @@ from text_generation.models.causal_lm import CausalLM, CausalLMBatch
...
@@ -8,13 +8,13 @@ from text_generation.models.causal_lm import CausalLM, CausalLMBatch
@
pytest
.
fixture
@
pytest
.
fixture
def
default_pb_request
(
default_pb_parameters
):
def
default_pb_request
(
default_pb_parameters
,
default_pb_stop_parameters
):
return
generate_pb2
.
Request
(
return
generate_pb2
.
Request
(
id
=
0
,
id
=
0
,
inputs
=
"Test"
,
inputs
=
"Test"
,
input_length
=
1
,
input_length
=
1
,
parameters
=
default_pb_parameters
,
parameters
=
default_pb_parameters
,
max_new_tokens
=
10
,
stopping_parameters
=
default_pb_stop_parameters
,
)
)
...
@@ -33,7 +33,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
...
@@ -33,7 +33,7 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_0
=
copy
(
default_pb_request
)
req_0
=
copy
(
default_pb_request
)
req_1
=
default_pb_request
req_1
=
default_pb_request
req_1
.
id
=
1
req_1
.
id
=
1
req_1
.
max_new_tokens
=
5
req_1
.
stopping_parameters
.
max_new_tokens
=
5
batch_pb
=
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
req_0
,
req_1
],
size
=
2
)
batch_pb
=
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
req_0
,
req_1
],
size
=
2
)
return
CausalLMBatch
.
from_pb
(
batch_pb
,
gpt2_tokenizer
,
torch
.
device
(
"cpu"
))
return
CausalLMBatch
.
from_pb
(
batch_pb
,
gpt2_tokenizer
,
torch
.
device
(
"cpu"
))
...
@@ -51,7 +51,6 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
...
@@ -51,7 +51,6 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert
batch
.
requests
==
default_pb_batch
.
requests
assert
batch
.
requests
==
default_pb_batch
.
requests
assert
len
(
batch
.
input_ids
)
==
default_pb_batch
.
size
assert
len
(
batch
.
input_ids
)
==
default_pb_batch
.
size
assert
len
(
batch
.
input_ids
[
0
])
==
8
assert
batch
.
input_ids
[
0
][
-
1
]
==
14402
assert
batch
.
input_ids
[
0
][
-
1
]
==
14402
assert
torch
.
all
(
batch
.
input_ids
[
0
][:
-
1
]
==
50256
)
assert
torch
.
all
(
batch
.
input_ids
[
0
][:
-
1
]
==
50256
)
...
@@ -80,6 +79,7 @@ def test_causal_lm_batch_type(default_causal_lm):
...
@@ -80,6 +79,7 @@ def test_causal_lm_batch_type(default_causal_lm):
def
test_causal_lm_generate_token
(
default_causal_lm
,
default_causal_lm_batch
):
def
test_causal_lm_generate_token
(
default_causal_lm
,
default_causal_lm_batch
):
sequence_length
=
len
(
default_causal_lm_batch
.
all_input_ids
[
0
])
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
default_causal_lm_batch
default_causal_lm_batch
)
)
...
@@ -88,8 +88,12 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
...
@@ -88,8 +88,12 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert
isinstance
(
next_batch
,
CausalLMBatch
)
assert
isinstance
(
next_batch
,
CausalLMBatch
)
assert
len
(
next_batch
.
all_input_ids
)
==
next_batch
.
size
assert
len
(
next_batch
.
all_input_ids
)
==
next_batch
.
size
assert
len
(
next_batch
.
all_input_ids
[
0
])
==
len
(
next_batch
.
attention_mask
[
0
])
==
9
assert
(
assert
next_batch
.
all_input_ids
[
0
][
-
1
]
==
6208
len
(
next_batch
.
all_input_ids
[
0
])
==
len
(
next_batch
.
attention_mask
[
0
])
==
sequence_length
+
1
)
assert
next_batch
.
all_input_ids
[
0
][
-
1
]
==
13
assert
next_batch
.
all_input_ids
[
0
][
-
2
]
==
14402
assert
next_batch
.
all_input_ids
[
0
][
-
2
]
==
14402
assert
torch
.
all
(
next_batch
.
all_input_ids
[
0
][:
-
2
]
==
50256
)
assert
torch
.
all
(
next_batch
.
all_input_ids
[
0
][:
-
2
]
==
50256
)
...
@@ -97,14 +101,18 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
...
@@ -97,14 +101,18 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][:
-
2
]
==
0
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][:
-
2
]
==
0
)
assert
next_batch
.
input_ids
.
shape
==
(
next_batch
.
size
,
1
)
assert
next_batch
.
input_ids
.
shape
==
(
next_batch
.
size
,
1
)
assert
next_batch
.
input_ids
[
0
,
0
]
==
6208
assert
next_batch
.
input_ids
[
0
,
0
]
==
13
assert
next_batch
.
input_lengths
==
[
2
]
assert
next_batch
.
input_lengths
==
[
2
]
assert
next_batch
.
max_sequence_length
==
next_batch
.
input_lengths
[
0
]
assert
next_batch
.
max_sequence_length
==
next_batch
.
input_lengths
[
0
]
assert
next_batch
.
past_key_values
is
not
None
assert
next_batch
.
past_key_values
is
not
None
assert
all
([
p
[
0
].
shape
==
(
1
,
12
,
8
,
64
)
for
p
in
next_batch
.
past_key_values
])
assert
all
(
assert
all
([
p
[
1
].
shape
==
(
1
,
12
,
8
,
64
)
for
p
in
next_batch
.
past_key_values
])
[
p
[
0
].
shape
==
(
1
,
12
,
sequence_length
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
assert
all
(
[
p
[
1
].
shape
==
(
1
,
12
,
sequence_length
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
def
test_causal_lm_generate_token_completion
(
def
test_causal_lm_generate_token_completion
(
...
@@ -119,10 +127,7 @@ def test_causal_lm_generate_token_completion(
...
@@ -119,10 +127,7 @@ def test_causal_lm_generate_token_completion(
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
(
assert
generated_texts
[
0
].
output
==
"Test.java:784) at net.minecraft."
generated_texts
[
0
].
output
==
"Test Test Test Test Test Test Test Test Test Test Test"
)
assert
generated_texts
[
0
].
request
==
default_causal_lm_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
request
==
default_causal_lm_batch
.
requests
[
0
]
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
tokens
...
@@ -145,7 +150,7 @@ def test_causal_lm_generate_token_completion_multi(
...
@@ -145,7 +150,7 @@ def test_causal_lm_generate_token_completion_multi(
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"Test
Test Test Test Test Test
"
assert
generated_texts
[
0
].
output
==
"Test
.java:784)
"
assert
(
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
1
]
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
1
]
)
)
...
@@ -166,10 +171,7 @@ def test_causal_lm_generate_token_completion_multi(
...
@@ -166,10 +171,7 @@ def test_causal_lm_generate_token_completion_multi(
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
(
assert
generated_texts
[
0
].
output
==
"Test.java:784) at net.minecraft."
generated_texts
[
0
].
output
==
"Test Test Test Test Test Test Test Test Test Test Test"
)
assert
(
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
0
]
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
0
]
)
)
...
@@ -200,7 +202,8 @@ def test_batch_concatenate(
...
@@ -200,7 +202,8 @@ def test_batch_concatenate(
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
:
-
2
]
==
0
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
:
-
2
]
==
0
)
assert
next_batch
.
batch_id
==
0
assert
next_batch
.
batch_id
==
0
assert
torch
.
all
(
next_batch
.
input_ids
==
6208
)
assert
next_batch
.
input_ids
[
0
,
0
]
==
12355
assert
torch
.
all
(
next_batch
.
input_ids
[
1
:]
==
13
)
assert
next_batch
.
input_lengths
==
[
3
,
2
,
2
]
assert
next_batch
.
input_lengths
==
[
3
,
2
,
2
]
assert
next_batch
.
max_sequence_length
==
3
assert
next_batch
.
max_sequence_length
==
3
...
@@ -239,7 +242,7 @@ def test_batch_concatenate(
...
@@ -239,7 +242,7 @@ def test_batch_concatenate(
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"Test
Test Test Test Test Test
"
assert
generated_texts
[
0
].
output
==
"Test
.java:784)
"
assert
(
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
1
]
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
1
]
)
)
...
@@ -260,10 +263,7 @@ def test_batch_concatenate(
...
@@ -260,10 +263,7 @@ def test_batch_concatenate(
assert
next_batch
is
not
None
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
(
assert
generated_texts
[
0
].
output
==
"Test.java:784) at net.minecraft."
generated_texts
[
0
].
output
==
"Test Test Test Test Test Test Test Test Test Test Test"
)
assert
generated_texts
[
0
].
request
==
default_causal_lm_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
request
==
default_causal_lm_batch
.
requests
[
0
]
assert
(
assert
(
generated_texts
[
0
].
tokens
generated_texts
[
0
].
tokens
...
@@ -283,10 +283,7 @@ def test_batch_concatenate(
...
@@ -283,10 +283,7 @@ def test_batch_concatenate(
assert
next_batch
is
None
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
len
(
generated_texts
)
==
1
assert
(
assert
generated_texts
[
0
].
output
==
"Test.java:784) at net.minecraft."
generated_texts
[
0
].
output
==
"Test Test Test Test Test Test Test Test Test Test Test"
)
assert
(
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
0
]
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
0
]
)
)
...
...
server/tests/models/test_seq2seq_lm.py
View file @
718096f6
...
@@ -8,13 +8,13 @@ from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
...
@@ -8,13 +8,13 @@ from text_generation.models.seq2seq_lm import Seq2SeqLM, Seq2SeqLMBatch
@
pytest
.
fixture
@
pytest
.
fixture
def
default_pb_request
(
default_pb_parameters
):
def
default_pb_request
(
default_pb_parameters
,
default_pb_stop_parameters
):
return
generate_pb2
.
Request
(
return
generate_pb2
.
Request
(
id
=
0
,
id
=
0
,
inputs
=
"Test"
,
inputs
=
"Test"
,
input_length
=
2
,
input_length
=
2
,
parameters
=
default_pb_parameters
,
parameters
=
default_pb_parameters
,
max_new_tokens
=
10
,
stopping_parameters
=
default_pb_stop_parameters
,
)
)
...
@@ -35,7 +35,7 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
...
@@ -35,7 +35,7 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
req_0
=
copy
(
default_pb_request
)
req_0
=
copy
(
default_pb_request
)
req_1
=
default_pb_request
req_1
=
default_pb_request
req_1
.
id
=
1
req_1
.
id
=
1
req_1
.
max_new_tokens
=
5
req_1
.
stopping_parameters
.
max_new_tokens
=
5
batch_pb
=
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
req_0
,
req_1
],
size
=
2
)
batch_pb
=
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
req_0
,
req_1
],
size
=
2
)
return
Seq2SeqLMBatch
.
from_pb
(
batch_pb
,
mt0_small_tokenizer
,
torch
.
device
(
"cpu"
))
return
Seq2SeqLMBatch
.
from_pb
(
batch_pb
,
mt0_small_tokenizer
,
torch
.
device
(
"cpu"
))
...
@@ -48,11 +48,12 @@ def default_seq2seq_lm():
...
@@ -48,11 +48,12 @@ def default_seq2seq_lm():
def
test_batch_from_pb
(
default_pb_batch
,
default_seq2seq_lm_batch
):
def
test_batch_from_pb
(
default_pb_batch
,
default_seq2seq_lm_batch
):
batch
=
default_seq2seq_lm_batch
batch
=
default_seq2seq_lm_batch
sequence_length
=
len
(
default_seq2seq_lm_batch
.
input_ids
[
0
])
assert
batch
.
batch_id
==
default_pb_batch
.
id
assert
batch
.
batch_id
==
default_pb_batch
.
id
assert
batch
.
requests
==
default_pb_batch
.
requests
assert
batch
.
requests
==
default_pb_batch
.
requests
assert
batch
.
input_ids
.
shape
==
(
default_pb_batch
.
size
,
8
)
assert
batch
.
input_ids
.
shape
==
(
default_pb_batch
.
size
,
sequence_length
)
assert
batch
.
input_ids
[
0
][
-
2
]
==
4268
assert
batch
.
input_ids
[
0
][
-
2
]
==
4268
assert
batch
.
input_ids
[
0
][
-
1
]
==
1
assert
batch
.
input_ids
[
0
][
-
1
]
==
1
assert
torch
.
all
(
batch
.
input_ids
[
0
][:
-
2
]
==
0
)
assert
torch
.
all
(
batch
.
input_ids
[
0
][:
-
2
]
==
0
)
...
@@ -86,6 +87,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
...
@@ -86,6 +87,7 @@ def test_seq2seq_lm_batch_type(default_seq2seq_lm):
def
test_seq2seq_lm_generate_token
(
default_seq2seq_lm
,
default_seq2seq_lm_batch
):
def
test_seq2seq_lm_generate_token
(
default_seq2seq_lm
,
default_seq2seq_lm_batch
):
sequence_length
=
len
(
default_seq2seq_lm_batch
.
input_ids
[
0
])
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
default_seq2seq_lm_batch
default_seq2seq_lm_batch
)
)
...
@@ -108,7 +110,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
...
@@ -108,7 +110,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert
next_batch
.
decoder_input_ids
[
0
,
0
]
==
0
assert
next_batch
.
decoder_input_ids
[
0
,
0
]
==
0
assert
next_batch
.
decoder_input_ids
[
0
,
1
]
==
259
assert
next_batch
.
decoder_input_ids
[
0
,
1
]
==
259
assert
next_batch
.
decoder_attention_mask
is
None
assert
next_batch
.
decoder_attention_mask
is
None
assert
next_batch
.
encoder_last_hidden_state
.
shape
==
(
1
,
8
,
512
)
assert
next_batch
.
encoder_last_hidden_state
.
shape
==
(
1
,
sequence_length
,
512
)
assert
next_batch
.
decoder_input_lengths
==
[
2
]
assert
next_batch
.
decoder_input_lengths
==
[
2
]
assert
next_batch
.
max_decoder_input_length
==
2
assert
next_batch
.
max_decoder_input_length
==
2
...
@@ -121,10 +123,16 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
...
@@ -121,10 +123,16 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
[
p
[
1
].
shape
==
(
next_batch
.
size
,
6
,
1
,
64
)
for
p
in
next_batch
.
past_key_values
]
[
p
[
1
].
shape
==
(
next_batch
.
size
,
6
,
1
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
)
assert
all
(
assert
all
(
[
p
[
2
].
shape
==
(
next_batch
.
size
,
6
,
8
,
64
)
for
p
in
next_batch
.
past_key_values
]
[
p
[
2
].
shape
==
(
next_batch
.
size
,
6
,
sequence_length
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
)
assert
all
(
assert
all
(
[
p
[
3
].
shape
==
(
next_batch
.
size
,
6
,
8
,
64
)
for
p
in
next_batch
.
past_key_values
]
[
p
[
3
].
shape
==
(
next_batch
.
size
,
6
,
sequence_length
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
)
...
...
server/tests/test_utils.py
View file @
718096f6
...
@@ -4,10 +4,55 @@ from text_generation.utils import (
...
@@ -4,10 +4,55 @@ from text_generation.utils import (
weight_hub_files
,
weight_hub_files
,
download_weights
,
download_weights
,
weight_files
,
weight_files
,
StopSequenceCriteria
,
StoppingCriteria
,
LocalEntryNotFoundError
,
LocalEntryNotFoundError
,
)
)
def
test_stop_sequence_criteria
():
criteria
=
StopSequenceCriteria
([
1
,
2
,
3
])
assert
not
criteria
(
1
)
assert
criteria
.
current_token_idx
==
1
assert
not
criteria
(
2
)
assert
criteria
.
current_token_idx
==
2
assert
criteria
(
3
)
assert
criteria
.
current_token_idx
==
3
def
test_stop_sequence_criteria_reset
():
criteria
=
StopSequenceCriteria
([
1
,
2
,
3
])
assert
not
criteria
(
1
)
assert
criteria
.
current_token_idx
==
1
assert
not
criteria
(
2
)
assert
criteria
.
current_token_idx
==
2
assert
not
criteria
(
4
)
assert
criteria
.
current_token_idx
==
0
def
test_stop_sequence_criteria_empty
():
with
pytest
.
raises
(
ValueError
):
StopSequenceCriteria
([])
def
test_stopping_criteria
():
criteria
=
StoppingCriteria
([
StopSequenceCriteria
([
1
,
2
,
3
])],
max_new_tokens
=
5
)
assert
criteria
([
1
])
==
(
False
,
None
)
assert
criteria
([
1
,
2
])
==
(
False
,
None
)
assert
criteria
([
1
,
2
,
3
])
==
(
True
,
"stop_sequence"
)
def
test_stopping_criteria_max
():
criteria
=
StoppingCriteria
([
StopSequenceCriteria
([
1
,
2
,
3
])],
max_new_tokens
=
5
)
assert
criteria
([
1
])
==
(
False
,
None
)
assert
criteria
([
1
,
1
])
==
(
False
,
None
)
assert
criteria
([
1
,
1
,
1
])
==
(
False
,
None
)
assert
criteria
([
1
,
1
,
1
,
1
])
==
(
False
,
None
)
assert
criteria
([
1
,
1
,
1
,
1
,
1
])
==
(
True
,
"length"
)
def
test_weight_hub_files
():
def
test_weight_hub_files
():
filenames
=
weight_hub_files
(
"bigscience/bloom-560m"
)
filenames
=
weight_hub_files
(
"bigscience/bloom-560m"
)
assert
filenames
==
[
"model.safetensors"
]
assert
filenames
==
[
"model.safetensors"
]
...
...
server/text_generation/models/causal_lm.py
View file @
718096f6
...
@@ -57,23 +57,17 @@ class CausalLMBatch:
...
@@ -57,23 +57,17 @@ class CausalLMBatch:
for
r
in
pb
.
requests
:
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
))
NextTokenChooser
(
temperature
=
r
.
parameters
.
temperature
,
top_k
=
r
.
parameters
.
top_k
,
top_p
=
r
.
parameters
.
top_p
,
do_sample
=
r
.
parameters
.
do_sample
,
)
)
stopping_criterias
.
append
(
stopping_criterias
.
append
(
StoppingCriteria
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
eos_token_id
=
tokenizer
.
eos_token_id
,
max_new_tokens
=
r
.
max_new_tokens
)
)
)
pad_to_multiple_of
=
8
if
"gpu"
in
str
(
device
)
else
None
pad_to_multiple_of
=
8
if
"gpu"
in
str
(
device
)
else
None
tokenized_inputs
=
tokenizer
(
tokenized_inputs
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
pad_to_multiple_of
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
pad_to_multiple_of
,
).
to
(
device
)
).
to
(
device
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
all_input_ids
=
tokenized_inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
...
@@ -123,8 +117,8 @@ class CausalLMBatch:
...
@@ -123,8 +117,8 @@ class CausalLMBatch:
end_index
=
start_index
+
batch
.
size
end_index
=
start_index
+
batch
.
size
# We only concatenate batches that did at least one step
# We only concatenate batches that did at least one step
if
batch
.
input_ids
.
shape
[
1
]
>
1
:
if
batch
.
past_key_values
is
None
:
raise
ValueError
(
"
Batch input_ids should be of shape (batch_size, 1)
"
)
raise
ValueError
(
"
only concatenate prefilled batches
"
)
# Create empty tensor
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# input_ids is always of shape [batch_size, 1]
...
@@ -331,14 +325,17 @@ class CausalLM(Model):
...
@@ -331,14 +325,17 @@ class CausalLM(Model):
all_tokens
=
torch
.
cat
([
all_tokens
,
next_token
])
all_tokens
=
torch
.
cat
([
all_tokens
,
next_token
])
# Evaluate stopping criteria
# Evaluate stopping criteria
if
stopping_criteria
(
all_tokens
):
stop
,
reason
=
stopping_criteria
(
all_tokens
)
if
stop
:
# Decode all tokens
# Decode all tokens
output
=
self
.
tokenizer
.
decode
(
output
=
self
.
tokenizer
.
decode
(
all_tokens
.
squeeze
(
-
1
),
skip_special_tokens
=
True
all_tokens
.
squeeze
(
-
1
),
skip_special_tokens
=
True
)
)
# Add to the list of finished generations with the original request
# Add to the list of finished generations with the original request
generated_texts
.
append
(
generated_texts
.
append
(
GeneratedText
(
request
,
output
,
stopping_criteria
.
current_tokens
)
GeneratedText
(
request
,
output
,
stopping_criteria
.
current_tokens
,
reason
)
)
)
# add to the next batch
# add to the next batch
else
:
else
:
...
...
server/text_generation/models/galactica.py
View file @
718096f6
...
@@ -94,18 +94,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
...
@@ -94,18 +94,9 @@ class GalacticaCausalLMBatch(CausalLMBatch):
# Add escape_custom_split_sequence to the CausalLMBatch logic
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs
.
append
(
escape_custom_split_sequence
(
r
.
inputs
))
inputs
.
append
(
escape_custom_split_sequence
(
r
.
inputs
))
input_lengths
.
append
(
r
.
input_length
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
))
NextTokenChooser
(
temperature
=
r
.
parameters
.
temperature
,
top_k
=
r
.
parameters
.
top_k
,
top_p
=
r
.
parameters
.
top_p
,
do_sample
=
r
.
parameters
.
do_sample
,
)
)
stopping_criterias
.
append
(
stopping_criterias
.
append
(
StoppingCriteria
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
eos_token_id
=
tokenizer
.
eos_token_id
,
max_new_tokens
=
r
.
max_new_tokens
)
)
)
tokenized_inputs
=
tokenizer
(
tokenized_inputs
=
tokenizer
(
...
@@ -207,11 +198,7 @@ class GalacticaSharded(Galactica):
...
@@ -207,11 +198,7 @@ class GalacticaSharded(Galactica):
continue
continue
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
try
:
module
=
model
.
get_submodule
(
module_name
)
module
=
model
.
get_submodule
(
module_name
)
except
Exception
as
e
:
print
(
type
(
model
),
name
,
module_name
,
param_name
)
raise
e
current_tensor
=
parameters
[
name
]
current_tensor
=
parameters
[
name
]
slice_
=
f
.
get_slice
(
name
)
slice_
=
f
.
get_slice
(
name
)
...
...
server/text_generation/models/seq2seq_lm.py
View file @
718096f6
...
@@ -68,24 +68,18 @@ class Seq2SeqLMBatch:
...
@@ -68,24 +68,18 @@ class Seq2SeqLMBatch:
# Decoder sequence only contains the bos_token
# Decoder sequence only contains the bos_token
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_lengths
.
append
(
1
)
decoder_input_lengths
.
append
(
1
)
next_token_choosers
.
append
(
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
))
NextTokenChooser
(
temperature
=
r
.
parameters
.
temperature
,
top_k
=
r
.
parameters
.
top_k
,
top_p
=
r
.
parameters
.
top_p
,
do_sample
=
r
.
parameters
.
do_sample
,
)
)
stopping_criterias
.
append
(
stopping_criterias
.
append
(
StoppingCriteria
(
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
eos_token_id
=
tokenizer
.
eos_token_id
,
max_new_tokens
=
r
.
max_new_tokens
)
)
)
# Tokenize batch
# Tokenize batch
pad_to_multiple_of
=
8
if
"gpu"
in
str
(
device
)
else
None
pad_to_multiple_of
=
8
if
"gpu"
in
str
(
device
)
else
None
tokenized_inputs
=
tokenizer
(
tokenized_inputs
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
pad_to_multiple_of
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
pad_to_multiple_of
,
).
to
(
device
)
).
to
(
device
)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids
=
torch
.
tensor
(
decoder_input_ids
,
device
=
device
).
unsqueeze
(
-
1
)
decoder_input_ids
=
torch
.
tensor
(
decoder_input_ids
,
device
=
device
).
unsqueeze
(
-
1
)
...
@@ -431,12 +425,15 @@ class Seq2SeqLM(Model):
...
@@ -431,12 +425,15 @@ class Seq2SeqLM(Model):
decoder_tokens
=
torch
.
cat
([
decoder_tokens
,
next_token
.
squeeze
(
1
)])
decoder_tokens
=
torch
.
cat
([
decoder_tokens
,
next_token
.
squeeze
(
1
)])
# Evaluate stopping criteria
# Evaluate stopping criteria
if
stopping_criteria
(
decoder_tokens
):
stop
,
reason
=
stopping_criteria
(
decoder_tokens
)
if
stop
:
# Decode tokens
# Decode tokens
output
=
self
.
tokenizer
.
decode
(
decoder_tokens
,
skip_special_tokens
=
True
)
output
=
self
.
tokenizer
.
decode
(
decoder_tokens
,
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request
# Add to the list of finished generations with the original request
generated_texts
.
append
(
generated_texts
.
append
(
GeneratedText
(
request
,
output
,
stopping_criteria
.
current_tokens
)
GeneratedText
(
request
,
output
,
stopping_criteria
.
current_tokens
,
reason
)
)
)
# add to the next batch
# add to the next batch
else
:
else
:
...
...
server/text_generation/models/types.py
View file @
718096f6
...
@@ -32,8 +32,12 @@ class GeneratedText:
...
@@ -32,8 +32,12 @@ class GeneratedText:
request
:
generate_pb2
.
Request
request
:
generate_pb2
.
Request
output
:
str
output
:
str
tokens
:
int
tokens
:
int
reason
:
str
def
to_pb
(
self
)
->
generate_pb2
.
GeneratedText
:
def
to_pb
(
self
)
->
generate_pb2
.
GeneratedText
:
return
generate_pb2
.
GeneratedText
(
return
generate_pb2
.
GeneratedText
(
request
=
self
.
request
,
output
=
self
.
output
,
tokens
=
self
.
tokens
request
=
self
.
request
,
output
=
self
.
output
,
tokens
=
self
.
tokens
,
finish_reason
=
self
.
reason
,
)
)
server/text_generation/utils.py
View file @
718096f6
import
concurrent
import
concurrent
import
os
import
os
import
signal
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -11,6 +10,8 @@ from functools import partial
...
@@ -11,6 +10,8 @@ from functools import partial
from
huggingface_hub
import
HfApi
,
hf_hub_download
,
try_to_load_from_cache
from
huggingface_hub
import
HfApi
,
hf_hub_download
,
try_to_load_from_cache
from
huggingface_hub.utils
import
LocalEntryNotFoundError
from
huggingface_hub.utils
import
LocalEntryNotFoundError
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
typing
import
List
,
Optional
,
Tuple
from
transformers
import
AutoTokenizer
from
transformers.generation.logits_process
import
(
from
transformers.generation.logits_process
import
(
LogitsProcessorList
,
LogitsProcessorList
,
TemperatureLogitsWarper
,
TemperatureLogitsWarper
,
...
@@ -18,6 +19,8 @@ from transformers.generation.logits_process import (
...
@@ -18,6 +19,8 @@ from transformers.generation.logits_process import (
TopKLogitsWarper
,
TopKLogitsWarper
,
)
)
from
text_generation.pb
import
generate_pb2
class
Sampling
:
class
Sampling
:
def
__call__
(
self
,
logits
):
def
__call__
(
self
,
logits
):
...
@@ -56,20 +59,72 @@ class NextTokenChooser:
...
@@ -56,20 +59,72 @@ class NextTokenChooser:
next_ids
=
self
.
choice
(
scores
)
next_ids
=
self
.
choice
(
scores
)
return
next_ids
.
unsqueeze
(
-
1
)
return
next_ids
.
unsqueeze
(
-
1
)
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
LogitsWarperParameters
)
->
"NextTokenChooser"
:
return
NextTokenChooser
(
temperature
=
pb
.
temperature
,
top_k
=
pb
.
top_k
,
top_p
=
pb
.
top_p
,
do_sample
=
pb
.
do_sample
,
)
class
StopSequenceCriteria
:
def
__init__
(
self
,
tokens
:
List
[
int
]):
if
not
tokens
:
raise
ValueError
(
"tokens cannot be empty"
)
self
.
tokens
=
tokens
self
.
current_token_idx
=
0
def
__call__
(
self
,
last_token
:
int
)
->
bool
:
if
last_token
==
self
.
tokens
[
self
.
current_token_idx
]:
# Increase idx to go to next token
self
.
current_token_idx
+=
1
else
:
# Reset to first token of the stopping sequence
self
.
current_token_idx
=
0
if
self
.
current_token_idx
==
len
(
self
.
tokens
):
# We matched the entire sequence without resetting
return
True
return
False
class
StoppingCriteria
:
class
StoppingCriteria
:
def
__init__
(
self
,
eos_token_id
,
max_new_tokens
=
20
):
def
__init__
(
self
.
eos_token_id
=
eos_token_id
self
,
stop_sequence_criterias
:
List
[
StopSequenceCriteria
],
max_new_tokens
=
20
):
self
.
stop_sequence_criterias
=
stop_sequence_criterias
self
.
max_new_tokens
=
max_new_tokens
self
.
max_new_tokens
=
max_new_tokens
self
.
current_tokens
=
0
self
.
current_tokens
=
0
def
__call__
(
self
,
all_ids
):
def
__call__
(
self
,
all_ids
)
->
Tuple
[
bool
,
Optional
[
str
]]
:
self
.
current_tokens
+=
1
self
.
current_tokens
+=
1
if
self
.
current_tokens
>=
self
.
max_new_tokens
:
if
self
.
current_tokens
>=
self
.
max_new_tokens
:
return
True
return
True
,
"length"
if
self
.
eos_token_id
is
not
None
and
all_ids
[
-
1
]
==
self
.
eos_token_id
:
return
True
last_token
=
all_ids
[
-
1
]
return
False
for
stop_sequence_criteria
in
self
.
stop_sequence_criterias
:
if
stop_sequence_criteria
(
last_token
):
return
True
,
"stop_sequence"
return
False
,
None
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
StoppingCriteriaParameters
,
tokenizer
:
AutoTokenizer
)
->
"StoppingCriteria"
:
stop_sequence_criterias
=
[]
for
stop_sequence
in
pb
.
stop_sequences
:
tokens
=
tokenizer
(
stop_sequence
,
padding
=
False
,
return_attention_mask
=
False
).
input_ids
if
tokens
:
stop_sequence_criterias
.
append
(
StopSequenceCriteria
(
tokens
))
stop_sequence_criterias
.
append
(
StopSequenceCriteria
([
tokenizer
.
eos_token_id
]))
return
StoppingCriteria
(
stop_sequence_criterias
,
pb
.
max_new_tokens
)
def
initialize_torch_distributed
():
def
initialize_torch_distributed
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment