Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
3cf6368c
Commit
3cf6368c
authored
Oct 28, 2022
by
OlivierDehaene
Browse files
feat(server): Support all AutoModelForCausalLM on a best effort basis
parent
09674e6d
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
389 additions
and
22 deletions
+389
-22
server/text_generation/models/model.py
server/text_generation/models/model.py
+166
-0
server/text_generation/models/types.py
server/text_generation/models/types.py
+206
-0
server/text_generation/pb/.gitignore
server/text_generation/pb/.gitignore
+0
-0
server/text_generation/server.py
server/text_generation/server.py
+10
-13
server/text_generation/utils.py
server/text_generation/utils.py
+7
-9
No files found.
server/text_generation/models/model.py
0 → 100644
View file @
3cf6368c
import
torch
import
torch.distributed
from
typing
import
List
,
Tuple
,
Optional
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
,
AutoConfig
from
transformers.modeling_outputs
import
CausalLMOutputWithPast
from
text_generation.models.types
import
Batch
,
GeneratedText
class
Model
:
def
__init__
(
self
,
model_name
:
str
):
if
torch
.
cuda
.
is_available
():
self
.
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
self
.
tokenizer
.
add_special_tokens
({
"pad_token"
:
"[PAD]"
})
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
).
eval
()
self
.
num_heads
=
self
.
model
.
config
.
num_attention_heads
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
)
->
CausalLMOutputWithPast
:
# Model Forward
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
def
generate_token
(
self
,
batch
:
Batch
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
Batch
]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager
=
(
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
)
with
context_manager
():
outputs
=
self
.
forward
(
**
batch
.
input_ids
)
# List of indices to cache
next_batch_keep_indices
=
[]
next_batch_past_keep_indices
=
[]
# New input_ids for next forward
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_all_input_lengths
=
[]
next_batch_size
=
0
next_batch_max_sequence_length
=
0
# Finished requests
generated_texts
:
List
[
GeneratedText
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
all_input_lengths
,
outputs
.
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
logits
,
next_token_chooser
,
stopping_criteria
,
all_tokens
,
)
in
enumerate
(
iterator
):
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
])
# Append next token to all tokens
all_tokens
=
torch
.
cat
([
all_tokens
,
next_token
])
# Evaluate stopping criteria
if
stopping_criteria
(
all_tokens
):
# Decode all tokens
output
=
self
.
tokenizer
.
decode
(
all_tokens
.
squeeze
(
-
1
),
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
request
,
output
))
# add to the next batch
else
:
next_batch_keep_indices
.
append
(
i
)
# past_key_values is of shape [batch_size * num_heads, ...]
# so we need to take into account the `num_heads` stride here
next_batch_past_keep_indices
.
extend
(
[
j
for
j
in
range
(
i
*
self
.
num_heads
,
(
i
+
1
)
*
self
.
num_heads
)]
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_size
+=
1
new_input_length
=
input_length
+
1
next_all_input_lengths
.
append
(
new_input_length
)
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
new_input_length
)
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generated_texts
,
None
# If we finished at least one generation
next_batch_input_ids
=
{
"input_ids"
:
torch
.
cat
(
next_batch_input_ids
,
dim
=
0
)}
if
generated_texts
:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
][
next_batch_keep_indices
]
next_batch_input_ids
[
"past_key_values"
]
=
[
(
keys
[
next_batch_past_keep_indices
],
values
[
next_batch_past_keep_indices
],
)
for
keys
,
values
in
outputs
[
"past_key_values"
]
]
next_batch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
]
next_batch_input_ids
[
"past_key_values"
]
=
outputs
[
"past_key_values"
]
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids
[
"attention_mask"
]
=
torch
.
cat
(
[
next_batch_input_ids
[
"attention_mask"
],
torch
.
ones
((
next_batch_size
,
1
)).
to
(
self
.
device
),
],
dim
=
1
,
)
next_batch
=
Batch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
all_input_lengths
=
next_all_input_lengths
,
input_ids
=
next_batch_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_sequence_length
=
next_batch_max_sequence_length
,
)
return
generated_texts
,
next_batch
server/text_generation/models/types.py
0 → 100644
View file @
3cf6368c
import
torch
from
dataclasses
import
dataclass
from
typing
import
List
,
Dict
from
transformers
import
AutoTokenizer
from
text_generation.pb
import
generate_pb2
from
text_generation.utils
import
NextTokenChooser
,
StoppingCriteria
@
dataclass
class
Batch
:
batch_id
:
int
requests
:
List
[
generate_pb2
.
Request
]
all_input_lengths
:
List
[
int
]
input_ids
:
Dict
[
str
,
torch
.
Tensor
]
all_input_ids
:
List
[
torch
.
Tensor
]
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
size
:
int
max_sequence_length
:
int
def
to_pb
(
self
):
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
size
=
self
.
size
,
max_sequence_length
=
self
.
max_sequence_length
,
)
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
AutoTokenizer
,
device
:
torch
.
device
)
->
"Batch"
:
inputs
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
all_input_lengths
=
[]
# Parse batch
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
all_input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
(
temperature
=
r
.
parameters
.
temperature
,
top_k
=
r
.
parameters
.
top_k
,
top_p
=
r
.
parameters
.
top_p
,
do_sample
=
r
.
parameters
.
do_sample
,
)
)
stopping_criterias
.
append
(
StoppingCriteria
(
max_new_tokens
=
r
.
max_new_tokens
))
input_ids
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
8
).
to
(
device
)
all_input_ids
=
input_ids
[
"input_ids"
].
unsqueeze
(
-
1
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
pb
.
size
,
max_sequence_length
=
pb
.
max_sequence_length
,
)
@
classmethod
def
concatenate
(
cls
,
batches
:
List
[
"Batch"
])
->
"Batch"
:
# Used for padding
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_sequence_length
=
max
(
batch
.
max_sequence_length
for
batch
in
batches
)
# Batch attributes
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
requests
=
[]
all_input_lengths
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index
=
0
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
all_input_lengths
.
extend
(
batch
.
all_input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Slicing end index for this batch
end_index
=
start_index
+
batch
.
size
# We only concatenate batches that did at least one step
if
batch
.
input_ids
[
"input_ids"
].
shape
[
1
]
>
1
:
raise
ValueError
(
"Batch input_ids should be of shape (batch_size, 1)"
)
# Initialize tensors
if
i
==
0
:
input_ids
[
"input_ids"
]
=
torch
.
empty
(
(
total_batch_size
,
1
),
dtype
=
batch
.
input_ids
[
"input_ids"
].
dtype
,
device
=
batch
.
input_ids
[
"input_ids"
].
device
,
)
input_ids
[
"attention_mask"
]
=
torch
.
zeros
(
(
total_batch_size
,
max_sequence_length
),
dtype
=
batch
.
input_ids
[
"attention_mask"
].
dtype
,
device
=
batch
.
input_ids
[
"attention_mask"
].
device
,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids
[
"input_ids"
][
start_index
:
end_index
]
=
batch
.
input_ids
[
"input_ids"
]
# We need to slice the attention mask to remove padding from previous steps
input_ids
[
"attention_mask"
][
start_index
:
end_index
,
-
batch
.
max_sequence_length
:
]
=
batch
.
input_ids
[
"attention_mask"
][:,
-
batch
.
max_sequence_length
:]
for
j
,
past
in
enumerate
(
batch
.
input_ids
[
"past_key_values"
]):
past_keys
=
past
[
0
]
past_values
=
past
[
1
]
_
,
head_dim
,
padded_sequence_length
=
past_keys
.
shape
# Reshape the tensors to make slicing easier
past_keys
=
past_keys
.
view
(
batch
.
size
,
-
1
,
head_dim
,
padded_sequence_length
)
past_values
=
past_values
.
view
(
batch
.
size
,
-
1
,
padded_sequence_length
,
head_dim
)
num_heads
=
past_keys
.
shape
[
1
]
# Initialize tensors
# This will run only once per layer
if
j
==
len
(
input_ids
[
"past_key_values"
]):
padded_past_keys
=
torch
.
zeros
(
(
total_batch_size
,
num_heads
,
head_dim
,
max_sequence_length
-
1
,
),
dtype
=
past_keys
.
dtype
,
device
=
past_keys
.
device
,
)
padded_past_values
=
torch
.
zeros
(
(
total_batch_size
,
num_heads
,
max_sequence_length
-
1
,
head_dim
,
),
dtype
=
past_values
.
dtype
,
device
=
past_values
.
device
,
)
input_ids
[
"past_key_values"
].
append
(
[
padded_past_keys
,
padded_past_values
]
)
# We slice the past keys and values to remove the padding from previous batches
input_ids
[
"past_key_values"
][
j
][
0
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:
]
=
past_keys
[:,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:]
input_ids
[
"past_key_values"
][
j
][
1
][
start_index
:
end_index
,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
:
]
=
past_values
[:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
:]
# If we are on the last batch, we need to reshape the tensors
if
(
i
+
1
)
==
len
(
batches
):
input_ids
[
"past_key_values"
][
j
][
0
]
=
input_ids
[
"past_key_values"
][
j
][
0
].
view
(
total_batch_size
*
num_heads
,
head_dim
,
-
1
)
input_ids
[
"past_key_values"
][
j
][
1
]
=
input_ids
[
"past_key_values"
][
j
][
1
].
view
(
total_batch_size
*
num_heads
,
-
1
,
head_dim
)
start_index
+=
batch
.
size
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_sequence_length
=
max_sequence_length
,
)
@
dataclass
class
GeneratedText
:
request
:
generate_pb2
.
Request
output
:
str
def
to_pb
(
self
)
->
generate_pb2
.
GeneratedText
:
return
generate_pb2
.
GeneratedText
(
request
=
self
.
request
,
output
=
self
.
output
)
server/
bloom_inference
/pb/.gitignore
→
server/
text_generation
/pb/.gitignore
View file @
3cf6368c
File moved
server/
bloom_inference
/server.py
→
server/
text_generation
/server.py
View file @
3cf6368c
...
...
@@ -5,15 +5,16 @@ from grpc import aio
from
grpc_reflection.v1alpha
import
reflection
from
pathlib
import
Path
from
typing
import
Optional
,
List
from
typing
import
List
from
bloom_inference.cache
import
Cache
from
bloom_inference.model
import
BLOOM
,
Batch
,
BLOOMSharded
from
bloom_inference.pb
import
generate_pb2_grpc
,
generate_pb2
from
text_generation.cache
import
Cache
from
text_generation.models
import
Model
,
get_model
from
text_generation.models.types
import
Batch
from
text_generation.pb
import
generate_pb2_grpc
,
generate_pb2
class
TextGenerationService
(
generate_pb2_grpc
.
TextGenerationServiceServicer
):
def
__init__
(
self
,
model
:
BLOOM
,
cache
:
Cache
,
server_urls
:
List
[
str
]):
def
__init__
(
self
,
model
:
Model
,
cache
:
Cache
,
server_urls
:
List
[
str
]):
self
.
cache
=
cache
self
.
model
=
model
self
.
server_urls
=
server_urls
...
...
@@ -78,21 +79,17 @@ def serve(
):
unix_socket_template
=
"unix://{}-{}"
if
sharded
:
model
=
BLOOMSharded
(
model_name
,
quantize
)
server_urls
=
[
unix_socket_template
.
format
(
uds_path
,
rank
)
for
rank
in
range
(
model
.
world_size
)
for
rank
in
range
(
int
(
os
.
environ
[
"WORLD_SIZE"
])
)
]
local_url
=
server_urls
[
model
.
rank
]
local_url
=
server_urls
[
int
(
os
.
environ
[
"RANK"
])
]
else
:
if
quantize
:
raise
ValueError
(
"bitsandbytes quantization is only available when running in `sharded` mode."
)
model
=
BLOOM
(
model_name
)
local_url
=
unix_socket_template
.
format
(
uds_path
,
0
)
server_urls
=
[
local_url
]
model
=
get_model
(
model_name
,
sharded
,
quantize
)
server
=
aio
.
server
()
generate_pb2_grpc
.
add_TextGenerationServiceServicer_to_server
(
TextGenerationService
(
model
,
Cache
(),
server_urls
),
server
...
...
server/
bloom_inference
/utils.py
→
server/
text_generation
/utils.py
View file @
3cf6368c
...
...
@@ -92,19 +92,17 @@ def initialize_torch_distributed():
return
torch
.
distributed
.
distributed_c10d
.
_get_default_group
(),
rank
,
world_size
def
weight_hub_files
(
model_name
):
def
weight_hub_files
(
model_name
,
extension
=
".safetensors"
):
"""Get the safetensors filenames on the hub"""
api
=
HfApi
()
info
=
api
.
model_info
(
model_name
)
filenames
=
[
s
.
rfilename
for
s
in
info
.
siblings
if
s
.
rfilename
.
endswith
(
".safetensors"
)
]
filenames
=
[
s
.
rfilename
for
s
in
info
.
siblings
if
s
.
rfilename
.
endswith
(
extension
)]
return
filenames
def
weight_files
(
model_name
):
def
weight_files
(
model_name
,
extension
=
".safetensors"
):
"""Get the local safetensors filenames"""
filenames
=
weight_hub_files
(
model_name
)
filenames
=
weight_hub_files
(
model_name
,
extension
)
files
=
[]
for
filename
in
filenames
:
cache_file
=
try_to_load_from_cache
(
model_name
,
filename
=
filename
)
...
...
@@ -112,16 +110,16 @@ def weight_files(model_name):
raise
LocalEntryNotFoundError
(
f
"File
{
filename
}
of model
{
model_name
}
not found in "
f
"
{
os
.
getenv
(
'HUGGINGFACE_HUB_CACHE'
,
'the local cache'
)
}
. "
f
"Please run `
bloom-inference
-server download-weights
{
model_name
}
` first."
f
"Please run `
text-generation
-server download-weights
{
model_name
}
` first."
)
files
.
append
(
cache_file
)
return
files
def
download_weights
(
model_name
):
def
download_weights
(
model_name
,
extension
=
".safetensors"
):
"""Download the safetensors files from the hub"""
filenames
=
weight_hub_files
(
model_name
)
filenames
=
weight_hub_files
(
model_name
,
extension
)
download_function
=
partial
(
hf_hub_download
,
repo_id
=
model_name
,
local_files_only
=
False
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment