Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
755fc0e4
Commit
755fc0e4
authored
Nov 03, 2022
by
OlivierDehaene
Browse files
fix(models): Revert buggy support for AutoModel
parent
b3b7ea0d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
339 additions
and
315 deletions
+339
-315
README.md
README.md
+2
-3
server/Makefile
server/Makefile
+5
-5
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+4
-10
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+311
-12
server/text_generation/models/model.py
server/text_generation/models/model.py
+12
-159
server/text_generation/models/types.py
server/text_generation/models/types.py
+3
-124
server/text_generation/server.py
server/text_generation/server.py
+2
-2
No files found.
README.md
View file @
755fc0e4
...
@@ -15,13 +15,11 @@ A Rust and gRPC server for text generation inference.
...
@@ -15,13 +15,11 @@ A Rust and gRPC server for text generation inference.
-
[
Safetensors
](
https://github.com/huggingface/safetensors
)
weight loading
-
[
Safetensors
](
https://github.com/huggingface/safetensors
)
weight loading
-
45ms per token generation for BLOOM with 8xA100 80GB
-
45ms per token generation for BLOOM with 8xA100 80GB
##
Officially s
upported models
##
S
upported models
-
BLOOM
-
BLOOM
-
BLOOM-560m
-
BLOOM-560m
Other models are supported on a best-effort basis using
`AutoModelForCausalLM.from_pretrained(<model>, torch_dtype=torch.float16, device_map="auto")`
.
## Load Tests for BLOOM
## Load Tests for BLOOM
See
`k6/load_test.js`
See
`k6/load_test.js`
...
@@ -82,6 +80,7 @@ make router-dev
...
@@ -82,6 +80,7 @@ make router-dev
## TODO:
## TODO:
-
[ ] Support AutoModelForSeq2SeqLM
-
[ ] Add tests for the
`server/model`
logic
-
[ ] Add tests for the
`server/model`
logic
-
[ ] Backport custom CUDA kernels to Transformers
-
[ ] Backport custom CUDA kernels to Transformers
-
[ ] Install safetensors with pip
-
[ ] Install safetensors with pip
\ No newline at end of file
server/Makefile
View file @
755fc0e4
...
@@ -9,11 +9,11 @@ gen-server:
...
@@ -9,11 +9,11 @@ gen-server:
install-transformers
:
install-transformers
:
# Install specific version of transformers
# Install specific version of transformers
rm
transformers
||
true
rm
transformers
||
true
rm
transformers-
46d37bece7d3ffdef97b1ee4a3170c0a0627d921
||
true
rm
transformers-
7302a24535e8dc5637ea5b4e4572fc971d404098
||
true
curl
-L
-O
https://github.com/
huggingfac
e/transformers/archive/
46d37bece7d3ffdef97b1ee4a3170c0a0627d921
.zip
curl
-L
-O
https://github.com/
OlivierDehaen
e/transformers/archive/
7302a24535e8dc5637ea5b4e4572fc971d404098
.zip
unzip
46d37bece7d3ffdef97b1ee4a3170c0a0627d921
.zip
unzip
7302a24535e8dc5637ea5b4e4572fc971d404098
.zip
rm
46d37bece7d3ffdef97b1ee4a3170c0a0627d921
.zip
rm
7302a24535e8dc5637ea5b4e4572fc971d404098
.zip
mv
transformers-
46d37bece7d3ffdef97b1ee4a3170c0a0627d921
transformers
mv
transformers-
7302a24535e8dc5637ea5b4e4572fc971d404098
transformers
cd
transformers
&&
python setup.py
install
cd
transformers
&&
python setup.py
install
install-safetensors
:
install-safetensors
:
...
...
server/text_generation/models/__init__.py
View file @
755fc0e4
from
text_generation.models.model
import
Model
from
text_generation.models.model
import
Model
from
text_generation.models.bloom
import
BLOOMSharded
from
text_generation.models.bloom
import
BLOOM
,
BLOOMSharded
__all__
=
[
"Model"
,
"BLOOMSharded"
]
__all__
=
[
"Model"
,
"BLOOM"
,
"BLOOMSharded"
]
def
get_model
(
model_name
:
str
,
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
def
get_model
(
model_name
:
str
,
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
if
model_name
.
startswith
(
"bigscience/bloom"
):
if
model_name
.
startswith
(
"bigscience/bloom"
):
if
sharded
:
if
sharded
:
return
BLOOMSharded
(
model_name
,
quantize
)
return
BLOOMSharded
(
model_name
,
quantize
)
else
:
else
:
if
quantize
:
if
quantize
:
raise
ValueError
(
"quantization is not supported for non-sharded BLOOM"
)
raise
ValueError
(
"quantization is not supported for non-sharded BLOOM"
)
return
Model
(
model_name
)
return
BLOOM
(
model_name
)
else
:
else
:
if
sharded
:
raise
ValueError
(
f
"model
{
model_name
}
is not supported yet"
)
raise
ValueError
(
"sharded is only supported for BLOOM models"
)
if
quantize
:
raise
ValueError
(
"Quantization is only supported for BLOOM models"
)
return
Model
(
model_name
)
server/text_generation/models/bloom.py
View file @
755fc0e4
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
accelerate
import
init_empty_weights
from
accelerate
import
init_empty_weights
from
safetensors
import
safe_open
from
safetensors
import
safe_open
...
@@ -11,8 +11,10 @@ from transformers.models.bloom.parallel_layers import (
...
@@ -11,8 +11,10 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelEmbedding
,
TensorParallelEmbedding
,
TensorParallelRowLinear
,
TensorParallelRowLinear
,
)
)
from
transformers.modeling_outputs
import
CausalLMOutputWithPast
from
text_generation.models
import
Model
from
text_generation.models
import
Model
from
text_generation.models.types
import
Batch
,
GeneratedText
from
text_generation.utils
import
(
from
text_generation.utils
import
(
initialize_torch_distributed
,
initialize_torch_distributed
,
weight_files
,
weight_files
,
...
@@ -29,9 +31,306 @@ except Exception as e:
...
@@ -29,9 +31,306 @@ except Exception as e:
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
class
BLOOMSharded
(
Model
):
class
BloomBatch
(
Batch
):
@
classmethod
def
concatenate
(
cls
,
batches
:
List
[
"Batch"
])
->
"BloomBatch"
:
# Used for padding
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_sequence_length
=
max
(
batch
.
max_sequence_length
for
batch
in
batches
)
# Batch attributes
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
requests
=
[]
all_input_lengths
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index
=
0
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
all_input_lengths
.
extend
(
batch
.
all_input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Slicing end index for this batch
end_index
=
start_index
+
batch
.
size
# We only concatenate batches that did at least one step
if
batch
.
input_ids
[
"input_ids"
].
shape
[
1
]
>
1
:
raise
ValueError
(
"Batch input_ids should be of shape (batch_size, 1)"
)
# Initialize tensors
if
i
==
0
:
input_ids
[
"input_ids"
]
=
torch
.
empty
(
(
total_batch_size
,
1
),
dtype
=
batch
.
input_ids
[
"input_ids"
].
dtype
,
device
=
batch
.
input_ids
[
"input_ids"
].
device
,
)
input_ids
[
"attention_mask"
]
=
torch
.
zeros
(
(
total_batch_size
,
max_sequence_length
),
dtype
=
batch
.
input_ids
[
"attention_mask"
].
dtype
,
device
=
batch
.
input_ids
[
"attention_mask"
].
device
,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids
[
"input_ids"
][
start_index
:
end_index
]
=
batch
.
input_ids
[
"input_ids"
]
# We need to slice the attention mask to remove padding from previous steps
input_ids
[
"attention_mask"
][
start_index
:
end_index
,
-
batch
.
max_sequence_length
:
]
=
batch
.
input_ids
[
"attention_mask"
][:,
-
batch
.
max_sequence_length
:]
for
j
,
past
in
enumerate
(
batch
.
input_ids
[
"past_key_values"
]):
past_keys
=
past
[
0
]
past_values
=
past
[
1
]
_
,
head_dim
,
padded_sequence_length
=
past_keys
.
shape
# Reshape the tensors to make slicing easier
past_keys
=
past_keys
.
view
(
batch
.
size
,
-
1
,
head_dim
,
padded_sequence_length
)
past_values
=
past_values
.
view
(
batch
.
size
,
-
1
,
padded_sequence_length
,
head_dim
)
num_heads
=
past_keys
.
shape
[
1
]
# Initialize tensors
# This will run only once per layer
if
j
==
len
(
input_ids
[
"past_key_values"
]):
padded_past_keys
=
torch
.
zeros
(
(
total_batch_size
,
num_heads
,
head_dim
,
max_sequence_length
-
1
,
),
dtype
=
past_keys
.
dtype
,
device
=
past_keys
.
device
,
)
padded_past_values
=
torch
.
zeros
(
(
total_batch_size
,
num_heads
,
max_sequence_length
-
1
,
head_dim
,
),
dtype
=
past_values
.
dtype
,
device
=
past_values
.
device
,
)
input_ids
[
"past_key_values"
].
append
(
[
padded_past_keys
,
padded_past_values
]
)
# We slice the past keys and values to remove the padding from previous batches
input_ids
[
"past_key_values"
][
j
][
0
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_sequence_length
-
1
):
]
=
past_keys
[:,
:,
:,
-
(
batch
.
max_sequence_length
-
1
):]
input_ids
[
"past_key_values"
][
j
][
1
][
start_index
:
end_index
,
:,
-
(
batch
.
max_sequence_length
-
1
):,
:
]
=
past_values
[:,
:,
-
(
batch
.
max_sequence_length
-
1
):,
:]
# If we are on the last batch, we need to reshape the tensors
if
(
i
+
1
)
==
len
(
batches
):
input_ids
[
"past_key_values"
][
j
][
0
]
=
input_ids
[
"past_key_values"
][
j
][
0
].
view
(
total_batch_size
*
num_heads
,
head_dim
,
-
1
)
input_ids
[
"past_key_values"
][
j
][
1
]
=
input_ids
[
"past_key_values"
][
j
][
1
].
view
(
total_batch_size
*
num_heads
,
-
1
,
head_dim
)
start_index
+=
batch
.
size
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_sequence_length
=
max_sequence_length
,
)
class
BLOOM
(
Model
):
def
__init__
(
self
,
model_name
:
str
):
if
not
model_name
.
startswith
(
"bigscience/bloom"
):
raise
ValueError
(
f
"Model
{
model_name
}
is not supported"
)
if
torch
.
cuda
.
is_available
():
self
.
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
self
.
tokenizer
.
add_special_tokens
({
"pad_token"
:
"[PAD]"
})
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
).
eval
()
self
.
num_heads
=
self
.
model
.
config
.
num_attention_heads
@
property
def
batch_type
(
self
)
->
Type
[
BloomBatch
]:
return
BloomBatch
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
)
->
CausalLMOutputWithPast
:
# Model Forward
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
def
generate_token
(
self
,
batch
:
BloomBatch
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
BloomBatch
]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager
=
(
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
)
with
context_manager
():
outputs
=
self
.
forward
(
**
batch
.
input_ids
)
# List of indices to cache
next_batch_keep_indices
=
[]
next_batch_past_keep_indices
=
[]
# New input_ids for next forward
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_all_input_lengths
=
[]
next_batch_size
=
0
next_batch_max_sequence_length
=
0
# Finished requests
generated_texts
:
List
[
GeneratedText
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
all_input_lengths
,
outputs
.
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
logits
,
next_token_chooser
,
stopping_criteria
,
all_tokens
,
)
in
enumerate
(
iterator
):
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
])
# Append next token to all tokens
all_tokens
=
torch
.
cat
([
all_tokens
,
next_token
])
# Evaluate stopping criteria
if
stopping_criteria
(
all_tokens
):
# Decode all tokens
output
=
self
.
tokenizer
.
decode
(
all_tokens
.
squeeze
(
-
1
),
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
request
,
output
))
# add to the next batch
else
:
next_batch_keep_indices
.
append
(
i
)
# past_key_values is of shape [batch_size * num_heads, ...]
# so we need to take into account the `num_heads` stride here
next_batch_past_keep_indices
.
extend
(
[
j
for
j
in
range
(
i
*
self
.
num_heads
,
(
i
+
1
)
*
self
.
num_heads
)]
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_size
+=
1
new_input_length
=
input_length
+
1
next_all_input_lengths
.
append
(
new_input_length
)
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
new_input_length
)
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generated_texts
,
None
# If we finished at least one generation
next_batch_input_ids
=
{
"input_ids"
:
torch
.
cat
(
next_batch_input_ids
,
dim
=
0
)}
if
generated_texts
:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
][
next_batch_keep_indices
]
next_batch_input_ids
[
"past_key_values"
]
=
[
(
keys
[
next_batch_past_keep_indices
],
values
[
next_batch_past_keep_indices
],
)
for
keys
,
values
in
outputs
[
"past_key_values"
]
]
next_batch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
]
next_batch_input_ids
[
"past_key_values"
]
=
outputs
[
"past_key_values"
]
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids
[
"attention_mask"
]
=
torch
.
cat
(
[
next_batch_input_ids
[
"attention_mask"
],
torch
.
ones
((
next_batch_size
,
1
)).
to
(
self
.
device
),
],
dim
=
1
,
)
next_batch
=
BloomBatch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
all_input_lengths
=
next_all_input_lengths
,
input_ids
=
next_batch_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_sequence_length
=
next_batch_max_sequence_length
,
)
return
generated_texts
,
next_batch
class
BLOOMSharded
(
BLOOM
):
def
__init__
(
self
,
model_name
:
str
,
quantize
:
bool
=
False
):
def
__init__
(
self
,
model_name
:
str
,
quantize
:
bool
=
False
):
super
(
Model
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
if
not
model_name
.
startswith
(
"bigscience/bloom"
):
raise
ValueError
(
f
"Model
{
model_name
}
is not supported"
)
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -80,17 +379,17 @@ class BLOOMSharded(Model):
...
@@ -80,17 +379,17 @@ class BLOOMSharded(Model):
@
staticmethod
@
staticmethod
def
load_weights
(
def
load_weights
(
model
,
model
,
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
parameters
=
dict
(
model
.
named_parameters
())
parameters
=
dict
(
model
.
named_parameters
())
for
file
in
filenames
:
for
file
in
filenames
:
with
safe_open
(
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
)
as
f
:
)
as
f
:
for
name
in
f
.
keys
():
for
name
in
f
.
keys
():
full_name
=
f
"transformer.
{
name
}
"
full_name
=
f
"transformer.
{
name
}
"
...
@@ -153,9 +452,9 @@ class BLOOMSharded(Model):
...
@@ -153,9 +452,9 @@ class BLOOMSharded(Model):
)
)
if
(
if
(
type
(
module
)
type
(
module
)
in
[
TensorParallelRowLinear
,
TensorParallelColumnLinear
]
in
[
TensorParallelRowLinear
,
TensorParallelColumnLinear
]
and
param_name
==
"weight"
and
param_name
==
"weight"
):
):
tensor
=
Int8Params
(
tensor
=
Int8Params
(
tensor
.
transpose
(
1
,
0
),
tensor
.
transpose
(
1
,
0
),
...
...
server/text_generation/models/model.py
View file @
755fc0e4
import
torch
from
abc
import
ABC
,
abstractmethod
import
torch.distributed
from
typing
import
List
,
Tuple
,
Optional
,
TypeVar
,
Type
from
typing
import
List
,
Tuple
,
Optional
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
,
AutoConfig
from
transformers.modeling_outputs
import
CausalLMOutputWithPast
from
text_generation.models.types
import
Batch
,
GeneratedText
from
text_generation.models.types
import
Batch
,
GeneratedText
B
=
TypeVar
(
"B"
,
bound
=
Batch
)
class
Model
:
def
__init__
(
self
,
model_name
:
str
):
if
torch
.
cuda
.
is_available
():
self
.
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
self
.
tokenizer
.
add_special_tokens
({
"pad_token"
:
"[PAD]"
})
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
).
eval
()
self
.
num_heads
=
self
.
model
.
config
.
num_attention_heads
class
Model
(
ABC
):
@
property
def
forward
(
@
abstractmethod
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
def
batch_type
(
self
)
->
Type
[
B
]:
)
->
CausalLMOutputWithPast
:
raise
NotImplementedError
# Model Forward
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
@
abstractmethod
def
generate_token
(
def
generate_token
(
self
,
batch
:
Batch
self
,
batch
:
B
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
Batch
]]:
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
B
]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
raise
NotImplementedError
context_manager
=
(
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
)
with
context_manager
():
outputs
=
self
.
forward
(
**
batch
.
input_ids
)
# List of indices to cache
next_batch_keep_indices
=
[]
next_batch_past_keep_indices
=
[]
# New input_ids for next forward
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_all_input_lengths
=
[]
next_batch_size
=
0
next_batch_max_sequence_length
=
0
# Finished requests
generated_texts
:
List
[
GeneratedText
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
all_input_lengths
,
outputs
.
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
logits
,
next_token_chooser
,
stopping_criteria
,
all_tokens
,
)
in
enumerate
(
iterator
):
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
])
# Append next token to all tokens
all_tokens
=
torch
.
cat
([
all_tokens
,
next_token
])
# Evaluate stopping criteria
if
stopping_criteria
(
all_tokens
):
# Decode all tokens
output
=
self
.
tokenizer
.
decode
(
all_tokens
.
squeeze
(
-
1
),
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
request
,
output
))
# add to the next batch
else
:
next_batch_keep_indices
.
append
(
i
)
# past_key_values is of shape [batch_size * num_heads, ...]
# so we need to take into account the `num_heads` stride here
next_batch_past_keep_indices
.
extend
(
[
j
for
j
in
range
(
i
*
self
.
num_heads
,
(
i
+
1
)
*
self
.
num_heads
)]
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_size
+=
1
new_input_length
=
input_length
+
1
next_all_input_lengths
.
append
(
new_input_length
)
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
new_input_length
)
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generated_texts
,
None
# If we finished at least one generation
next_batch_input_ids
=
{
"input_ids"
:
torch
.
cat
(
next_batch_input_ids
,
dim
=
0
)}
if
generated_texts
:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
][
next_batch_keep_indices
]
next_batch_input_ids
[
"past_key_values"
]
=
[
(
keys
[
next_batch_past_keep_indices
],
values
[
next_batch_past_keep_indices
],
)
for
keys
,
values
in
outputs
[
"past_key_values"
]
]
next_batch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
]
next_batch_input_ids
[
"past_key_values"
]
=
outputs
[
"past_key_values"
]
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids
[
"attention_mask"
]
=
torch
.
cat
(
[
next_batch_input_ids
[
"attention_mask"
],
torch
.
ones
((
next_batch_size
,
1
)).
to
(
self
.
device
),
],
dim
=
1
,
)
next_batch
=
Batch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
all_input_lengths
=
next_all_input_lengths
,
input_ids
=
next_batch_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_sequence_length
=
next_batch_max_sequence_length
,
)
return
generated_texts
,
next_batch
server/text_generation/models/types.py
View file @
755fc0e4
import
torch
import
torch
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Dict
from
typing
import
List
,
Dict
...
@@ -70,131 +71,9 @@ class Batch:
...
@@ -70,131 +71,9 @@ class Batch:
)
)
@
classmethod
@
classmethod
@
abstractmethod
def
concatenate
(
cls
,
batches
:
List
[
"Batch"
])
->
"Batch"
:
def
concatenate
(
cls
,
batches
:
List
[
"Batch"
])
->
"Batch"
:
# Used for padding
raise
NotImplementedError
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_sequence_length
=
max
(
batch
.
max_sequence_length
for
batch
in
batches
)
# Batch attributes
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
requests
=
[]
all_input_lengths
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index
=
0
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
all_input_lengths
.
extend
(
batch
.
all_input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Slicing end index for this batch
end_index
=
start_index
+
batch
.
size
# We only concatenate batches that did at least one step
if
batch
.
input_ids
[
"input_ids"
].
shape
[
1
]
>
1
:
raise
ValueError
(
"Batch input_ids should be of shape (batch_size, 1)"
)
# Initialize tensors
if
i
==
0
:
input_ids
[
"input_ids"
]
=
torch
.
empty
(
(
total_batch_size
,
1
),
dtype
=
batch
.
input_ids
[
"input_ids"
].
dtype
,
device
=
batch
.
input_ids
[
"input_ids"
].
device
,
)
input_ids
[
"attention_mask"
]
=
torch
.
zeros
(
(
total_batch_size
,
max_sequence_length
),
dtype
=
batch
.
input_ids
[
"attention_mask"
].
dtype
,
device
=
batch
.
input_ids
[
"attention_mask"
].
device
,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids
[
"input_ids"
][
start_index
:
end_index
]
=
batch
.
input_ids
[
"input_ids"
]
# We need to slice the attention mask to remove padding from previous steps
input_ids
[
"attention_mask"
][
start_index
:
end_index
,
-
batch
.
max_sequence_length
:
]
=
batch
.
input_ids
[
"attention_mask"
][:,
-
batch
.
max_sequence_length
:]
for
j
,
past
in
enumerate
(
batch
.
input_ids
[
"past_key_values"
]):
past_keys
=
past
[
0
]
past_values
=
past
[
1
]
_
,
head_dim
,
padded_sequence_length
=
past_keys
.
shape
# Reshape the tensors to make slicing easier
past_keys
=
past_keys
.
view
(
batch
.
size
,
-
1
,
head_dim
,
padded_sequence_length
)
past_values
=
past_values
.
view
(
batch
.
size
,
-
1
,
padded_sequence_length
,
head_dim
)
num_heads
=
past_keys
.
shape
[
1
]
# Initialize tensors
# This will run only once per layer
if
j
==
len
(
input_ids
[
"past_key_values"
]):
padded_past_keys
=
torch
.
zeros
(
(
total_batch_size
,
num_heads
,
head_dim
,
max_sequence_length
-
1
,
),
dtype
=
past_keys
.
dtype
,
device
=
past_keys
.
device
,
)
padded_past_values
=
torch
.
zeros
(
(
total_batch_size
,
num_heads
,
max_sequence_length
-
1
,
head_dim
,
),
dtype
=
past_values
.
dtype
,
device
=
past_values
.
device
,
)
input_ids
[
"past_key_values"
].
append
(
[
padded_past_keys
,
padded_past_values
]
)
# We slice the past keys and values to remove the padding from previous batches
input_ids
[
"past_key_values"
][
j
][
0
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:
]
=
past_keys
[:,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:]
input_ids
[
"past_key_values"
][
j
][
1
][
start_index
:
end_index
,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
:
]
=
past_values
[:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
:]
# If we are on the last batch, we need to reshape the tensors
if
(
i
+
1
)
==
len
(
batches
):
input_ids
[
"past_key_values"
][
j
][
0
]
=
input_ids
[
"past_key_values"
][
j
][
0
].
view
(
total_batch_size
*
num_heads
,
head_dim
,
-
1
)
input_ids
[
"past_key_values"
][
j
][
1
]
=
input_ids
[
"past_key_values"
][
j
][
1
].
view
(
total_batch_size
*
num_heads
,
-
1
,
head_dim
)
start_index
+=
batch
.
size
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_sequence_length
=
max_sequence_length
,
)
@
dataclass
@
dataclass
...
...
server/text_generation/server.py
View file @
755fc0e4
...
@@ -27,7 +27,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -27,7 +27,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return
generate_pb2
.
ClearCacheResponse
()
return
generate_pb2
.
ClearCacheResponse
()
async
def
Generate
(
self
,
request
,
context
):
async
def
Generate
(
self
,
request
,
context
):
batch
=
Batch
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
device
)
batch
=
self
.
model
.
batch_type
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
device
)
generated_texts
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
generated_texts
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
next_batch
)
self
.
cache
.
set
(
next_batch
)
...
@@ -51,7 +51,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -51,7 +51,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batches
.
append
(
batch
)
batches
.
append
(
batch
)
if
len
(
batches
)
>
1
:
if
len
(
batches
)
>
1
:
batch
=
Batch
.
concatenate
(
batches
)
batch
=
self
.
model
.
batch_type
.
concatenate
(
batches
)
else
:
else
:
batch
=
batches
[
0
]
batch
=
batches
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment