Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
b49dbf2d
Unverified
Commit
b49dbf2d
authored
Mar 16, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 16, 2023
Browse files
fix(server): use server tokenizer as gt (#128)
parent
8ad60b75
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
46 additions
and
54 deletions
+46
-54
proto/generate.proto
proto/generate.proto
+2
-4
router/src/infer.rs
router/src/infer.rs
+4
-2
router/src/queue.rs
router/src/queue.rs
+2
-3
router/src/validation.rs
router/src/validation.rs
+0
-2
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+3
-4
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+3
-4
server/tests/models/test_santacoder.py
server/tests/models/test_santacoder.py
+0
-2
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+0
-1
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+26
-26
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+6
-6
No files found.
proto/generate.proto
View file @
b49dbf2d
...
...
@@ -58,12 +58,10 @@ message Request {
uint64
id
=
1
;
/// The generation context
string
inputs
=
2
;
/// The number of tokens inside inputs
uint32
input_length
=
3
;
/// Next Token Chooser Parameters
NextTokenChooserParameters
parameters
=
4
;
NextTokenChooserParameters
parameters
=
3
;
/// Stopping Criteria Parameters
StoppingCriteriaParameters
stopping_parameters
=
5
;
StoppingCriteriaParameters
stopping_parameters
=
4
;
}
message
Batch
{
...
...
router/src/infer.rs
View file @
b49dbf2d
...
...
@@ -278,7 +278,8 @@ async fn batching_task(
// because a new batch is being computed
let
entry_waiting_span
=
info_span!
(
parent
:
&
entry
.span
,
"waiting"
,
batch_size
=
new_batch_size
);
// Add relationship
// Add relationships
span
.follows_from
(
&
entry_waiting_span
);
entry_waiting_span
.follows_from
(
&
span
);
// Update entry
entry
.temp_span
=
Some
(
entry_waiting_span
);
...
...
@@ -305,7 +306,8 @@ async fn batching_task(
// Create a new span to link the batch back to this entry
let
entry_batch_span
=
info_span!
(
parent
:
&
entry
.span
,
"infer"
,
batch_size
=
next_batch_size
);
// Add relationship
// Add relationships
next_batch_span
.follows_from
(
&
entry_batch_span
);
entry_batch_span
.follows_from
(
&
next_batch_span
);
// Update entry
entry
.temp_span
=
Some
(
entry_batch_span
);
...
...
router/src/queue.rs
View file @
b49dbf2d
...
...
@@ -165,7 +165,8 @@ impl State {
// Create a new span to link the batch back to this entry
let
entry_batch_span
=
info_span!
(
parent
:
&
entry
.span
,
"infer"
,
batch_size
=
next_batch_size
);
// Add relationship
// Add relationships
next_batch_span
.follows_from
(
&
entry_batch_span
);
entry_batch_span
.follows_from
(
&
next_batch_span
);
// Update entry
entry
.temp_span
=
Some
(
entry_batch_span
);
...
...
@@ -173,7 +174,6 @@ impl State {
batch_requests
.push
(
Request
{
id
,
inputs
:
entry
.request.inputs
.clone
(),
input_length
:
entry
.request.input_length
,
parameters
:
Some
(
entry
.request.parameters
.clone
()),
stopping_parameters
:
Some
(
entry
.request.stopping_parameters
.clone
()),
});
...
...
@@ -226,7 +226,6 @@ mod tests {
Entry
{
request
:
ValidGenerateRequest
{
inputs
:
""
.to_string
(),
input_length
:
0
,
parameters
:
NextTokenChooserParameters
{
temperature
:
0.0
,
top_k
:
0
,
...
...
router/src/validation.rs
View file @
b49dbf2d
...
...
@@ -322,7 +322,6 @@ fn validate(
Ok
(
ValidGenerateRequest
{
inputs
,
input_length
:
input_length
as
u32
,
parameters
,
stopping_parameters
,
})
...
...
@@ -337,7 +336,6 @@ type ValidationRequest = (
#[derive(Debug)]
pub
(
crate
)
struct
ValidGenerateRequest
{
pub
inputs
:
String
,
pub
input_length
:
u32
,
pub
parameters
:
NextTokenChooserParameters
,
pub
stopping_parameters
:
StoppingCriteriaParameters
,
}
...
...
server/tests/models/test_bloom.py
View file @
b49dbf2d
...
...
@@ -24,7 +24,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"Test"
,
input_length
=
1
,
parameters
=
default_pb_parameters
,
stopping_parameters
=
default_pb_stop_parameters
,
)
...
...
@@ -77,7 +76,7 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
assert
batch
.
size
==
default_pb_batch
.
size
assert
len
(
batch
.
next_token_choosers
)
==
len
(
batch
.
stopping_criterias
)
==
batch
.
size
assert
batch
.
max_
sequence
_length
==
batch
.
input_lengths
[
0
]
assert
batch
.
max_
input
_length
==
batch
.
input_lengths
[
0
]
def
test_batch_concatenate_no_prefill
(
default_bloom_batch
):
...
...
@@ -110,7 +109,7 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
assert
next_batch
.
input_ids
[
0
,
0
]
==
10264
assert
next_batch
.
input_lengths
==
[
2
]
assert
next_batch
.
max_
sequence
_length
==
next_batch
.
input_lengths
[
0
]
assert
next_batch
.
max_
input
_length
==
next_batch
.
input_lengths
[
0
]
assert
next_batch
.
past_key_values
is
not
None
assert
all
(
...
...
@@ -222,7 +221,7 @@ def test_batch_concatenate(
assert
torch
.
all
(
next_batch
.
input_ids
==
10264
)
assert
next_batch
.
input_lengths
==
[
3
,
2
,
2
]
assert
next_batch
.
max_
sequence
_length
==
3
assert
next_batch
.
max_
input
_length
==
3
assert
next_batch
.
requests
[
0
]
==
next_batch_0
.
requests
[
0
]
assert
next_batch
.
requests
[
1
:]
==
next_batch_1
.
requests
...
...
server/tests/models/test_causal_lm.py
View file @
b49dbf2d
...
...
@@ -25,7 +25,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"Test"
,
input_length
=
1
,
parameters
=
default_pb_parameters
,
stopping_parameters
=
default_pb_stop_parameters
,
)
...
...
@@ -74,7 +73,7 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
assert
batch
.
size
==
default_pb_batch
.
size
assert
len
(
batch
.
next_token_choosers
)
==
len
(
batch
.
stopping_criterias
)
==
batch
.
size
assert
batch
.
max_
sequence
_length
==
batch
.
input_lengths
[
0
]
assert
batch
.
max_
input
_length
==
batch
.
input_lengths
[
0
]
def
test_batch_concatenate_no_prefill
(
default_causal_lm_batch
):
...
...
@@ -107,7 +106,7 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
assert
next_batch
.
input_ids
[
0
,
0
]
==
13
assert
next_batch
.
input_lengths
==
[
2
]
assert
next_batch
.
max_
sequence
_length
==
next_batch
.
input_lengths
[
0
]
assert
next_batch
.
max_
input
_length
==
next_batch
.
input_lengths
[
0
]
assert
next_batch
.
past_key_values
is
not
None
assert
all
(
...
...
@@ -220,7 +219,7 @@ def test_batch_concatenate(
assert
torch
.
all
(
next_batch
.
input_ids
[
1
:]
==
13
)
assert
next_batch
.
input_lengths
==
[
3
,
2
,
2
]
assert
next_batch
.
max_
sequence
_length
==
3
assert
next_batch
.
max_
input
_length
==
3
assert
next_batch
.
requests
[
0
]
==
next_batch_0
.
requests
[
0
]
assert
next_batch
.
requests
[
1
:]
==
next_batch_1
.
requests
...
...
server/tests/models/test_santacoder.py
View file @
b49dbf2d
...
...
@@ -15,7 +15,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"def"
,
input_length
=
1
,
parameters
=
default_pb_parameters
,
stopping_parameters
=
default_pb_stop_parameters
,
)
...
...
@@ -31,7 +30,6 @@ def default_fim_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"<fim-prefix>def<fim-suffix>world<fim-middle>"
,
input_length
=
5
,
parameters
=
default_pb_parameters
,
stopping_parameters
=
default_pb_stop_parameters
,
)
...
...
server/tests/models/test_seq2seq_lm.py
View file @
b49dbf2d
...
...
@@ -28,7 +28,6 @@ def default_pb_request(default_pb_parameters, default_pb_stop_parameters):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"Test"
,
input_length
=
2
,
parameters
=
default_pb_parameters
,
stopping_parameters
=
default_pb_stop_parameters
,
)
...
...
server/text_generation_server/models/causal_lm.py
View file @
b49dbf2d
...
...
@@ -41,7 +41,7 @@ class CausalLMBatch(Batch):
# Metadata used for padding
size
:
int
max_
sequence
_length
:
int
max_
input
_length
:
int
padding_right_offset
:
int
# Past metadata
...
...
@@ -67,17 +67,14 @@ class CausalLMBatch(Batch):
input_lengths
=
[]
# Parse batch
max_sequence_length
=
0
padding_right_offset
=
0
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
)
stopping_criterias
.
append
(
stopping_criteria
)
max_sequence_length
=
max
(
max_sequence_length
,
r
.
input_length
)
padding_right_offset
=
max
(
padding_right_offset
,
stopping_criteria
.
max_new_tokens
)
...
...
@@ -89,13 +86,16 @@ class CausalLMBatch(Batch):
return_token_type_ids
=
False
,
).
to
(
device
)
input_lengths
=
tokenized_inputs
[
"attention_mask"
].
sum
(
1
)
max_input_length
=
input_lengths
.
max
()
input_ids
=
tokenized_inputs
[
"input_ids"
]
# Allocate maximum attention_mask
attention_mask
=
input_ids
.
new_zeros
(
(
pb
.
size
,
max_
sequence
_length
+
padding_right_offset
)
(
pb
.
size
,
max_
input
_length
+
padding_right_offset
)
)
# Copy tokenizer attention_mask into fully allocated attention_mask
attention_mask
[:,
:
max_
sequence
_length
]
=
tokenized_inputs
[
"attention_mask"
]
attention_mask
[:,
:
max_
input
_length
]
=
tokenized_inputs
[
"attention_mask"
]
position_ids
=
tokenized_inputs
[
"attention_mask"
].
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
tokenized_inputs
[
"attention_mask"
]
==
0
,
1
)
...
...
@@ -109,11 +109,11 @@ class CausalLMBatch(Batch):
position_ids
=
position_ids
,
past_key_values
=
None
,
all_input_ids
=
all_input_ids
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
.
tolist
()
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
pb
.
size
,
max_
sequence
_length
=
max_
sequence_length
,
max_
input
_length
=
max_
input_length
.
item
()
,
padding_right_offset
=
padding_right_offset
,
)
...
...
@@ -122,11 +122,11 @@ class CausalLMBatch(Batch):
def
concatenate
(
cls
,
batches
:
List
[
"CausalLMBatch"
])
->
"CausalLMBatch"
:
# Used for padding
total_batch_size
=
0
max_
sequence
_length
=
0
max_
input
_length
=
0
padding_right_offset
=
0
for
batch
in
batches
:
total_batch_size
+=
batch
.
size
max_
sequence
_length
=
max
(
max_
sequence
_length
,
batch
.
max_
sequence
_length
)
max_
input
_length
=
max
(
max_
input
_length
,
batch
.
max_
input
_length
)
padding_right_offset
=
max
(
padding_right_offset
,
batch
.
padding_right_offset
)
# Batch attributes
...
...
@@ -170,15 +170,15 @@ class CausalLMBatch(Batch):
# Create padded tensor
if
attention_mask
is
None
:
attention_mask
=
batch
.
attention_mask
.
new_zeros
(
(
total_batch_size
,
max_
sequence
_length
+
padding_right_offset
),
(
total_batch_size
,
max_
input
_length
+
padding_right_offset
),
)
# We need to slice the attention mask to remove padding from previous steps
# and to remove unused allocated space
left_offset
=
max_
sequence
_length
-
batch
.
max_
sequence
_length
left_offset
=
max_
input
_length
-
batch
.
max_
input
_length
batch_left_offset
=
(
batch
.
attention_mask
.
shape
[
1
]
-
batch
.
max_
sequence
_length
-
batch
.
max_
input
_length
-
batch
.
padding_right_offset
)
attention_mask
[
...
...
@@ -209,7 +209,7 @@ class CausalLMBatch(Batch):
padded_past_values_shape
=
(
total_batch_size
,
num_heads
,
max_
sequence
_length
-
1
,
max_
input
_length
-
1
,
head_dim
,
)
...
...
@@ -221,7 +221,7 @@ class CausalLMBatch(Batch):
total_batch_size
,
num_heads
,
head_dim
,
max_
sequence
_length
-
1
,
max_
input
_length
-
1
,
)
# This will run only once per layer
...
...
@@ -235,20 +235,20 @@ class CausalLMBatch(Batch):
past_key_values
[
j
][
0
][
start_index
:
end_index
,
:,
-
(
batch
.
max_
sequence
_length
-
1
)
:,
-
(
batch
.
max_
input
_length
-
1
)
:,
:,
]
=
past_keys
[:,
:,
-
(
batch
.
max_
sequence
_length
-
1
)
:,
:]
]
=
past_keys
[:,
:,
-
(
batch
.
max_
input
_length
-
1
)
:,
:]
else
:
past_key_values
[
j
][
0
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_
sequence
_length
-
1
)
:,
]
=
past_keys
[:,
:,
:,
-
(
batch
.
max_
sequence
_length
-
1
)
:]
-
(
batch
.
max_
input
_length
-
1
)
:,
]
=
past_keys
[:,
:,
:,
-
(
batch
.
max_
input
_length
-
1
)
:]
past_key_values
[
j
][
1
][
start_index
:
end_index
,
:,
-
(
batch
.
max_
sequence
_length
-
1
)
:,
:
]
=
past_values
[:,
:,
-
(
batch
.
max_
sequence
_length
-
1
)
:,
:]
start_index
:
end_index
,
:,
-
(
batch
.
max_
input
_length
-
1
)
:,
:
]
=
past_values
[:,
:,
-
(
batch
.
max_
input
_length
-
1
)
:,
:]
start_index
+=
batch
.
size
...
...
@@ -264,7 +264,7 @@ class CausalLMBatch(Batch):
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_
sequence
_length
=
max_
sequence
_length
,
max_
input
_length
=
max_
input
_length
,
padding_right_offset
=
padding_right_offset
,
keys_head_dim_last
=
batches
[
0
].
keys_head_dim_last
,
)
...
...
@@ -352,7 +352,7 @@ class CausalLM(Model):
# Metadata
next_batch_size
=
0
next_batch_max_
sequence
_length
=
0
next_batch_max_
input
_length
=
0
# Results
generations
:
List
[
Generation
]
=
[]
...
...
@@ -420,8 +420,8 @@ class CausalLM(Model):
next_batch_all_input_ids
.
append
(
all_input_ids
)
next_batch_size
+=
1
next_batch_input_lengths
.
append
(
new_input_length
)
next_batch_max_
sequence
_length
=
max
(
next_batch_max_
sequence
_length
,
new_input_length
next_batch_max_
input
_length
=
max
(
next_batch_max_
input
_length
,
new_input_length
)
# Prefill
...
...
@@ -506,7 +506,7 @@ class CausalLM(Model):
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_
sequence
_length
=
next_batch_max_
sequence
_length
,
max_
input
_length
=
next_batch_max_
input
_length
,
padding_right_offset
=
batch
.
padding_right_offset
-
1
,
keys_head_dim_last
=
batch
.
keys_head_dim_last
,
)
...
...
server/text_generation_server/models/seq2seq_lm.py
View file @
b49dbf2d
...
...
@@ -68,17 +68,14 @@ class Seq2SeqLMBatch(Batch):
inputs
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
input_lengths
=
[]
decoder_input_ids
=
[]
decoder_input_lengths
=
[]
# Parse batch
max_input_length
=
0
padding_right_offset
=
0
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
# Decoder sequence only contains the bos_token
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_lengths
.
append
(
1
)
...
...
@@ -87,7 +84,6 @@ class Seq2SeqLMBatch(Batch):
r
.
stopping_parameters
,
tokenizer
)
stopping_criterias
.
append
(
stopping_criteria
)
max_input_length
=
max
(
max_input_length
,
r
.
input_length
)
padding_right_offset
=
max
(
padding_right_offset
,
stopping_criteria
.
max_new_tokens
)
...
...
@@ -99,6 +95,10 @@ class Seq2SeqLMBatch(Batch):
padding
=
True
,
return_token_type_ids
=
False
,
).
to
(
device
)
input_lengths
=
tokenized_inputs
[
"attention_mask"
].
sum
(
1
)
max_input_length
=
input_lengths
.
max
()
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids
=
torch
.
tensor
(
decoder_input_ids
,
device
=
device
).
unsqueeze
(
-
1
)
...
...
@@ -111,12 +111,12 @@ class Seq2SeqLMBatch(Batch):
decoder_attention_mask
=
None
,
encoder_last_hidden_state
=
None
,
past_key_values
=
None
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
.
tolist
()
,
decoder_input_lengths
=
decoder_input_lengths
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
len
(
pb
.
requests
),
max_input_length
=
max
(
input_length
s
),
max_input_length
=
max
_
input_length
.
item
(
),
max_decoder_input_length
=
1
,
padding_right_offset
=
padding_right_offset
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment