Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
c5665f5c
Commit
c5665f5c
authored
Nov 04, 2022
by
OlivierDehaene
Browse files
feat(server): Support generic AutoModelForCausalLM
parent
755fc0e4
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
372 additions
and
332 deletions
+372
-332
README.md
README.md
+1
-0
proto/generate.proto
proto/generate.proto
+2
-0
router/src/batcher.rs
router/src/batcher.rs
+2
-0
router/src/server.rs
router/src/server.rs
+1
-1
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+9
-4
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+9
-306
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+38
-0
server/text_generation/models/model.py
server/text_generation/models/model.py
+129
-9
server/text_generation/models/types.py
server/text_generation/models/types.py
+165
-6
server/text_generation/server.py
server/text_generation/server.py
+2
-2
server/text_generation/utils.py
server/text_generation/utils.py
+14
-4
No files found.
README.md
View file @
c5665f5c
...
...
@@ -18,6 +18,7 @@ A Rust and gRPC server for text generation inference.
## Supported models
-
BLOOM
-
BLOOMZ
-
BLOOM-560m
## Load Tests for BLOOM
...
...
proto/generate.proto
View file @
c5665f5c
...
...
@@ -63,6 +63,8 @@ message GeneratedText {
Request
request
=
1
;
/// Output
string
output
=
2
;
/// Number of generated tokens
uint32
tokens
=
3
;
}
message
GenerateRequest
{
...
...
router/src/batcher.rs
View file @
c5665f5c
...
...
@@ -190,6 +190,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
.expect
(
"ID not found in db. This is a bug."
);
let
response
=
InferResponse
{
output
:
output
.output
,
tokens
:
output
.tokens
,
queued
:
entry
.time
,
start
:
entry
.batch_time
.unwrap
(),
// unwrap is always valid
end
:
Instant
::
now
(),
...
...
@@ -202,6 +203,7 @@ fn send_generated(finished: Vec<GeneratedText>, db: &Db) {
#[derive(Debug)]
pub
(
crate
)
struct
InferResponse
{
pub
(
crate
)
output
:
String
,
pub
(
crate
)
tokens
:
u32
,
pub
(
crate
)
queued
:
Instant
,
pub
(
crate
)
start
:
Instant
,
pub
(
crate
)
end
:
Instant
,
...
...
router/src/server.rs
View file @
c5665f5c
...
...
@@ -116,7 +116,7 @@ async fn generate(
let
validation_time
=
response
.queued
-
start_time
;
let
queue_time
=
response
.start
-
response
.queued
;
let
inference_time
=
response
.end
-
response
.start
;
let
time_per_token
=
inference_time
/
re
q
.parameters.max_new_
tokens
;
let
time_per_token
=
inference_time
/
re
sponse
.
tokens
;
// Headers
let
mut
headers
=
HeaderMap
::
new
();
...
...
server/text_generation/models/__init__.py
View file @
c5665f5c
from
text_generation.models.model
import
Model
from
text_generation.models.bloom
import
BLOOM
,
BLOOMSharded
from
text_generation.models.bloom
import
BLOOMSharded
from
text_generation.models.causal_lm
import
CausalLM
__all__
=
[
"Model"
,
"BLOOM
"
,
"BLOOMSharded
"
]
__all__
=
[
"Model"
,
"BLOOM
Sharded"
,
"CausalLM
"
]
def
get_model
(
model_name
:
str
,
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
...
...
@@ -11,6 +12,10 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
else
:
if
quantize
:
raise
ValueError
(
"quantization is not supported for non-sharded BLOOM"
)
return
BLOO
M
(
model_name
)
return
CausalL
M
(
model_name
)
else
:
raise
ValueError
(
f
"model
{
model_name
}
is not supported yet"
)
if
sharded
:
raise
ValueError
(
"sharded is not supported for AutoModel"
)
if
quantize
:
raise
ValueError
(
"quantize is not supported for AutoModel"
)
return
CausalLM
(
model_name
)
server/text_generation/models/bloom.py
View file @
c5665f5c
import
torch
import
torch.distributed
from
typing
import
List
,
Optional
,
Tuple
,
Type
from
typing
import
List
,
Optional
from
accelerate
import
init_empty_weights
from
safetensors
import
safe_open
...
...
@@ -11,10 +11,8 @@ from transformers.models.bloom.parallel_layers import (
TensorParallelEmbedding
,
TensorParallelRowLinear
,
)
from
transformers.modeling_outputs
import
CausalLMOutputWithPast
from
text_generation.models
import
Model
from
text_generation.models.types
import
Batch
,
GeneratedText
from
text_generation.utils
import
(
initialize_torch_distributed
,
weight_files
,
...
...
@@ -31,322 +29,26 @@ except Exception as e:
torch
.
manual_seed
(
0
)
class
BloomBatch
(
Batch
):
@
classmethod
def
concatenate
(
cls
,
batches
:
List
[
"Batch"
])
->
"BloomBatch"
:
# Used for padding
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_sequence_length
=
max
(
batch
.
max_sequence_length
for
batch
in
batches
)
# Batch attributes
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
requests
=
[]
all_input_lengths
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index
=
0
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
all_input_lengths
.
extend
(
batch
.
all_input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Slicing end index for this batch
end_index
=
start_index
+
batch
.
size
# We only concatenate batches that did at least one step
if
batch
.
input_ids
[
"input_ids"
].
shape
[
1
]
>
1
:
raise
ValueError
(
"Batch input_ids should be of shape (batch_size, 1)"
)
# Initialize tensors
if
i
==
0
:
input_ids
[
"input_ids"
]
=
torch
.
empty
(
(
total_batch_size
,
1
),
dtype
=
batch
.
input_ids
[
"input_ids"
].
dtype
,
device
=
batch
.
input_ids
[
"input_ids"
].
device
,
)
input_ids
[
"attention_mask"
]
=
torch
.
zeros
(
(
total_batch_size
,
max_sequence_length
),
dtype
=
batch
.
input_ids
[
"attention_mask"
].
dtype
,
device
=
batch
.
input_ids
[
"attention_mask"
].
device
,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids
[
"input_ids"
][
start_index
:
end_index
]
=
batch
.
input_ids
[
"input_ids"
]
# We need to slice the attention mask to remove padding from previous steps
input_ids
[
"attention_mask"
][
start_index
:
end_index
,
-
batch
.
max_sequence_length
:
]
=
batch
.
input_ids
[
"attention_mask"
][:,
-
batch
.
max_sequence_length
:]
for
j
,
past
in
enumerate
(
batch
.
input_ids
[
"past_key_values"
]):
past_keys
=
past
[
0
]
past_values
=
past
[
1
]
_
,
head_dim
,
padded_sequence_length
=
past_keys
.
shape
# Reshape the tensors to make slicing easier
past_keys
=
past_keys
.
view
(
batch
.
size
,
-
1
,
head_dim
,
padded_sequence_length
)
past_values
=
past_values
.
view
(
batch
.
size
,
-
1
,
padded_sequence_length
,
head_dim
)
num_heads
=
past_keys
.
shape
[
1
]
# Initialize tensors
# This will run only once per layer
if
j
==
len
(
input_ids
[
"past_key_values"
]):
padded_past_keys
=
torch
.
zeros
(
(
total_batch_size
,
num_heads
,
head_dim
,
max_sequence_length
-
1
,
),
dtype
=
past_keys
.
dtype
,
device
=
past_keys
.
device
,
)
padded_past_values
=
torch
.
zeros
(
(
total_batch_size
,
num_heads
,
max_sequence_length
-
1
,
head_dim
,
),
dtype
=
past_values
.
dtype
,
device
=
past_values
.
device
,
)
input_ids
[
"past_key_values"
].
append
(
[
padded_past_keys
,
padded_past_values
]
)
# We slice the past keys and values to remove the padding from previous batches
input_ids
[
"past_key_values"
][
j
][
0
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_sequence_length
-
1
):
]
=
past_keys
[:,
:,
:,
-
(
batch
.
max_sequence_length
-
1
):]
input_ids
[
"past_key_values"
][
j
][
1
][
start_index
:
end_index
,
:,
-
(
batch
.
max_sequence_length
-
1
):,
:
]
=
past_values
[:,
:,
-
(
batch
.
max_sequence_length
-
1
):,
:]
# If we are on the last batch, we need to reshape the tensors
if
(
i
+
1
)
==
len
(
batches
):
input_ids
[
"past_key_values"
][
j
][
0
]
=
input_ids
[
"past_key_values"
][
j
][
0
].
view
(
total_batch_size
*
num_heads
,
head_dim
,
-
1
)
input_ids
[
"past_key_values"
][
j
][
1
]
=
input_ids
[
"past_key_values"
][
j
][
1
].
view
(
total_batch_size
*
num_heads
,
-
1
,
head_dim
)
start_index
+=
batch
.
size
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_sequence_length
=
max_sequence_length
,
)
class
BLOOM
(
Model
):
def
__init__
(
self
,
model_name
:
str
):
if
not
model_name
.
startswith
(
"bigscience/bloom"
):
raise
ValueError
(
f
"Model
{
model_name
}
is not supported"
)
if
torch
.
cuda
.
is_available
():
self
.
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
self
.
tokenizer
.
add_special_tokens
({
"pad_token"
:
"[PAD]"
})
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
).
eval
()
self
.
num_heads
=
self
.
model
.
config
.
num_attention_heads
@
property
def
batch_type
(
self
)
->
Type
[
BloomBatch
]:
return
BloomBatch
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
)
->
CausalLMOutputWithPast
:
# Model Forward
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
def
generate_token
(
self
,
batch
:
BloomBatch
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
BloomBatch
]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager
=
(
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
)
with
context_manager
():
outputs
=
self
.
forward
(
**
batch
.
input_ids
)
# List of indices to cache
next_batch_keep_indices
=
[]
next_batch_past_keep_indices
=
[]
# New input_ids for next forward
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_all_input_lengths
=
[]
next_batch_size
=
0
next_batch_max_sequence_length
=
0
# Finished requests
generated_texts
:
List
[
GeneratedText
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
all_input_lengths
,
outputs
.
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
logits
,
next_token_chooser
,
stopping_criteria
,
all_tokens
,
)
in
enumerate
(
iterator
):
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
])
# Append next token to all tokens
all_tokens
=
torch
.
cat
([
all_tokens
,
next_token
])
# Evaluate stopping criteria
if
stopping_criteria
(
all_tokens
):
# Decode all tokens
output
=
self
.
tokenizer
.
decode
(
all_tokens
.
squeeze
(
-
1
),
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
request
,
output
))
# add to the next batch
else
:
next_batch_keep_indices
.
append
(
i
)
# past_key_values is of shape [batch_size * num_heads, ...]
# so we need to take into account the `num_heads` stride here
next_batch_past_keep_indices
.
extend
(
[
j
for
j
in
range
(
i
*
self
.
num_heads
,
(
i
+
1
)
*
self
.
num_heads
)]
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_size
+=
1
new_input_length
=
input_length
+
1
next_all_input_lengths
.
append
(
new_input_length
)
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
new_input_length
)
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generated_texts
,
None
# If we finished at least one generation
next_batch_input_ids
=
{
"input_ids"
:
torch
.
cat
(
next_batch_input_ids
,
dim
=
0
)}
if
generated_texts
:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
][
next_batch_keep_indices
]
next_batch_input_ids
[
"past_key_values"
]
=
[
(
keys
[
next_batch_past_keep_indices
],
values
[
next_batch_past_keep_indices
],
)
for
keys
,
values
in
outputs
[
"past_key_values"
]
]
next_batch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
]
next_batch_input_ids
[
"past_key_values"
]
=
outputs
[
"past_key_values"
]
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids
[
"attention_mask"
]
=
torch
.
cat
(
[
next_batch_input_ids
[
"attention_mask"
],
torch
.
ones
((
next_batch_size
,
1
)).
to
(
self
.
device
),
],
dim
=
1
,
)
next_batch
=
BloomBatch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
all_input_lengths
=
next_all_input_lengths
,
input_ids
=
next_batch_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_sequence_length
=
next_batch_max_sequence_length
,
)
return
generated_texts
,
next_batch
class
BLOOMSharded
(
BLOOM
):
class
BLOOMSharded
(
Model
):
def
__init__
(
self
,
model_name
:
str
,
quantize
:
bool
=
False
):
super
(
Model
,
self
).
__init__
()
if
not
model_name
.
startswith
(
"bigscience/bloom"
):
raise
ValueError
(
f
"Model
{
model_name
}
is not supported"
)
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
float16
else
:
self
.
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
config
=
AutoConfig
.
from_pretrained
(
model_name
,
slow_but_exact
=
False
,
tp_parallel
=
True
)
config
.
pad_token_id
=
3
self
.
num_heads
=
config
.
n_head
//
self
.
process_group
.
size
()
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
...
...
@@ -370,12 +72,14 @@ class BLOOMSharded(BLOOM):
model
,
filenames
,
quantize
=
quantize
,
device
=
self
.
device
,
device
=
device
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
)
self
.
model
=
model
.
eval
().
to
(
dtype
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
BLOOMSharded
,
self
).
__init__
(
tokenizer
=
tokenizer
,
num_heads
=
config
.
n_head
//
self
.
process_group
.
size
(),
device
=
device
)
@
staticmethod
def
load_weights
(
...
...
@@ -526,5 +230,4 @@ class BLOOMSharded(BLOOM):
torch
.
distributed
.
all_gather
(
logits
,
logits_shard
,
group
=
self
.
process_group
)
logits
=
torch
.
cat
(
logits
,
dim
=
1
).
view
(
batch_size
,
1
,
vocab_size
)
outputs
.
logits
=
logits
return
outputs
return
logits
,
outputs
.
past_key_values
server/text_generation/models/causal_lm.py
0 → 100644
View file @
c5665f5c
import
torch
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
from
typing
import
Optional
,
Tuple
,
List
from
text_generation.models
import
Model
class
CausalLM
(
Model
):
def
__init__
(
self
,
model_name
:
str
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
else
:
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
tokenizer
.
add_special_tokens
({
"pad_token"
:
"[PAD]"
})
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
).
eval
()
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
num_heads
=
self
.
model
.
config
.
num_attention_heads
,
device
=
device
)
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
)
->
Tuple
[
torch
.
Tensor
,
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Model Forward
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
return
outputs
.
logits
,
outputs
.
past_key_values
server/text_generation/models/model.py
View file @
c5665f5c
import
torch
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Tuple
,
Optional
,
TypeVar
,
Type
from
typing
import
List
,
Tuple
,
Optional
from
tokenizers
import
Tokenizer
from
text_generation.models.types
import
Batch
,
GeneratedText
B
=
TypeVar
(
"B"
,
bound
=
Batch
)
class
Model
(
ABC
):
@
property
def
__init__
(
self
,
tokenizer
:
Tokenizer
,
num_heads
:
int
,
device
:
torch
.
device
):
self
.
tokenizer
=
tokenizer
self
.
num_heads
=
num_heads
self
.
device
=
device
@
abstractmethod
def
batch_type
(
self
)
->
Type
[
B
]:
def
forward
(
self
,
input_ids
,
attention_mask
,
past_key_values
:
Optional
=
None
)
->
Tuple
[
torch
.
Tensor
,
List
[
Tuple
]
]:
raise
NotImplementedError
@
abstractmethod
def
generate_token
(
self
,
batch
:
B
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
B
]]:
raise
NotImplementedError
self
,
batch
:
Batch
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
Batch
]]:
# For some reason, inference_mode does not work well with GLOO which we use on CPU
context_manager
=
(
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
)
with
context_manager
():
logits
,
past
=
self
.
forward
(
**
batch
.
input_ids
)
# List of indices to cache
next_batch_keep_indices
=
[]
# New input_ids for next forward
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_all_input_lengths
=
[]
next_batch_size
=
0
next_batch_max_sequence_length
=
0
# Finished requests
generated_texts
:
List
[
GeneratedText
]
=
[]
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
all_input_lengths
,
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
)
# For each member of the batch
for
i
,
(
request
,
input_length
,
logits
,
next_token_chooser
,
stopping_criteria
,
all_tokens
,
)
in
enumerate
(
iterator
):
# Select next token
next_token
=
next_token_chooser
(
all_tokens
,
logits
.
unsqueeze
(
0
)[:,
-
1
])
# Append next token to all tokens
all_tokens
=
torch
.
cat
([
all_tokens
,
next_token
])
# Evaluate stopping criteria
if
stopping_criteria
(
all_tokens
):
# Decode all tokens
output
=
self
.
tokenizer
.
decode
(
all_tokens
.
squeeze
(
-
1
),
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request
generated_texts
.
append
(
GeneratedText
(
request
,
output
,
stopping_criteria
.
current_tokens
))
# add to the next batch
else
:
next_batch_keep_indices
.
append
(
i
)
next_batch_input_ids
.
append
(
next_token
)
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_size
+=
1
new_input_length
=
input_length
+
1
next_all_input_lengths
.
append
(
new_input_length
)
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
new_input_length
)
# We finished all generations in the batch; there is no next batch
if
not
next_batch_keep_indices
:
return
generated_texts
,
None
# If we finished at least one generation
next_batch_input_ids
=
{
"input_ids"
:
torch
.
cat
(
next_batch_input_ids
,
dim
=
0
)}
if
generated_texts
:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
][
next_batch_keep_indices
]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_input_ids
[
"past_key_values"
]
=
[
[
t
.
view
(
-
1
,
self
.
num_heads
,
*
t
.
shape
[
-
2
:])[
next_batch_keep_indices
]
for
t
in
layer
]
for
layer
in
past
]
next_batch_requests
=
[
batch
.
requests
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_next_token_choosers
=
[
batch
.
next_token_choosers
[
i
]
for
i
in
next_batch_keep_indices
]
next_batch_stopping_criterias
=
[
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
]
next_batch_input_ids
[
"past_key_values"
]
=
past
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_input_ids
[
"attention_mask"
]
=
torch
.
cat
(
[
next_batch_input_ids
[
"attention_mask"
],
torch
.
ones
((
next_batch_size
,
1
)).
to
(
self
.
device
),
],
dim
=
1
,
)
next_batch
=
Batch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
all_input_lengths
=
next_all_input_lengths
,
input_ids
=
next_batch_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_sequence_length
=
next_batch_max_sequence_length
,
)
return
generated_texts
,
next_batch
server/text_generation/models/types.py
View file @
c5665f5c
import
torch
from
abc
import
abstractmethod
from
dataclasses
import
dataclass
from
typing
import
List
,
Dict
...
...
@@ -51,7 +50,11 @@ class Batch:
do_sample
=
r
.
parameters
.
do_sample
,
)
)
stopping_criterias
.
append
(
StoppingCriteria
(
max_new_tokens
=
r
.
max_new_tokens
))
stopping_criterias
.
append
(
StoppingCriteria
(
eos_token_id
=
tokenizer
.
eos_token_id
,
max_new_tokens
=
r
.
max_new_tokens
)
)
input_ids
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
8
...
...
@@ -71,15 +74,171 @@ class Batch:
)
@
classmethod
@
abstractmethod
def
concatenate
(
cls
,
batches
:
List
[
"Batch"
])
->
"Batch"
:
raise
NotImplementedError
# Used for padding
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_sequence_length
=
max
(
batch
.
max_sequence_length
for
batch
in
batches
)
# Only needed for Seq2SeqLM
max_encoded_sequence_length
=
None
# Batch attributes
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
requests
=
[]
all_input_lengths
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index
=
0
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
all_input_lengths
.
extend
(
batch
.
all_input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
# Slicing end index for this batch
end_index
=
start_index
+
batch
.
size
# We only concatenate batches that did at least one step
if
batch
.
input_ids
[
"input_ids"
].
shape
[
1
]
>
1
:
raise
ValueError
(
"Batch input_ids should be of shape (batch_size, 1)"
)
# Initialize tensors
if
i
==
0
:
input_ids
[
"input_ids"
]
=
torch
.
empty
(
(
total_batch_size
,
1
),
dtype
=
batch
.
input_ids
[
"input_ids"
].
dtype
,
device
=
batch
.
input_ids
[
"input_ids"
].
device
,
)
input_ids
[
"attention_mask"
]
=
torch
.
zeros
(
(
total_batch_size
,
max_sequence_length
),
dtype
=
batch
.
input_ids
[
"attention_mask"
].
dtype
,
device
=
batch
.
input_ids
[
"attention_mask"
].
device
,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids
[
"input_ids"
][
start_index
:
end_index
]
=
batch
.
input_ids
[
"input_ids"
]
# We need to slice the attention mask to remove padding from previous steps
input_ids
[
"attention_mask"
][
start_index
:
end_index
,
-
batch
.
max_sequence_length
:
]
=
batch
.
input_ids
[
"attention_mask"
][:,
-
batch
.
max_sequence_length
:]
for
j
,
past
in
enumerate
(
batch
.
input_ids
[
"past_key_values"
]):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
head_dim
,
padded_sequence_length
=
past
[
0
].
shape
[
-
2
:]
num_heads
=
(
past
[
0
]
.
view
(
batch
.
size
,
-
1
,
head_dim
,
padded_sequence_length
)
.
shape
[
1
]
)
# This will run only once per layer
if
j
==
len
(
input_ids
[
"past_key_values"
]):
input_ids
[
"past_key_values"
].
append
([])
# Decoder past
for
k
,
t
in
enumerate
(
past
[:
2
]):
# Needed because BLOOM past shapes are not the same for keys and values
# Keys: [batch_size * num_heads, head_dim, seq_length]
# Values: [batch_size * num_heads, seq_length, head_dim]
head_dim_last
=
False
if
t
.
shape
[
-
2
]
==
head_dim
:
t
=
t
.
view
(
batch
.
size
,
num_heads
,
head_dim
,
padded_sequence_length
)
padded_t_shape
=
(
total_batch_size
,
num_heads
,
head_dim
,
max_sequence_length
-
1
,
)
elif
t
.
shape
[
-
1
]
==
head_dim
:
head_dim_last
=
True
t
=
t
.
view
(
batch
.
size
,
num_heads
,
padded_sequence_length
,
head_dim
)
padded_t_shape
=
(
total_batch_size
,
num_heads
,
max_sequence_length
-
1
,
head_dim
,
)
else
:
raise
ValueError
(
f
"shape
{
t
.
shape
}
is not valid"
)
# Initialize tensors
# This will run only once per layer and per past tensor
if
k
==
len
(
input_ids
[
"past_key_values"
][
j
]):
input_ids
[
"past_key_values"
][
j
].
append
(
torch
.
zeros
(
padded_t_shape
,
dtype
=
t
.
dtype
,
device
=
t
.
device
)
)
# We slice the past keys and values to remove the padding from previous batches
if
not
head_dim_last
:
input_ids
[
"past_key_values"
][
j
][
k
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_sequence_length
-
1
):,
]
=
t
[:,
:,
:,
-
(
batch
.
max_sequence_length
-
1
):]
else
:
input_ids
[
"past_key_values"
][
j
][
k
][
start_index
:
end_index
,
:,
-
(
batch
.
max_sequence_length
-
1
):,
:,
]
=
t
[:,
:,
-
(
batch
.
max_sequence_length
-
1
):,
:]
# Seq2SeqLM specific past (encoder past)
for
k
,
t
in
enumerate
(
past
[
2
:]):
if
max_encoded_sequence_length
is
None
:
max_encoded_sequence_length
=
max
(
max
(
batch
.
all_input_lengths
)
for
batch
in
batches
)
batch_max_encoded_sequence_length
=
max
(
batch
.
all_input_lengths
)
padded_t_shape
=
(
total_batch_size
,
num_heads
,
max_encoded_sequence_length
,
head_dim
)
idx
=
k
+
2
# Initialize tensors
# This will run only once per layer and per past tensor
if
idx
==
len
(
input_ids
[
"past_key_values"
][
j
]):
input_ids
[
"past_key_values"
][
j
].
append
(
torch
.
zeros
(
padded_t_shape
,
dtype
=
t
.
dtype
,
device
=
t
.
device
)
)
input_ids
[
"past_key_values"
][
j
][
idx
][
start_index
:
end_index
,
:,
-
batch_max_encoded_sequence_length
:,
:
]
=
t
[:,
:,
-
batch_max_encoded_sequence_length
:,
:]
start_index
+=
batch
.
size
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
all_input_ids
=
all_input_ids
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_sequence_length
=
max_sequence_length
,
)
@
dataclass
class
GeneratedText
:
request
:
generate_pb2
.
Request
output
:
str
tokens
:
int
def
to_pb
(
self
)
->
generate_pb2
.
GeneratedText
:
return
generate_pb2
.
GeneratedText
(
request
=
self
.
request
,
output
=
self
.
output
)
return
generate_pb2
.
GeneratedText
(
request
=
self
.
request
,
output
=
self
.
output
,
tokens
=
self
.
tokens
)
server/text_generation/server.py
View file @
c5665f5c
...
...
@@ -27,7 +27,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return
generate_pb2
.
ClearCacheResponse
()
async
def
Generate
(
self
,
request
,
context
):
batch
=
self
.
model
.
batch_type
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
device
)
batch
=
Batch
.
from_pb
(
request
.
batch
,
self
.
model
.
tokenizer
,
self
.
model
.
device
)
generated_texts
,
next_batch
=
self
.
model
.
generate_token
(
batch
)
self
.
cache
.
set
(
next_batch
)
...
...
@@ -51,7 +51,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batches
.
append
(
batch
)
if
len
(
batches
)
>
1
:
batch
=
self
.
model
.
batch_type
.
concatenate
(
batches
)
batch
=
Batch
.
concatenate
(
batches
)
else
:
batch
=
batches
[
0
]
...
...
server/text_generation/utils.py
View file @
c5665f5c
...
...
@@ -58,7 +58,8 @@ class NextTokenChooser:
class
StoppingCriteria
:
def
__init__
(
self
,
max_new_tokens
=
20
):
def
__init__
(
self
,
eos_token_id
,
max_new_tokens
=
20
):
self
.
eos_token_id
=
eos_token_id
self
.
max_new_tokens
=
max_new_tokens
self
.
current_tokens
=
0
...
...
@@ -66,6 +67,8 @@ class StoppingCriteria:
self
.
current_tokens
+=
1
if
self
.
current_tokens
>=
self
.
max_new_tokens
:
return
True
if
self
.
eos_token_id
is
not
None
and
all_ids
[
-
1
]
==
self
.
eos_token_id
:
return
True
return
False
...
...
@@ -124,11 +127,18 @@ def download_weights(model_name, extension=".safetensors"):
filenames
=
weight_hub_files
(
model_name
,
extension
)
download_function
=
partial
(
hf_hub_download
,
repo_id
=
model_name
,
local_files_only
=
False
hf_hub_download
,
repo_id
=
model_name
,
local_files_only
=
False
,
)
executor
=
ThreadPoolExecutor
(
max_workers
=
5
)
futures
=
[
executor
.
submit
(
download_function
,
filename
=
filename
)
for
filename
in
filenames
]
files
=
[
file
for
file
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
))]
futures
=
[
executor
.
submit
(
download_function
,
filename
=
filename
)
for
filename
in
filenames
]
files
=
[
file
for
file
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
))
]
return
files
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment