Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
427d7cc4
You need to sign in or sign up before continuing.
Commit
427d7cc4
authored
Nov 04, 2022
by
OlivierDehaene
Browse files
feat(server): Support AutoModelForSeq2SeqLM
parent
c5665f5c
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
892 additions
and
376 deletions
+892
-376
README.md
README.md
+9
-2
proto/generate.proto
proto/generate.proto
+0
-2
router/src/db.rs
router/src/db.rs
+0
-4
server/text_generation/cache.py
server/text_generation/cache.py
+6
-4
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+7
-3
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+17
-14
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+339
-3
server/text_generation/models/model.py
server/text_generation/models/model.py
+8
-122
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+488
-0
server/text_generation/models/types.py
server/text_generation/models/types.py
+14
-219
server/text_generation/server.py
server/text_generation/server.py
+4
-3
No files found.
README.md
View file @
427d7cc4
...
@@ -15,12 +15,20 @@ A Rust and gRPC server for text generation inference.
...
@@ -15,12 +15,20 @@ A Rust and gRPC server for text generation inference.
-
[
Safetensors
](
https://github.com/huggingface/safetensors
)
weight loading
-
[
Safetensors
](
https://github.com/huggingface/safetensors
)
weight loading
-
45ms per token generation for BLOOM with 8xA100 80GB
-
45ms per token generation for BLOOM with 8xA100 80GB
##
S
upported models
##
Officialy s
upported models
-
BLOOM
-
BLOOM
-
BLOOMZ
-
BLOOMZ
-
BLOOM-560m
-
BLOOM-560m
Other models are supported on a best effort basis using:
`AutoModelForCausalLM.from_pretrained(<model>, device_map="auto")`
or
`AutoModelForSeq2SeqLM.from_pretrained(<model>, device_map="auto")`
## Load Tests for BLOOM
## Load Tests for BLOOM
See
`k6/load_test.js`
See
`k6/load_test.js`
...
@@ -81,7 +89,6 @@ make router-dev
...
@@ -81,7 +89,6 @@ make router-dev
## TODO:
## TODO:
-
[ ] Support AutoModelForSeq2SeqLM
-
[ ] Add tests for the
`server/model`
logic
-
[ ] Add tests for the
`server/model`
logic
-
[ ] Backport custom CUDA kernels to Transformers
-
[ ] Backport custom CUDA kernels to Transformers
-
[ ] Install safetensors with pip
-
[ ] Install safetensors with pip
\ No newline at end of file
proto/generate.proto
View file @
427d7cc4
...
@@ -54,8 +54,6 @@ message Batch {
...
@@ -54,8 +54,6 @@ message Batch {
repeated
Request
requests
=
2
;
repeated
Request
requests
=
2
;
/// Batch size (==len(requests))
/// Batch size (==len(requests))
uint32
size
=
3
;
uint32
size
=
3
;
/// Length of the longest sequence within the batch (used for padding)
uint32
max_sequence_length
=
4
;
}
}
message
GeneratedText
{
message
GeneratedText
{
...
...
router/src/db.rs
View file @
427d7cc4
...
@@ -142,14 +142,10 @@ impl Db {
...
@@ -142,14 +142,10 @@ impl Db {
// Batch size
// Batch size
let
size
=
requests
.len
();
let
size
=
requests
.len
();
// Longest input length for all requests in batch size
// Used for padding inside the inference server
let
max_sequence_length
=
requests
.iter
()
.map
(|
r
|
r
.input_length
)
.max
()
.unwrap
();
let
batch
=
Batch
{
let
batch
=
Batch
{
id
:
state
.next_batch_id
,
id
:
state
.next_batch_id
,
requests
,
requests
,
size
:
size
as
u32
,
size
:
size
as
u32
,
max_sequence_length
,
};
};
// Update next_batch_start_id to the last id in the batch + 1
// Update next_batch_start_id to the last id in the batch + 1
state
.next_batch_start_id
=
ids
.last
()
.unwrap
()
+
1
;
state
.next_batch_start_id
=
ids
.last
()
.unwrap
()
+
1
;
...
...
server/text_generation/cache.py
View file @
427d7cc4
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
,
TypeVar
from
text_generation.models.types
import
Batch
from
text_generation.models.types
import
Batch
B
=
TypeVar
(
"B"
,
bound
=
Batch
)
class
Cache
:
class
Cache
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
cache
:
Dict
[
int
,
B
atch
]
=
{}
self
.
cache
:
Dict
[
int
,
B
]
=
{}
def
pop
(
self
,
batch_id
:
int
)
->
Optional
[
B
atch
]:
def
pop
(
self
,
batch_id
:
int
)
->
Optional
[
B
]:
return
self
.
cache
.
pop
(
batch_id
,
None
)
return
self
.
cache
.
pop
(
batch_id
,
None
)
def
set
(
self
,
entry
:
B
atch
):
def
set
(
self
,
entry
:
B
):
if
entry
is
not
None
:
if
entry
is
not
None
:
self
.
cache
[
entry
.
batch_id
]
=
entry
self
.
cache
[
entry
.
batch_id
]
=
entry
...
...
server/text_generation/models/__init__.py
View file @
427d7cc4
from
text_generation.models.model
import
Model
from
text_generation.models.model
import
Model
from
text_generation.models.bloom
import
BLOOMSharded
from
text_generation.models.causal_lm
import
CausalLM
from
text_generation.models.causal_lm
import
CausalLM
from
text_generation.models.bloom
import
BLOOMSharded
from
text_generation.models.seq2seq_lm
import
Seq2SeqLM
__all__
=
[
"Model"
,
"BLOOMSharded"
,
"CausalLM"
]
__all__
=
[
"Model"
,
"BLOOMSharded"
,
"CausalLM"
,
"Seq2SeqLM"
]
def
get_model
(
model_name
:
str
,
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
def
get_model
(
model_name
:
str
,
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
...
@@ -18,4 +19,7 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
...
@@ -18,4 +19,7 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
raise
ValueError
(
"sharded is not supported for AutoModel"
)
raise
ValueError
(
"sharded is not supported for AutoModel"
)
if
quantize
:
if
quantize
:
raise
ValueError
(
"quantize is not supported for AutoModel"
)
raise
ValueError
(
"quantize is not supported for AutoModel"
)
try
:
return
CausalLM
(
model_name
)
return
CausalLM
(
model_name
)
except
Exception
as
e
:
return
Seq2SeqLM
(
model_name
)
server/text_generation/models/bloom.py
View file @
427d7cc4
...
@@ -12,7 +12,7 @@ from transformers.models.bloom.parallel_layers import (
...
@@ -12,7 +12,7 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelRowLinear
,
TensorParallelRowLinear
,
)
)
from
text_generation.models
import
Model
from
text_generation.models
import
CausalLM
from
text_generation.utils
import
(
from
text_generation.utils
import
(
initialize_torch_distributed
,
initialize_torch_distributed
,
weight_files
,
weight_files
,
...
@@ -29,7 +29,7 @@ except Exception as e:
...
@@ -29,7 +29,7 @@ except Exception as e:
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
class
BLOOMSharded
(
Model
):
class
BLOOMSharded
(
CausalLM
):
def
__init__
(
self
,
model_name
:
str
,
quantize
:
bool
=
False
):
def
__init__
(
self
,
model_name
:
str
,
quantize
:
bool
=
False
):
if
not
model_name
.
startswith
(
"bigscience/bloom"
):
if
not
model_name
.
startswith
(
"bigscience/bloom"
):
raise
ValueError
(
f
"Model
{
model_name
}
is not supported"
)
raise
ValueError
(
f
"Model
{
model_name
}
is not supported"
)
...
@@ -78,8 +78,11 @@ class BLOOMSharded(Model):
...
@@ -78,8 +78,11 @@ class BLOOMSharded(Model):
)
)
self
.
model
=
model
.
eval
().
to
(
dtype
)
self
.
model
=
model
.
eval
().
to
(
dtype
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
BLOOMSharded
,
self
).
__init__
(
tokenizer
=
tokenizer
,
num_heads
=
config
.
n_head
//
self
.
process_group
.
size
(),
super
(
CausalLM
,
self
).
__init__
(
device
=
device
)
tokenizer
=
tokenizer
,
num_heads
=
config
.
n_head
//
self
.
process_group
.
size
(),
device
=
device
,
)
@
staticmethod
@
staticmethod
def
load_weights
(
def
load_weights
(
...
...
server/text_generation/models/causal_lm.py
View file @
427d7cc4
import
torch
import
torch
from
dataclasses
import
dataclass
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
from
typing
import
Optional
,
Tuple
,
List
from
typing
import
Optional
,
Tuple
,
List
,
Dict
,
Type
from
text_generation.models
import
Model
from
text_generation.models
import
Model
from
text_generation.models.types
import
GeneratedText
from
text_generation.pb
import
generate_pb2
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
@
dataclass
class
CausalLMBatch
:
batch_id
:
int
requests
:
List
[
generate_pb2
.
Request
]
all_input_lengths
:
List
[
int
]
input_ids
:
Dict
[
str
,
torch
.
Tensor
]
all_input_ids
:
List
[
torch
.
Tensor
]
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
size
:
int
max_sequence_length
:
int
def
to_pb
(
self
):
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
self
.
size
,
)
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
AutoTokenizer
,
device
:
torch
.
device
)
->
"CausalLMBatch"
:
inputs
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
all_input_lengths
=
[]
# Parse batch
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
all_input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
(
temperature
=
r
.
parameters
.
temperature
,
top_k
=
r
.
parameters
.
top_k
,
top_p
=
r
.
parameters
.
top_p
,
do_sample
=
r
.
parameters
.
do_sample
,
)
)
stopping_criterias
.
append
(
StoppingCriteria
(
eos_token_id
=
tokenizer
.
eos_token_id
,
max_new_tokens
=
r
.
max_new_tokens
)
)
input_ids
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
8
).
to
(
device
)
all_input_ids
=
input_ids
[
"input_ids"
].
unsqueeze
(
-
1
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
pb
.
size
,
max_sequence_length
=
max
(
all_input_lengths
),
)
@
classmethod
def
concatenate
(
cls
,
batches
:
List
[
"CausalLMBatch"
])
->
"CausalLMBatch"
:
# Used for padding
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_sequence_length
=
max
(
batch
.
max_sequence_length
for
batch
in
batches
)
# Batch attributes
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
requests
=
[]
all_input_lengths
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index
=
0
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
all_input_lengths
.
extend
(
batch
.
all_input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Slicing end index for this batch
end_index
=
start_index
+
batch
.
size
# We only concatenate batches that did at least one step
if
batch
.
input_ids
[
"input_ids"
].
shape
[
1
]
>
1
:
raise
ValueError
(
"Batch input_ids should be of shape (batch_size, 1)"
)
# Initialize tensors
if
i
==
0
:
input_ids
[
"input_ids"
]
=
torch
.
empty
(
(
total_batch_size
,
1
),
dtype
=
batch
.
input_ids
[
"input_ids"
].
dtype
,
device
=
batch
.
input_ids
[
"input_ids"
].
device
,
)
input_ids
[
"attention_mask"
]
=
torch
.
zeros
(
(
total_batch_size
,
max_sequence_length
),
dtype
=
batch
.
input_ids
[
"attention_mask"
].
dtype
,
device
=
batch
.
input_ids
[
"attention_mask"
].
device
,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids
[
"input_ids"
][
start_index
:
end_index
]
=
batch
.
input_ids
[
"input_ids"
]
# We need to slice the attention mask to remove padding from previous steps
input_ids
[
"attention_mask"
][
start_index
:
end_index
,
-
batch
.
max_sequence_length
:
]
=
batch
.
input_ids
[
"attention_mask"
][:,
-
batch
.
max_sequence_length
:]
for
j
,
past
in
enumerate
(
batch
.
input_ids
[
"past_key_values"
]):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
head_dim
,
padded_sequence_length
=
past
[
0
].
shape
[
-
2
:]
num_heads
=
(
past
[
0
]
.
view
(
batch
.
size
,
-
1
,
head_dim
,
padded_sequence_length
)
.
shape
[
1
]
)
# This will run only once per layer
if
j
==
len
(
input_ids
[
"past_key_values"
]):
input_ids
[
"past_key_values"
].
append
([])
# Decoder past
for
k
,
t
in
enumerate
(
past
):
# Needed because BLOOM past shapes are not the same for keys and values
# Keys: [batch_size * num_heads, head_dim, seq_length]
# Values: [batch_size * num_heads, seq_length, head_dim]
head_dim_last
=
False
if
t
.
shape
[
-
2
]
==
head_dim
:
t
=
t
.
view
(
batch
.
size
,
num_heads
,
head_dim
,
padded_sequence_length
)
padded_t_shape
=
(
total_batch_size
,
num_heads
,
head_dim
,
max_sequence_length
-
1
,
)
elif
t
.
shape
[
-
1
]
==
head_dim
:
head_dim_last
=
True
t
=
t
.
view
(
batch
.
size
,
num_heads
,
padded_sequence_length
,
head_dim
)
padded_t_shape
=
(
total_batch_size
,
num_heads
,
max_sequence_length
-
1
,
head_dim
,
)
else
:
raise
ValueError
(
f
"shape
{
t
.
shape
}
is not valid"
)
# Initialize tensors
# This will run only once per layer and per past tensor
if
k
==
len
(
input_ids
[
"past_key_values"
][
j
]):
input_ids
[
"past_key_values"
][
j
].
append
(
torch
.
zeros
(
padded_t_shape
,
dtype
=
t
.
dtype
,
device
=
t
.
device
)
)
# We slice the past keys and values to remove the padding from previous batches
if
not
head_dim_last
:
input_ids
[
"past_key_values"
][
j
][
k
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
]
=
t
[:,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:]
else
:
input_ids
[
"past_key_values"
][
j
][
k
][
start_index
:
end_index
,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
:,
]
=
t
[:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
:]
start_index
+=
batch
.
size
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_sequence_length
=
max_sequence_length
,
)
class
CausalLM
(
Model
):
class
CausalLM
(
Model
):
...
@@ -23,7 +225,15 @@ class CausalLM(Model):
...
@@ -23,7 +225,15 @@ class CausalLM(Model):
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
).
eval
()
).
eval
()
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
device
=
device
)
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
device
=
device
,
)
@
property
def
batch_type
(
self
)
->
Type
[
CausalLMBatch
]:
return
CausalLMBatch
def
forward
(
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
...
@@ -36,3 +246,129 @@ class CausalLM(Model):
...
@@ -36,3 +246,129 @@ class CausalLM(Model):
use_cache
=
True
,
use_cache
=
True
,
)
)
return
outputs
.
logits
,
outputs
.
past_key_values
return
outputs
.
logits
,
outputs
.
past_key_values
def
generate_token
(
self
,
batch
:
CausalLMBatch
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
CausalLMBatch
]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager
=
(
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
)
with
context_manager
():
logits
,
past
=
self
.
forward
(
**
batch
.
input_ids
)
# List of indices to cache
next_batch_keep_indices
=
[]
# New input_ids for next forward
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_all_input_lengths
=
[]
next_batch_size
=
0
next_batch_max_sequence_length
=
0
# Finished requests
generated_texts
:
List
[
GeneratedText
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
all_input_lengths
,
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
logits
,
next_token_chooser
,
stopping_criteria
,
all_tokens
,
)
in
enumerate
(
iterator
):
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
])
# Append next token to all tokens
all_tokens
=
torch
.
cat
([
all_tokens
,
next_token
])
# Evaluate stopping criteria
if
stopping_criteria
(
all_tokens
):
# Decode all tokens
output
=
self
.
tokenizer
.
decode
(
all_tokens
.
squeeze
(
-
1
),
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
request
,
output
,
stopping_criteria
.
current_tokens
)
)
# add to the next batch
else
:
next_batch_keep_indices
.
append
(
i
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_size
+=
1
new_input_length
=
input_length
+
1
next_all_input_lengths
.
append
(
new_input_length
)
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
new_input_length
)
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generated_texts
,
None
# If we finished at least one generation
next_batch_input_ids
=
{
"input_ids"
:
torch
.
cat
(
next_batch_input_ids
,
dim
=
0
)}
if
generated_texts
:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
][
next_batch_keep_indices
]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_input_ids
[
"past_key_values"
]
=
[
[
t
.
view
(
-
1
,
self
.
num_heads
,
*
t
.
shape
[
-
2
:])[
next_batch_keep_indices
]
for
t
in
layer
]
for
layer
in
past
]
next_batch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
]
next_batch_input_ids
[
"past_key_values"
]
=
past
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids
[
"attention_mask"
]
=
torch
.
cat
(
[
next_batch_input_ids
[
"attention_mask"
],
torch
.
ones
((
next_batch_size
,
1
)).
to
(
self
.
device
),
],
dim
=
1
,
)
next_batch
=
CausalLMBatch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
all_input_lengths
=
next_all_input_lengths
,
input_ids
=
next_batch_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_sequence_length
=
next_batch_max_sequence_length
,
)
return
generated_texts
,
next_batch
server/text_generation/models/model.py
View file @
427d7cc4
import
torch
import
torch
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Tuple
,
Optional
from
typing
import
List
,
Tuple
,
Optional
,
TypeVar
,
Type
from
tokenizers
import
Tokenizer
from
tokenizers
import
Tokenizer
from
text_generation.models.types
import
Batch
,
GeneratedText
from
text_generation.models.types
import
Batch
,
GeneratedText
B
=
TypeVar
(
"B"
,
bound
=
Batch
)
class
Model
(
ABC
):
class
Model
(
ABC
):
def
__init__
(
self
,
tokenizer
:
Tokenizer
,
num_heads
:
int
,
device
:
torch
.
device
):
def
__init__
(
self
,
tokenizer
:
Tokenizer
,
num_heads
:
int
,
device
:
torch
.
device
):
...
@@ -13,127 +15,11 @@ class Model(ABC):
...
@@ -13,127 +15,11 @@ class Model(ABC):
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
device
=
device
self
.
device
=
device
@
property
@
abstractmethod
@
abstractmethod
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
)
->
Tuple
[
torch
.
Tensor
,
List
[
Tuple
]
]:
def
batch_type
(
self
)
->
Type
[
B
]:
raise
NotImplementedError
raise
NotImplementedError
def
generate_token
(
@
abstractmethod
self
,
batch
:
Batch
def
generate_token
(
self
,
batch
:
B
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
B
]]:
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
Batch
]]:
raise
NotImplementedError
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager
=
(
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
)
with
context_manager
():
logits
,
past
=
self
.
forward
(
**
batch
.
input_ids
)
# List of indices to cache
next_batch_keep_indices
=
[]
# New input_ids for next forward
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_all_input_lengths
=
[]
next_batch_size
=
0
next_batch_max_sequence_length
=
0
# Finished requests
generated_texts
:
List
[
GeneratedText
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
all_input_lengths
,
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
logits
,
next_token_chooser
,
stopping_criteria
,
all_tokens
,
)
in
enumerate
(
iterator
):
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
])
# Append next token to all tokens
all_tokens
=
torch
.
cat
([
all_tokens
,
next_token
])
# Evaluate stopping criteria
if
stopping_criteria
(
all_tokens
):
# Decode all tokens
output
=
self
.
tokenizer
.
decode
(
all_tokens
.
squeeze
(
-
1
),
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
request
,
output
,
stopping_criteria
.
current_tokens
))
# add to the next batch
else
:
next_batch_keep_indices
.
append
(
i
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_size
+=
1
new_input_length
=
input_length
+
1
next_all_input_lengths
.
append
(
new_input_length
)
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
new_input_length
)
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generated_texts
,
None
# If we finished at least one generation
next_batch_input_ids
=
{
"input_ids"
:
torch
.
cat
(
next_batch_input_ids
,
dim
=
0
)}
if
generated_texts
:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
][
next_batch_keep_indices
]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_input_ids
[
"past_key_values"
]
=
[
[
t
.
view
(
-
1
,
self
.
num_heads
,
*
t
.
shape
[
-
2
:])[
next_batch_keep_indices
]
for
t
in
layer
]
for
layer
in
past
]
next_batch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
]
next_batch_input_ids
[
"past_key_values"
]
=
past
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids
[
"attention_mask"
]
=
torch
.
cat
(
[
next_batch_input_ids
[
"attention_mask"
],
torch
.
ones
((
next_batch_size
,
1
)).
to
(
self
.
device
),
],
dim
=
1
,
)
next_batch
=
Batch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
all_input_lengths
=
next_all_input_lengths
,
input_ids
=
next_batch_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_sequence_length
=
next_batch_max_sequence_length
,
)
return
generated_texts
,
next_batch
server/text_generation/models/seq2seq_lm.py
0 → 100644
View file @
427d7cc4
import
torch
from
dataclasses
import
dataclass
from
transformers
import
AutoTokenizer
,
AutoModelForSeq2SeqLM
from
typing
import
Optional
,
Tuple
,
List
,
Type
from
text_generation.models
import
Model
from
text_generation.models.types
import
GeneratedText
from
text_generation.pb
import
generate_pb2
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
@
dataclass
class
Seq2SeqLMBatch
:
batch_id
:
int
requests
:
List
[
generate_pb2
.
Request
]
input_ids
:
torch
.
Tensor
attention_mask
:
torch
.
Tensor
decoder_input_ids
:
torch
.
Tensor
decoder_attention_mask
:
Optional
[
torch
.
Tensor
]
encoder_last_hidden_state
:
Optional
[
torch
.
Tensor
]
past_key_values
:
Optional
[
List
[
Tuple
]]
input_lengths
:
List
[
int
]
decoder_input_lengths
:
List
[
int
]
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
size
:
int
max_input_length
:
int
max_decoder_input_length
:
int
def
to_pb
(
self
):
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
self
.
size
,
)
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
AutoTokenizer
,
device
:
torch
.
device
)
->
"Seq2SeqLMBatch"
:
inputs
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
input_lengths
=
[]
decoder_input_ids
=
[]
decoder_input_lengths
=
[]
# Parse batch
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_lengths
.
append
(
1
)
next_token_choosers
.
append
(
NextTokenChooser
(
temperature
=
r
.
parameters
.
temperature
,
top_k
=
r
.
parameters
.
top_k
,
top_p
=
r
.
parameters
.
top_p
,
do_sample
=
r
.
parameters
.
do_sample
,
)
)
stopping_criterias
.
append
(
StoppingCriteria
(
eos_token_id
=
tokenizer
.
eos_token_id
,
max_new_tokens
=
r
.
max_new_tokens
)
)
tokenized_inputs
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
8
).
to
(
device
)
decoder_input_ids
=
torch
.
tensor
(
decoder_input_ids
).
to
(
device
).
unsqueeze
(
-
1
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
input_ids
=
tokenized_inputs
[
"input_ids"
],
attention_mask
=
tokenized_inputs
[
"attention_mask"
],
decoder_input_ids
=
decoder_input_ids
,
decoder_attention_mask
=
None
,
encoder_last_hidden_state
=
None
,
past_key_values
=
None
,
input_lengths
=
input_lengths
,
decoder_input_lengths
=
decoder_input_lengths
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
len
(
pb
.
requests
),
max_input_length
=
max
(
input_lengths
),
max_decoder_input_length
=
1
,
)
@
classmethod
def
concatenate
(
cls
,
batches
:
List
[
"Seq2SeqLMBatch"
])
->
"Seq2SeqLMBatch"
:
# Used for padding
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_input_length
=
max
(
batch
.
max_input_length
for
batch
in
batches
)
max_decoder_input_length
=
max
(
batch
.
max_decoder_input_length
for
batch
in
batches
)
# Batch attributes
requests
=
[]
input_lengths
=
[]
decoder_input_lengths
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
input_ids
=
None
attention_mask
=
None
decoder_input_ids
=
None
decoder_attention_mask
=
None
encoder_last_hidden_state
=
None
past_key_values
=
[]
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index
=
0
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
decoder_input_lengths
.
extend
(
batch
.
decoder_input_lengths
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Slicing end index for this batch
end_index
=
start_index
+
batch
.
size
# We only concatenate batches that did at least one step
if
batch
.
encoder_last_hidden_state
is
None
:
raise
ValueError
(
"Batch encoder_last_hidden_state cannot be None"
)
if
input_ids
is
None
:
input_ids
=
torch
.
zeros
(
(
total_batch_size
,
max_input_length
),
dtype
=
batch
.
input_ids
.
dtype
,
device
=
batch
.
input_ids
.
device
,
)
input_ids
[
start_index
:
end_index
,
-
batch
.
max_input_length
:
]
=
batch
.
input_ids
[:,
-
batch
.
max_input_length
:]
if
attention_mask
is
None
:
attention_mask
=
torch
.
zeros
(
(
total_batch_size
,
max_input_length
),
dtype
=
batch
.
attention_mask
.
dtype
,
device
=
batch
.
attention_mask
.
device
,
)
attention_mask
[
start_index
:
end_index
,
-
batch
.
max_input_length
:
]
=
batch
.
attention_mask
[:,
-
batch
.
max_input_length
:]
if
decoder_input_ids
is
None
:
decoder_input_ids
=
torch
.
zeros
(
(
total_batch_size
,
max_decoder_input_length
),
dtype
=
batch
.
decoder_input_ids
.
dtype
,
device
=
batch
.
decoder_input_ids
.
device
,
)
decoder_input_ids
[
start_index
:
end_index
,
-
batch
.
max_decoder_input_length
:
]
=
batch
.
decoder_input_ids
[:,
-
batch
.
max_decoder_input_length
:]
if
decoder_attention_mask
is
None
:
decoder_attention_mask
=
torch
.
zeros
(
(
total_batch_size
,
max_decoder_input_length
),
dtype
=
batch
.
attention_mask
.
dtype
,
device
=
batch
.
attention_mask
.
device
,
)
if
batch
.
decoder_attention_mask
is
None
:
decoder_attention_mask
[
start_index
:
end_index
,
-
batch
.
max_decoder_input_length
:
]
=
1
else
:
decoder_attention_mask
[
start_index
:
end_index
,
-
batch
.
max_decoder_input_length
:
]
=
batch
.
decoder_attention_mask
[:,
-
batch
.
max_decoder_input_length
:]
if
encoder_last_hidden_state
is
None
:
encoder_last_hidden_state
=
torch
.
zeros
(
(
total_batch_size
,
max_input_length
,
batch
.
encoder_last_hidden_state
.
shape
[
-
1
],
),
dtype
=
batch
.
encoder_last_hidden_state
.
dtype
,
device
=
batch
.
encoder_last_hidden_state
.
device
,
)
encoder_last_hidden_state
[
start_index
:
end_index
,
-
batch
.
max_decoder_input_length
:,
:
]
=
batch
.
encoder_last_hidden_state
[:,
-
batch
.
max_decoder_input_length
:,
:]
for
j
,
past
in
enumerate
(
batch
.
past_key_values
):
_
,
num_heads
,
_
,
head_dim
=
past
[
0
].
shape
# This will run only once per layer
if
j
==
len
(
past_key_values
):
past_key_values
.
append
([])
# Decoder past
for
k
,
t
in
enumerate
(
past
[:
2
]):
padded_t_shape
=
(
total_batch_size
,
num_heads
,
(
max_decoder_input_length
-
1
),
head_dim
,
)
# Initialize tensors
# This will run only once per layer and per past tensor
if
k
==
len
(
past_key_values
[
j
]):
past_key_values
[
j
].
append
(
torch
.
zeros
(
padded_t_shape
,
dtype
=
t
.
dtype
,
device
=
t
.
device
)
)
# We slice the past keys and values to remove the padding from previous batches
past_key_values
[
j
][
k
][
start_index
:
end_index
,
:,
-
(
batch
.
max_decoder_input_length
-
1
)
:,
:,
]
=
t
[:,
:,
-
(
batch
.
max_decoder_input_length
-
1
)
:,
:]
# encoder past
for
k
,
t
in
enumerate
(
past
[
2
:]):
padded_t_shape
=
(
total_batch_size
,
num_heads
,
max_input_length
,
head_dim
,
)
idx
=
k
+
2
# Initialize tensors
# This will run only once per layer and per past tensor
if
idx
==
len
(
past_key_values
[
j
]):
past_key_values
[
j
].
append
(
torch
.
zeros
(
padded_t_shape
,
dtype
=
t
.
dtype
,
device
=
t
.
device
)
)
past_key_values
[
j
][
idx
][
start_index
:
end_index
,
:,
-
batch
.
max_input_length
:,
:
]
=
t
[:,
:,
-
batch
.
max_input_length
:,
:]
start_index
+=
batch
.
size
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attention_mask
=
decoder_attention_mask
,
encoder_last_hidden_state
=
encoder_last_hidden_state
,
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
decoder_input_lengths
=
decoder_input_lengths
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_input_length
=
max_input_length
,
max_decoder_input_length
=
max_decoder_input_length
,
)
class
Seq2SeqLM
(
Model
):
def
__init__
(
self
,
model_name
:
str
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
else
:
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
self
.
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
).
eval
()
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
tokenizer
.
bos_token_id
=
self
.
model
.
config
.
decoder_start_token_id
super
(
Seq2SeqLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
device
=
device
,
)
@
property
def
batch_type
(
self
)
->
Type
[
Seq2SeqLMBatch
]:
return
Seq2SeqLMBatch
def
forward
(
self
,
input_ids
,
attention_mask
,
decoder_input_ids
,
decoder_attention_mask
:
Optional
,
encoder_last_hidden_state
:
Optional
,
past_key_values
:
Optional
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]],
]:
# Model Forward
if
past_key_values
is
not
None
:
decoder_input_ids
=
decoder_input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attention_mask
=
decoder_attention_mask
,
encoder_outputs
=
[
encoder_last_hidden_state
]
if
encoder_last_hidden_state
is
not
None
else
None
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
return
(
outputs
.
logits
,
outputs
.
encoder_last_hidden_state
,
outputs
.
past_key_values
,
)
def
generate_token
(
self
,
batch
:
Seq2SeqLMBatch
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
Seq2SeqLMBatch
]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager
=
(
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
)
with
context_manager
():
logits
,
encoder_last_hidden_state
,
past
=
self
.
forward
(
batch
.
input_ids
,
batch
.
attention_mask
,
batch
.
decoder_input_ids
,
batch
.
decoder_attention_mask
,
batch
.
encoder_last_hidden_state
,
batch
.
past_key_values
,
)
# List of indices to cache
next_batch_keep_indices
=
[]
# New input_ids for next forward
next_batch_input_lengths
=
[]
next_batch_decoder_input_ids
=
[]
next_batch_decoder_input_lengths
=
[]
next_batch_size
=
0
next_batch_max_input_length
=
0
next_batch_max_decoder_input_length
=
0
# Finished requests
generated_texts
:
List
[
GeneratedText
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
input_lengths
,
batch
.
decoder_input_lengths
,
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
input_ids
,
batch
.
decoder_input_ids
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
decoder_input_length
,
logits
,
next_token_chooser
,
stopping_criteria
,
input_tokens
,
decoder_tokens
,
)
in
enumerate
(
iterator
):
all_tokens
=
torch
.
cat
([
input_tokens
,
decoder_tokens
])
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
])
# Append next token to decoder tokens
decoder_tokens
=
torch
.
cat
([
decoder_tokens
,
next_token
.
squeeze
(
1
)])
# Evaluate stopping criteria
if
stopping_criteria
(
decoder_tokens
):
# Decode all tokens
output
=
self
.
tokenizer
.
decode
(
decoder_tokens
,
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
request
,
output
,
stopping_criteria
.
current_tokens
)
)
# add to the next batch
else
:
next_batch_keep_indices
.
append
(
i
)
next_batch_decoder_input_ids
.
append
(
decoder_tokens
.
unsqueeze
(
0
))
next_batch_size
+=
1
new_decoder_input_length
=
decoder_input_length
+
1
next_batch_input_lengths
.
append
(
input_length
)
next_batch_decoder_input_lengths
.
append
(
new_decoder_input_length
)
next_batch_max_input_length
=
max
(
next_batch_max_input_length
,
input_length
)
next_batch_max_decoder_input_length
=
max
(
next_batch_max_decoder_input_length
,
new_decoder_input_length
)
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generated_texts
,
None
# If we finished at least one generation
next_batch_decoder_input_ids
=
torch
.
cat
(
next_batch_decoder_input_ids
)
if
generated_texts
:
next_batch_input_ids
=
batch
.
input_ids
[
next_batch_keep_indices
]
next_batch_attention_mask
=
batch
.
attention_mask
[
next_batch_keep_indices
]
if
batch
.
decoder_attention_mask
is
not
None
:
next_batch_decoder_attention_mask
=
batch
.
decoder_attention_mask
[
next_batch_keep_indices
]
else
:
next_batch_decoder_attention_mask
=
None
next_batch_encoder_last_hidden_state
=
encoder_last_hidden_state
[
next_batch_keep_indices
]
next_batch_past_key_values
=
[
[
t
[
next_batch_keep_indices
]
for
t
in
layer
]
for
layer
in
past
]
next_batch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_input_ids
=
batch
.
input_ids
next_batch_attention_mask
=
batch
.
attention_mask
next_batch_decoder_attention_mask
=
batch
.
decoder_attention_mask
next_batch_encoder_last_hidden_state
=
encoder_last_hidden_state
next_batch_past_key_values
=
past
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
if
next_batch_decoder_attention_mask
is
not
None
:
next_batch_decoder_attention_mask
=
torch
.
cat
(
[
next_batch_decoder_attention_mask
,
torch
.
ones
((
next_batch_size
,
1
)).
to
(
self
.
device
),
],
dim
=
1
,
)
next_batch
=
Seq2SeqLMBatch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
input_ids
=
next_batch_input_ids
,
attention_mask
=
next_batch_attention_mask
,
decoder_input_ids
=
next_batch_decoder_input_ids
,
decoder_attention_mask
=
next_batch_decoder_attention_mask
,
encoder_last_hidden_state
=
next_batch_encoder_last_hidden_state
,
past_key_values
=
next_batch_past_key_values
,
input_lengths
=
next_batch_input_lengths
,
decoder_input_lengths
=
next_batch_decoder_input_lengths
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_input_length
=
next_batch_max_input_length
,
max_decoder_input_length
=
next_batch_max_decoder_input_length
,
)
return
generated_texts
,
next_batch
server/text_generation/models/types.py
View file @
427d7cc4
import
torch
import
torch
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Dict
from
typing
import
List
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
from
text_generation.pb
import
generate_pb2
from
text_generation.pb
import
generate_pb2
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
@
dataclass
class
Batch
(
ABC
):
class
Batch
:
@
abstractmethod
batch_id
:
int
def
to_pb
(
self
)
->
generate_pb2
.
Batch
:
requests
:
List
[
generate_pb2
.
Request
]
raise
NotImplementedError
all_input_lengths
:
List
[
int
]
input_ids
:
Dict
[
str
,
torch
.
Tensor
]
all_input_ids
:
List
[
torch
.
Tensor
]
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
size
:
int
max_sequence_length
:
int
def
to_pb
(
self
):
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
self
.
size
,
max_sequence_length
=
self
.
max_sequence_length
,
)
@
classmethod
@
classmethod
@
abstractmethod
def
from_pb
(
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
AutoTokenizer
,
device
:
torch
.
device
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
AutoTokenizer
,
device
:
torch
.
device
)
->
"Batch"
:
)
->
"Batch"
:
inputs
=
[]
raise
NotImplementedError
next_token_choosers
=
[]
stopping_criterias
=
[]
all_input_lengths
=
[]
# Parse batch
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
all_input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
(
temperature
=
r
.
parameters
.
temperature
,
top_k
=
r
.
parameters
.
top_k
,
top_p
=
r
.
parameters
.
top_p
,
do_sample
=
r
.
parameters
.
do_sample
,
)
)
stopping_criterias
.
append
(
StoppingCriteria
(
eos_token_id
=
tokenizer
.
eos_token_id
,
max_new_tokens
=
r
.
max_new_tokens
)
)
input_ids
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
8
).
to
(
device
)
all_input_ids
=
input_ids
[
"input_ids"
].
unsqueeze
(
-
1
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
pb
.
size
,
max_sequence_length
=
pb
.
max_sequence_length
,
)
@
classmethod
@
classmethod
@
abstractmethod
def
concatenate
(
cls
,
batches
:
List
[
"Batch"
])
->
"Batch"
:
def
concatenate
(
cls
,
batches
:
List
[
"Batch"
])
->
"Batch"
:
# Used for padding
raise
NotImplementedError
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_sequence_length
=
max
(
batch
.
max_sequence_length
for
batch
in
batches
)
# Only needed for Seq2SeqLM
max_encoded_sequence_length
=
None
# Batch attributes
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
requests
=
[]
all_input_lengths
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index
=
0
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
all_input_lengths
.
extend
(
batch
.
all_input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Slicing end index for this batch
end_index
=
start_index
+
batch
.
size
# We only concatenate batches that did at least one step
if
batch
.
input_ids
[
"input_ids"
].
shape
[
1
]
>
1
:
raise
ValueError
(
"Batch input_ids should be of shape (batch_size, 1)"
)
# Initialize tensors
if
i
==
0
:
input_ids
[
"input_ids"
]
=
torch
.
empty
(
(
total_batch_size
,
1
),
dtype
=
batch
.
input_ids
[
"input_ids"
].
dtype
,
device
=
batch
.
input_ids
[
"input_ids"
].
device
,
)
input_ids
[
"attention_mask"
]
=
torch
.
zeros
(
(
total_batch_size
,
max_sequence_length
),
dtype
=
batch
.
input_ids
[
"attention_mask"
].
dtype
,
device
=
batch
.
input_ids
[
"attention_mask"
].
device
,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids
[
"input_ids"
][
start_index
:
end_index
]
=
batch
.
input_ids
[
"input_ids"
]
# We need to slice the attention mask to remove padding from previous steps
input_ids
[
"attention_mask"
][
start_index
:
end_index
,
-
batch
.
max_sequence_length
:
]
=
batch
.
input_ids
[
"attention_mask"
][:,
-
batch
.
max_sequence_length
:]
for
j
,
past
in
enumerate
(
batch
.
input_ids
[
"past_key_values"
]):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
head_dim
,
padded_sequence_length
=
past
[
0
].
shape
[
-
2
:]
num_heads
=
(
past
[
0
]
.
view
(
batch
.
size
,
-
1
,
head_dim
,
padded_sequence_length
)
.
shape
[
1
]
)
# This will run only once per layer
if
j
==
len
(
input_ids
[
"past_key_values"
]):
input_ids
[
"past_key_values"
].
append
([])
# Decoder past
for
k
,
t
in
enumerate
(
past
[:
2
]):
# Needed because BLOOM past shapes are not the same for keys and values
# Keys: [batch_size * num_heads, head_dim, seq_length]
# Values: [batch_size * num_heads, seq_length, head_dim]
head_dim_last
=
False
if
t
.
shape
[
-
2
]
==
head_dim
:
t
=
t
.
view
(
batch
.
size
,
num_heads
,
head_dim
,
padded_sequence_length
)
padded_t_shape
=
(
total_batch_size
,
num_heads
,
head_dim
,
max_sequence_length
-
1
,
)
elif
t
.
shape
[
-
1
]
==
head_dim
:
head_dim_last
=
True
t
=
t
.
view
(
batch
.
size
,
num_heads
,
padded_sequence_length
,
head_dim
)
padded_t_shape
=
(
total_batch_size
,
num_heads
,
max_sequence_length
-
1
,
head_dim
,
)
else
:
raise
ValueError
(
f
"shape
{
t
.
shape
}
is not valid"
)
# Initialize tensors
# This will run only once per layer and per past tensor
if
k
==
len
(
input_ids
[
"past_key_values"
][
j
]):
input_ids
[
"past_key_values"
][
j
].
append
(
torch
.
zeros
(
padded_t_shape
,
dtype
=
t
.
dtype
,
device
=
t
.
device
)
)
# We slice the past keys and values to remove the padding from previous batches
if
not
head_dim_last
:
input_ids
[
"past_key_values"
][
j
][
k
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_sequence_length
-
1
):,
]
=
t
[:,
:,
:,
-
(
batch
.
max_sequence_length
-
1
):]
else
:
input_ids
[
"past_key_values"
][
j
][
k
][
start_index
:
end_index
,
:,
-
(
batch
.
max_sequence_length
-
1
):,
:,
]
=
t
[:,
:,
-
(
batch
.
max_sequence_length
-
1
):,
:]
# Seq2SeqLM specific past (encoder past)
for
k
,
t
in
enumerate
(
past
[
2
:]):
if
max_encoded_sequence_length
is
None
:
max_encoded_sequence_length
=
max
(
max
(
batch
.
all_input_lengths
)
for
batch
in
batches
)
batch_max_encoded_sequence_length
=
max
(
batch
.
all_input_lengths
)
padded_t_shape
=
(
total_batch_size
,
num_heads
,
max_encoded_sequence_length
,
head_dim
)
idx
=
k
+
2
# Initialize tensors
# This will run only once per layer and per past tensor
if
idx
==
len
(
input_ids
[
"past_key_values"
][
j
]):
input_ids
[
"past_key_values"
][
j
].
append
(
torch
.
zeros
(
padded_t_shape
,
dtype
=
t
.
dtype
,
device
=
t
.
device
)
)
input_ids
[
"past_key_values"
][
j
][
idx
][
start_index
:
end_index
,
:,
-
batch_max_encoded_sequence_length
:,
:
]
=
t
[:,
:,
-
batch_max_encoded_sequence_length
:,
:]
start_index
+=
batch
.
size
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_sequence_length
=
max_sequence_length
,
)
@
dataclass
@
dataclass
...
@@ -241,4 +34,6 @@ class GeneratedText:
...
@@ -241,4 +34,6 @@ class GeneratedText:
tokens
:
int
tokens
:
int
def
to_pb
(
self
)
->
generate_pb2
.
GeneratedText
:
def
to_pb
(
self
)
->
generate_pb2
.
GeneratedText
:
return
generate_pb2
.
GeneratedText
(
request
=
self
.
request
,
output
=
self
.
output
,
tokens
=
self
.
tokens
)
return
generate_pb2
.
GeneratedText
(
request
=
self
.
request
,
output
=
self
.
output
,
tokens
=
self
.
tokens
)
server/text_generation/server.py
View file @
427d7cc4
...
@@ -9,7 +9,6 @@ from typing import List
...
@@ -9,7 +9,6 @@ from typing import List
from
text_generation.cache
import
Cache
from
text_generation.cache
import
Cache
from
text_generation.models
import
Model
,
get_model
from
text_generation.models
import
Model
,
get_model
from
text_generation.models.types
import
Batch
from
text_generation.pb
import
generate_pb2_grpc
,
generate_pb2
from
text_generation.pb
import
generate_pb2_grpc
,
generate_pb2
...
@@ -27,7 +26,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -27,7 +26,9 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return
generate_pb2
.
ClearCacheResponse
()
return
generate_pb2
.
ClearCacheResponse
()
async
def
Generate
(
self
,
request
,
context
):
async
def
Generate
(
self
,
request
,
context
):
batch
=
Batch
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
device
)
batch
=
self
.
model
.
batch_type
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
device
)
generated_texts
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
generated_texts
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
next_batch
)
self
.
cache
.
set
(
next_batch
)
...
@@ -51,7 +52,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -51,7 +52,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batches
.
append
(
batch
)
batches
.
append
(
batch
)
if
len
(
batches
)
>
1
:
if
len
(
batches
)
>
1
:
batch
=
Batch
.
concatenate
(
batches
)
batch
=
self
.
model
.
batch_type
.
concatenate
(
batches
)
else
:
else
:
batch
=
batches
[
0
]
batch
=
batches
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment