Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
4236e41b
Commit
4236e41b
authored
Nov 07, 2022
by
OlivierDehaene
Browse files
feat(server): Improved doc
parent
cea6051e
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
192 additions
and
98 deletions
+192
-98
Dockerfile
Dockerfile
+1
-6
README.md
README.md
+7
-7
launcher/src/main.rs
launcher/src/main.rs
+4
-0
server/Makefile
server/Makefile
+2
-13
server/poetry.lock
server/poetry.lock
+46
-1
server/pyproject.toml
server/pyproject.toml
+1
-0
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+4
-8
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+81
-53
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+46
-10
No files found.
Dockerfile
View file @
4236e41b
...
...
@@ -28,6 +28,7 @@ ENV LANG=C.UTF-8 \
MODEL_NAME=bigscience/bloom \
QUANTIZE=false \
NUM_GPUS=8 \
SAFETENSORS_FAST_GPU=1 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NCCL_ASYNC_ERROR_HANDLING=1 \
CUDA_HOME=/usr/local/cuda \
...
...
@@ -55,12 +56,6 @@ RUN cd server && make install-torch
# Install specific version of transformers
RUN
cd
server
&&
make install-transformers
# Install specific version of safetensors
# FIXME: This is a temporary fix while we wait for a new release
RUN
curl https://sh.rustup.rs
-sSf
| bash
-s
--
-y
ENV
PATH="/root/.cargo/bin:${PATH}"
RUN
cd
server
&&
make install-safetensors
# Install server
COPY
proto proto
COPY
server server
...
...
README.md
View file @
4236e41b
...
...
@@ -6,7 +6,8 @@
</div>
A Rust and gRPC server for text generation inference.
A Rust and gRPC server for text generation inference. Used in production at
[
HuggingFace
](
https://huggingface.co
)
to power Bloom, BloomZ and MT0-XXL api-inference widgets.
## Features
...
...
@@ -15,11 +16,11 @@ A Rust and gRPC server for text generation inference.
-
[
Safetensors
](
https://github.com/huggingface/safetensors
)
weight loading
-
45ms per token generation for BLOOM with 8xA100 80GB
## Officialy supported models
## Official
l
y supported models
-
BLOOM
-
BLOOMZ
-
BLOOM-560m
-
[
BLOOM
](
https://huggingface.co/bigscience/bloom
)
-
[
BLOOMZ
](
https://huggingface.co/bigscience/bloomz
)
-
[
MT0-XXL
](
https://huggingface.co/bigscience/mt0-xxl
)
Other models are supported on a best effort basis using:
...
...
@@ -90,5 +91,4 @@ make router-dev
## TODO:
-
[ ] Add tests for the
`server/model`
logic
-
[ ] Backport custom CUDA kernels to Transformers
-
[ ] Install safetensors with pip
\ No newline at end of file
-
[ ] Backport custom CUDA kernels to Transformers
\ No newline at end of file
launcher/src/main.rs
View file @
4236e41b
...
...
@@ -295,6 +295,10 @@ fn shard_manager(
"MASTER_PORT"
.parse
()
.unwrap
(),
master_port
.to_string
()
.parse
()
.unwrap
(),
),
(
"SAFETENSORS_FAST_GPU"
.parse
()
.unwrap
(),
"1"
.to_string
()
.parse
()
.unwrap
(),
),
];
// If the HUGGINGFACE_HUB_CACHE env var is set, pass it to the shard
...
...
server/Makefile
View file @
4236e41b
...
...
@@ -16,24 +16,13 @@ install-transformers:
mv
transformers-7302a24535e8dc5637ea5b4e4572fc971d404098 transformers
cd
transformers
&&
python setup.py
install
install-safetensors
:
# Install specific version of safetensors
pip
install
setuptools_rust
rm
safetensors
||
true
rm
safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5
||
true
curl
-L
-O
https://github.com/huggingface/safetensors/archive/634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
unzip 634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
rm
634deccbcbad5eaf417935281f8b3be7ebca69c5.zip
mv
safetensors-634deccbcbad5eaf417935281f8b3be7ebca69c5 safetensors
cd
safetensors/bindings/python
&&
python setup.py develop
install-torch
:
# Install specific version of torch
pip
install
torch
--extra-index-url
https://download.pytorch.org/whl/cu116
--no-cache-dir
install
:
gen-server install-torch install-transformers
install-safetensors
install
:
gen-server install-torch install-transformers
pip
install
pip
--upgrade
pip
install
-e
.
--no-cache-dir
run-dev
:
python
-m
torch.distributed.run
--nproc_per_node
=
2 text_generation/cli.py serve bigscience/bloom-560m
--sharded
\ No newline at end of file
SAFETENSORS_FAST_GPU
=
1 python
-m
torch.distributed.run
--nproc_per_node
=
2 text_generation/cli.py serve bigscience/bloom-560m
--sharded
\ No newline at end of file
server/poetry.lock
View file @
4236e41b
...
...
@@ -145,6 +145,18 @@ category = "main"
optional = false
python-versions = ">=3.6"
[[package]]
name = "safetensors"
version = "0.2.4"
description = "Fast and Safe Tensor serialization"
category = "main"
optional = false
python-versions = "*"
[package.extras]
dev = ["black (==22.3)", "flake8 (>=3.8.3)", "huggingface-hub", "isort (>=5.5.4)", "numpy", "pytest", "setuptools-rust"]
testing = ["black (==22.3)", "flake8 (>=3.8.3)", "huggingface-hub", "isort (>=5.5.4)", "numpy", "pytest", "setuptools-rust"]
[[package]]
name = "setuptools"
version = "65.5.0"
...
...
@@ -208,7 +220,7 @@ bnb = ["bitsandbytes"]
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "
224b1e379d6105fe911bff4563946a90dfa6ff5918cf2e7be59f8d4f7c5cd7cf
"
content-hash = "
3266187ef14fe8f9e29b3b6530d07781ea952aa670c0fe0de34be43efa231a67
"
[metadata.files]
accelerate = [
...
...
@@ -459,6 +471,39 @@ PyYAML = [
{file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"},
{file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"},
]
safetensors = [
{file = "safetensors-0.2.4-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:79c4a7610d7699c64d8531c43f758ded4990ebaa7b0887c2078640e6de44e726"},
{file = "safetensors-0.2.4-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:ef425a4ddd29612fe733a6eeca6ad8f3ee3939f530a032114974aac4c4667b89"},
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77758f8ba4de6e20bf394dd964854a926dee2efee82eaa95e6c0893e2a7d960c"},
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fb956e9090cce515649f00b491b5ddc0f9c3d989139016a8d69f9dcf57e8d3d9"},
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3e31b02d27249bd519f05ec9d189097c59fc6851c59daa1a86ef347659e33ac3"},
{file = "safetensors-0.2.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c2fead03a1497042efea4358574f3d7acf501b0c82e54d605f393f2b4e2aafe"},
{file = "safetensors-0.2.4-cp310-cp310-win32.whl", hash = "sha256:dce6ed3c7d13aafa574737eb3309c928adcb6781e879b41f0861be83b439cf3e"},
{file = "safetensors-0.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:1dfe727325a1342767c6725dc2cc1f00463eb40a1f5df37c338d8e03957e27ce"},
{file = "safetensors-0.2.4-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:c066bc7b90a582a01ec468fef61a7581b5c726bf12c50491cb6ea5db215ea5e0"},
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca6ed53dad5d7d0e67eb676528ff2ad345cac3a34010e4dc1e3736972de294a5"},
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ada03b44acbb036cfabe7066a8df4ad9b1ac05bb585a6b6c0f285f08e016381d"},
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:58a0902708daa7ec2b2293b46e85df61f4fa359ddfe648e7ac025a79e6f59627"},
{file = "safetensors-0.2.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e0a4e38f7cbb4bfc513588e52f349b906c941e74fbbe192f2b19fc34221d448"},
{file = "safetensors-0.2.4-cp37-cp37m-win32.whl", hash = "sha256:4f8695b77dd847203258f035f8468f8b701c90621cb6b457e109f8d89c27f16c"},
{file = "safetensors-0.2.4-cp37-cp37m-win_amd64.whl", hash = "sha256:16b08f33c753c7da64b3999beea7c30d58204a0820961e33881d05a331e3f5c0"},
{file = "safetensors-0.2.4-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:a381606804f23db9eede51135f5fbd1f75dda02100415ee150fd39eb1cd6be4c"},
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7aceae84d0c7233d83923029aaf8d184848561e0211ec98c5317327b3db025d6"},
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:da48fc929485cbd9ee22621e388764a7cef27b0205e73aee2ad75aadd7d67662"},
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2619b88f934c4de6b59de90c9dc00eae2d0e30f254a1daebd6eb232ac1f9a7a7"},
{file = "safetensors-0.2.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1f78b987ae1f6b71da8ea110164e4cab2ee31b53835d2a66279df89c5d73f0e"},
{file = "safetensors-0.2.4-cp38-cp38-win32.whl", hash = "sha256:34b3e60b5130fb0fe07114705e51d30aa2c7eae4c1d1e77d6f260fa4ade70ede"},
{file = "safetensors-0.2.4-cp38-cp38-win_amd64.whl", hash = "sha256:debaa4fa98a7af44ba6dcb6945efee77b8480284c2cb05918ab97cf511c40826"},
{file = "safetensors-0.2.4-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:90baaafc0c872a736124b341db54b0bdd61765cbf3a61418371066a37905b18d"},
{file = "safetensors-0.2.4-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b4bf7e23191d6a3ff00de141512869fc776e8ff159c872cb44af018cb04d45eb"},
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf11a3aba8796e548ceb0a65f34dcd334dcf0c4c891dccabe18a8b53918ae8ab"},
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:95c31935ea71d63a38c546654136d7f0dbf1e7aeb6564dbc2201bc1fe9b34e4c"},
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ef31776e2e081d6f075408eed34a0fbd524cbd19e50268bef02c238b209213b7"},
{file = "safetensors-0.2.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d06bb1d68148f6d6934352124d8cbfcf0db092f969db7187e348bd5cbf183db5"},
{file = "safetensors-0.2.4-cp39-cp39-win32.whl", hash = "sha256:5d546152b9a5bd58eae97c2ddefba394404d37ddedec305f7639c9b6054513e5"},
{file = "safetensors-0.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:553ecfd895d379c1e03a7c9241f7343b3af66573436969ed7eb95df81dfbe9af"},
{file = "safetensors-0.2.4.tar.gz", hash = "sha256:35c0719a898f1f1292464f4cd9370bb6c2698032f1db4d677489f078b66b5a75"},
]
setuptools = [
{file = "setuptools-65.5.0-py3-none-any.whl", hash = "sha256:f62ea9da9ed6289bfe868cd6845968a2c854d1427f8548d52cae02a42b4f0356"},
{file = "setuptools-65.5.0.tar.gz", hash = "sha256:512e5536220e38146176efb833d4a62aa726b7bbff82cfbc8ba9eaa3996e0b17"},
...
...
server/pyproject.toml
View file @
4236e41b
...
...
@@ -15,6 +15,7 @@ typer = "^0.6.1"
grpcio-reflection
=
"^1.49.1"
accelerate
=
"^0.12.0"
bitsandbytes
=
"^0.35.1"
safetensors
=
"^0.2.4"
[tool.poetry.extras]
bnb
=
["bitsandbytes"]
...
...
server/text_generation/models/__init__.py
View file @
4236e41b
...
...
@@ -9,17 +9,13 @@ __all__ = ["Model", "BLOOMSharded", "CausalLM", "Seq2SeqLM"]
def
get_model
(
model_name
:
str
,
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
if
model_name
.
startswith
(
"bigscience/bloom"
):
if
sharded
:
return
BLOOMSharded
(
model_name
,
quantize
)
return
BLOOMSharded
(
model_name
,
quantize
=
quantize
)
else
:
if
quantize
:
raise
ValueError
(
"quantization is not supported for non-sharded BLOOM"
)
return
CausalLM
(
model_name
)
return
CausalLM
(
model_name
,
quantize
=
quantize
)
else
:
if
sharded
:
raise
ValueError
(
"sharded is not supported for AutoModel"
)
if
quantize
:
raise
ValueError
(
"quantize is not supported for AutoModel"
)
try
:
return
CausalLM
(
model_name
)
return
CausalLM
(
model_name
,
quantize
=
quantize
)
except
Exception
as
e
:
return
Seq2SeqLM
(
model_name
)
return
Seq2SeqLM
(
model_name
,
quantize
=
quantize
)
server/text_generation/models/causal_lm.py
View file @
4236e41b
...
...
@@ -2,7 +2,7 @@ import torch
from
dataclasses
import
dataclass
from
transformers
import
AutoTokenizer
,
AutoModelForCausalLM
from
typing
import
Optional
,
Tuple
,
List
,
Dict
,
Type
from
typing
import
Optional
,
Tuple
,
List
,
Type
from
text_generation.models
import
Model
from
text_generation.models.types
import
GeneratedText
...
...
@@ -14,11 +14,23 @@ from text_generation.utils import NextTokenChooser, StoppingCriteria
class
CausalLMBatch
:
batch_id
:
int
requests
:
List
[
generate_pb2
.
Request
]
all_input_lengths
:
List
[
int
]
input_ids
:
Dict
[
str
,
torch
.
Tensor
]
# Decoder values
input_ids
:
torch
.
Tensor
attention_mask
:
torch
.
Tensor
past_key_values
:
Optional
[
List
[
Tuple
]]
# All tokens
all_input_ids
:
List
[
torch
.
Tensor
]
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
# Generation helpers
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
# Metadata used for padding
size
:
int
max_sequence_length
:
int
...
...
@@ -36,12 +48,12 @@ class CausalLMBatch:
inputs
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
all_
input_lengths
=
[]
input_lengths
=
[]
# Parse batch
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
all_
input_lengths
.
append
(
r
.
input_length
)
input_lengths
.
append
(
r
.
input_length
)
next_token_choosers
.
append
(
NextTokenChooser
(
temperature
=
r
.
parameters
.
temperature
,
...
...
@@ -56,21 +68,23 @@ class CausalLMBatch:
)
)
input
_id
s
=
tokenizer
(
tokenized_
inputs
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
8
).
to
(
device
)
all_input_ids
=
input
_id
s
[
"input_ids"
].
unsqueeze
(
-
1
)
all_input_ids
=
tokenized_
inputs
[
"input_ids"
].
unsqueeze
(
-
1
)
return
cls
(
batch_id
=
pb
.
id
,
requests
=
pb
.
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
input_ids
=
tokenized_inputs
[
"input_ids"
],
attention_mask
=
tokenized_inputs
[
"attention_mask"
],
past_key_values
=
None
,
all_input_ids
=
all_input_ids
,
input_lengths
=
input_lengths
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
pb
.
size
,
max_sequence_length
=
max
(
all_
input_lengths
),
max_sequence_length
=
max
(
input_lengths
),
)
@
classmethod
...
...
@@ -80,19 +94,23 @@ class CausalLMBatch:
max_sequence_length
=
max
(
batch
.
max_sequence_length
for
batch
in
batches
)
# Batch attributes
input_ids
=
{
"input_ids"
:
None
,
"attention_mask"
:
None
,
"past_key_values"
:
[]}
requests
=
[]
all_
input_lengths
=
[]
input_lengths
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
# Batch tensors
input_ids
=
None
attention_mask
=
None
past_key_values
=
[]
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index
=
0
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
all_
input_lengths
.
extend
(
batch
.
all_
input_lengths
)
input_lengths
.
extend
(
batch
.
input_lengths
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
...
...
@@ -101,32 +119,35 @@ class CausalLMBatch:
end_index
=
start_index
+
batch
.
size
# We only concatenate batches that did at least one step
if
batch
.
input_ids
[
"input_ids"
]
.
shape
[
1
]
>
1
:
if
batch
.
input_ids
.
shape
[
1
]
>
1
:
raise
ValueError
(
"Batch input_ids should be of shape (batch_size, 1)"
)
# Initialize tensors
if
i
==
0
:
input_ids
[
"input_ids"
]
=
torch
.
empty
(
# Create empty tensor
# input_ids is always of shape [batch_size, 1]
# We do not need to pad it
if
input_ids
is
None
:
input_ids
=
torch
.
empty
(
(
total_batch_size
,
1
),
dtype
=
batch
.
input_ids
[
"input_ids"
]
.
dtype
,
device
=
batch
.
input_ids
[
"input_ids"
]
.
device
,
dtype
=
batch
.
input_ids
.
dtype
,
device
=
batch
.
input_ids
.
device
,
)
input_ids
[
"attention_mask"
]
=
torch
.
zeros
(
# Copy to correct indices
input_ids
[
start_index
:
end_index
]
=
batch
.
input_ids
# Create padded tensor
if
attention_mask
is
None
:
attention_mask
=
torch
.
zeros
(
(
total_batch_size
,
max_sequence_length
),
dtype
=
batch
.
input_ids
[
"
attention_mask
"
]
.
dtype
,
device
=
batch
.
input_ids
[
"
attention_mask
"
]
.
device
,
dtype
=
batch
.
attention_mask
.
dtype
,
device
=
batch
.
attention_mask
.
device
,
)
# input_ids["input_ids"] is always of shape [batch_size, 1]
# We do not need to pad it
input_ids
[
"input_ids"
][
start_index
:
end_index
]
=
batch
.
input_ids
[
"input_ids"
]
# We need to slice the attention mask to remove padding from previous steps
input_ids
[
"
attention_mask
"
]
[
attention_mask
[
start_index
:
end_index
,
-
batch
.
max_sequence_length
:
]
=
batch
.
input_ids
[
"
attention_mask
"
]
[:,
-
batch
.
max_sequence_length
:]
]
=
batch
.
attention_mask
[:,
-
batch
.
max_sequence_length
:]
for
j
,
past
in
enumerate
(
batch
.
input_ids
[
"
past_key_values
"
]
):
for
j
,
past
in
enumerate
(
batch
.
past_key_values
):
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
head_dim
,
padded_sequence_length
=
past
[
0
].
shape
[
-
2
:]
...
...
@@ -137,8 +158,8 @@ class CausalLMBatch:
)
# This will run only once per layer
if
j
==
len
(
input_ids
[
"
past_key_values
"
]
):
input_ids
[
"
past_key_values
"
]
.
append
([])
if
j
==
len
(
past_key_values
):
past_key_values
.
append
([])
# Decoder past
for
k
,
t
in
enumerate
(
past
):
...
...
@@ -172,21 +193,21 @@ class CausalLMBatch:
# Initialize tensors
# This will run only once per layer and per past tensor
if
k
==
len
(
input_ids
[
"
past_key_values
"
]
[
j
]):
input_ids
[
"
past_key_values
"
]
[
j
].
append
(
if
k
==
len
(
past_key_values
[
j
]):
past_key_values
[
j
].
append
(
torch
.
zeros
(
padded_t_shape
,
dtype
=
t
.
dtype
,
device
=
t
.
device
)
)
# We slice the past keys and values to remove the padding from previous batches
if
not
head_dim_last
:
input_ids
[
"
past_key_values
"
]
[
j
][
k
][
past_key_values
[
j
][
k
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
]
=
t
[:,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:]
else
:
input_ids
[
"
past_key_values
"
]
[
j
][
k
][
past_key_values
[
j
][
k
][
start_index
:
end_index
,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
...
...
@@ -198,9 +219,11 @@ class CausalLMBatch:
return
cls
(
batch_id
=
batches
[
0
].
batch_id
,
requests
=
requests
,
all_input_lengths
=
all_input_lengths
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
past_key_values
=
past_key_values
,
all_input_ids
=
all_input_ids
,
input_lengths
=
input_lengths
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
...
...
@@ -209,7 +232,7 @@ class CausalLMBatch:
class
CausalLM
(
Model
):
def
__init__
(
self
,
model_name
:
str
):
def
__init__
(
self
,
model_name
:
str
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
...
...
@@ -223,6 +246,7 @@ class CausalLM(Model):
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
load_in_8bit
=
quantize
,
).
eval
()
super
(
CausalLM
,
self
).
__init__
(
...
...
@@ -255,16 +279,19 @@ class CausalLM(Model):
torch
.
no_grad
if
self
.
device
.
type
==
"cpu"
else
torch
.
inference_mode
)
with
context_manager
():
logits
,
past
=
self
.
forward
(
**
batch
.
input_ids
)
logits
,
past
=
self
.
forward
(
batch
.
input_ids
,
batch
.
attention_mask
,
batch
.
past_key_values
)
# List of indices to cache
next_batch_keep_indices
=
[]
# New input_ids for next forward
# New values for next forward
next_batch_input_lengths
=
[]
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_all_input_lengths
=
[]
# Metadata
next_batch_size
=
0
next_batch_max_sequence_length
=
0
...
...
@@ -274,7 +301,7 @@ class CausalLM(Model):
# Zipped iterator
iterator
=
zip
(
batch
.
requests
,
batch
.
all_
input_lengths
,
batch
.
input_lengths
,
logits
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
...
...
@@ -313,7 +340,7 @@ class CausalLM(Model):
next_batch_all_input_ids
.
append
(
all_tokens
)
next_batch_size
+=
1
new_input_length
=
input_length
+
1
next_
all
_input_lengths
.
append
(
new_input_length
)
next_
batch
_input_lengths
.
append
(
new_input_length
)
next_batch_max_sequence_length
=
max
(
next_batch_max_sequence_length
,
new_input_length
)
...
...
@@ -322,15 +349,14 @@ class CausalLM(Model):
if
not
next_batch_keep_indices
:
return
generated_texts
,
None
# If we finished at least one generation
next_batch_input_ids
=
{
"input_ids"
:
torch
.
cat
(
next_batch_input_ids
,
dim
=
0
)}
next_batch_input_ids
=
torch
.
cat
(
next_batch_input_ids
,
dim
=
0
)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if
generated_texts
:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids
[
"attention_mask"
]
=
batch
.
input_ids
[
"attention_mask"
][
next_batch_keep_indices
]
next_batch_attention_mask
=
batch
.
attention_mask
[
next_batch_keep_indices
]
# Force past to be of dim [batch_size, num_heads, ...] for easy indexing
next_batch_
input_ids
[
"
past_key_values
"
]
=
[
next_batch_past_key_values
=
[
[
t
.
view
(
-
1
,
self
.
num_heads
,
*
t
.
shape
[
-
2
:])[
next_batch_keep_indices
]
for
t
in
layer
...
...
@@ -345,16 +371,16 @@ class CausalLM(Model):
batch
.
stopping_criterias
[
i
]
for
i
in
next_batch_keep_indices
]
else
:
next_batch_
input_ids
[
"
attention_mask
"
]
=
batch
.
input_ids
[
"
attention_mask
"
]
next_batch_
input_ids
[
"
past_key_values
"
]
=
past
next_batch_attention_mask
=
batch
.
attention_mask
next_batch_past_key_values
=
past
next_batch_requests
=
batch
.
requests
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
next_batch_
input_ids
[
"
attention_mask
"
]
=
torch
.
cat
(
next_batch_attention_mask
=
torch
.
cat
(
[
next_batch_
input_ids
[
"
attention_mask
"
]
,
next_batch_attention_mask
,
torch
.
ones
((
next_batch_size
,
1
)).
to
(
self
.
device
),
],
dim
=
1
,
...
...
@@ -363,9 +389,11 @@ class CausalLM(Model):
next_batch
=
CausalLMBatch
(
batch_id
=
batch
.
batch_id
,
requests
=
next_batch_requests
,
all_input_lengths
=
next_all_input_lengths
,
input_ids
=
next_batch_input_ids
,
attention_mask
=
next_batch_attention_mask
,
past_key_values
=
next_batch_past_key_values
,
all_input_ids
=
next_batch_all_input_ids
,
input_lengths
=
next_batch_input_lengths
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
...
...
server/text_generation/models/seq2seq_lm.py
View file @
4236e41b
...
...
@@ -15,26 +15,33 @@ class Seq2SeqLMBatch:
batch_id
:
int
requests
:
List
[
generate_pb2
.
Request
]
# Encoder values
input_ids
:
torch
.
Tensor
attention_mask
:
torch
.
Tensor
# Decoder values
decoder_input_ids
:
torch
.
Tensor
decoder_attention_mask
:
Optional
[
torch
.
Tensor
]
encoder_last_hidden_state
:
Optional
[
torch
.
Tensor
]
# Seq2SeqLM keeps track of both encoder and decoder attention keys and values
past_key_values
:
Optional
[
List
[
Tuple
]]
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
decoder_input_lengths
:
List
[
int
]
# Generation helpers
next_token_choosers
:
List
[
NextTokenChooser
]
stopping_criterias
:
List
[
StoppingCriteria
]
# Metadata used for padding
size
:
int
max_input_length
:
int
max_decoder_input_length
:
int
def
to_pb
(
self
):
"""Convert a Seq2SeqLMBatch to a text_generation.v1.Batch protobuf"""
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
requests
=
self
.
requests
,
...
...
@@ -45,6 +52,7 @@ class Seq2SeqLMBatch:
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
AutoTokenizer
,
device
:
torch
.
device
)
->
"Seq2SeqLMBatch"
:
"""Convert a text_generation.v1.Batch protobuf to a Seq2SeqLMBatch"""
inputs
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
...
...
@@ -57,6 +65,7 @@ class Seq2SeqLMBatch:
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
input_lengths
.
append
(
r
.
input_length
)
# Decoder sequence only contains the bos_token
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_lengths
.
append
(
1
)
next_token_choosers
.
append
(
...
...
@@ -73,9 +82,11 @@ class Seq2SeqLMBatch:
)
)
# Tokenize batch
tokenized_inputs
=
tokenizer
(
inputs
,
return_tensors
=
"pt"
,
padding
=
True
,
pad_to_multiple_of
=
8
).
to
(
device
)
# Convert decoder_input_ids to torch tensor of size [batch_size, 1]
decoder_input_ids
=
torch
.
tensor
(
decoder_input_ids
).
to
(
device
).
unsqueeze
(
-
1
)
return
cls
(
...
...
@@ -98,6 +109,8 @@ class Seq2SeqLMBatch:
@
classmethod
def
concatenate
(
cls
,
batches
:
List
[
"Seq2SeqLMBatch"
])
->
"Seq2SeqLMBatch"
:
"""Concatenate multiple batches together by padding internal torch tensors"""
# Used for padding
total_batch_size
=
sum
(
batch
.
size
for
batch
in
batches
)
max_input_length
=
max
(
batch
.
max_input_length
for
batch
in
batches
)
...
...
@@ -112,6 +125,7 @@ class Seq2SeqLMBatch:
next_token_choosers
=
[]
stopping_criterias
=
[]
# Batch tensors
input_ids
=
None
attention_mask
=
None
decoder_input_ids
=
None
...
...
@@ -122,7 +136,9 @@ class Seq2SeqLMBatch:
# Used for slicing correctly inside the tensors
# Equivalent to a cumsum on batch sizes
start_index
=
0
for
i
,
batch
in
enumerate
(
batches
):
# Extend all list attributes
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
decoder_input_lengths
.
extend
(
batch
.
decoder_input_lengths
)
...
...
@@ -136,51 +152,62 @@ class Seq2SeqLMBatch:
if
batch
.
encoder_last_hidden_state
is
None
:
raise
ValueError
(
"Batch encoder_last_hidden_state cannot be None"
)
# Create padded tensor
if
input_ids
is
None
:
input_ids
=
torch
.
zeros
(
(
total_batch_size
,
max_input_length
),
dtype
=
batch
.
input_ids
.
dtype
,
device
=
batch
.
input_ids
.
device
,
)
# Copy to correct indices
input_ids
[
start_index
:
end_index
,
-
batch
.
max_input_length
:
]
=
batch
.
input_ids
[:,
-
batch
.
max_input_length
:]
# Create padded tensor
if
attention_mask
is
None
:
attention_mask
=
torch
.
zeros
(
(
total_batch_size
,
max_input_length
),
dtype
=
batch
.
attention_mask
.
dtype
,
device
=
batch
.
attention_mask
.
device
,
)
# Copy to correct indices
attention_mask
[
start_index
:
end_index
,
-
batch
.
max_input_length
:
]
=
batch
.
attention_mask
[:,
-
batch
.
max_input_length
:]
# Create padded tensor
if
decoder_input_ids
is
None
:
decoder_input_ids
=
torch
.
zeros
(
(
total_batch_size
,
max_decoder_input_length
),
dtype
=
batch
.
decoder_input_ids
.
dtype
,
device
=
batch
.
decoder_input_ids
.
device
,
)
# Copy to correct indices
decoder_input_ids
[
start_index
:
end_index
,
-
batch
.
max_decoder_input_length
:
]
=
batch
.
decoder_input_ids
[:,
-
batch
.
max_decoder_input_length
:]
# Create padded tensor
if
decoder_attention_mask
is
None
:
decoder_attention_mask
=
torch
.
zeros
(
(
total_batch_size
,
max_decoder_input_length
),
dtype
=
batch
.
attention_mask
.
dtype
,
device
=
batch
.
attention_mask
.
device
,
dtype
=
batch
.
attention_mask
.
dtype
,
# As decoder_attention_mask might not exist,
device
=
batch
.
attention_mask
.
device
,
# we use `batch.attention_maks` for device here
)
# If the decoder mask does not exist yet, all generations started at the same time and we never concatenated
# this batch. All generations are of length `batch.max_decoder_input_length`.
if
batch
.
decoder_attention_mask
is
None
:
decoder_attention_mask
[
start_index
:
end_index
,
-
batch
.
max_decoder_input_length
:
]
=
1
# If it exists, we need to index
else
:
decoder_attention_mask
[
start_index
:
end_index
,
-
batch
.
max_decoder_input_length
:
]
=
batch
.
decoder_attention_mask
[:,
-
batch
.
max_decoder_input_length
:]
# Create padded tensor
if
encoder_last_hidden_state
is
None
:
encoder_last_hidden_state
=
torch
.
zeros
(
(
...
...
@@ -192,10 +219,12 @@ class Seq2SeqLMBatch:
device
=
batch
.
encoder_last_hidden_state
.
device
,
)
# Copy to correct indices
encoder_last_hidden_state
[
start_index
:
end_index
,
-
batch
.
max_decoder_input_length
:,
:
]
=
batch
.
encoder_last_hidden_state
[:,
-
batch
.
max_decoder_input_length
:,
:]
# Iterate over attention layers
for
j
,
past
in
enumerate
(
batch
.
past_key_values
):
_
,
num_heads
,
_
,
head_dim
=
past
[
0
].
shape
...
...
@@ -271,7 +300,7 @@ class Seq2SeqLMBatch:
class
Seq2SeqLM
(
Model
):
def
__init__
(
self
,
model_name
:
str
):
def
__init__
(
self
,
model_name
:
str
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
...
...
@@ -283,6 +312,7 @@ class Seq2SeqLM(Model):
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
load_in_8bit
=
quantize
,
).
eval
()
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
,
padding_side
=
"left"
)
tokenizer
.
bos_token_id
=
self
.
model
.
config
.
decoder_start_token_id
...
...
@@ -314,14 +344,17 @@ class Seq2SeqLM(Model):
if
past_key_values
is
not
None
:
decoder_input_ids
=
decoder_input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
# internally...
if
encoder_last_hidden_state
is
not
None
:
encoder_last_hidden_state
=
[
encoder_last_hidden_state
]
outputs
=
self
.
model
.
forward
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
decoder_input_ids
=
decoder_input_ids
,
decoder_attention_mask
=
decoder_attention_mask
,
encoder_outputs
=
[
encoder_last_hidden_state
]
if
encoder_last_hidden_state
is
not
None
else
None
,
encoder_outputs
=
encoder_last_hidden_state
,
past_key_values
=
past_key_values
,
use_cache
=
True
,
)
...
...
@@ -351,11 +384,12 @@ class Seq2SeqLM(Model):
# List of indices to cache
next_batch_keep_indices
=
[]
# New
input_id
s for next forward
# New
value
s for next forward
next_batch_input_lengths
=
[]
next_batch_decoder_input_ids
=
[]
next_batch_decoder_input_lengths
=
[]
# Metadata
next_batch_size
=
0
next_batch_max_input_length
=
0
next_batch_max_decoder_input_length
=
0
...
...
@@ -395,7 +429,7 @@ class Seq2SeqLM(Model):
# Evaluate stopping criteria
if
stopping_criteria
(
decoder_tokens
):
# Decode
all
tokens
# Decode tokens
output
=
self
.
tokenizer
.
decode
(
decoder_tokens
,
skip_special_tokens
=
True
)
# Add to the list of finished generations with the original request
generated_texts
.
append
(
...
...
@@ -420,9 +454,11 @@ class Seq2SeqLM(Model):
if
not
next_batch_keep_indices
:
return
generated_texts
,
None
# If we finished at least one generation
next_batch_decoder_input_ids
=
torch
.
cat
(
next_batch_decoder_input_ids
)
# If we finished at least one generation, we need to evict the indices of the generations that finished
# from the values of the next batch
if
generated_texts
:
# Apply indices to attention mask, past key values and other items that need to be cached
next_batch_input_ids
=
batch
.
input_ids
[
next_batch_keep_indices
]
next_batch_attention_mask
=
batch
.
attention_mask
[
next_batch_keep_indices
]
...
...
@@ -458,7 +494,7 @@ class Seq2SeqLM(Model):
next_batch_next_token_choosers
=
batch
.
next_token_choosers
next_batch_stopping_criterias
=
batch
.
stopping_criterias
# Update attention_mask with padding as we added a new token to input_ids
# Update
decoder_
attention_mask with padding as we added a new token to input_ids
if
next_batch_decoder_attention_mask
is
not
None
:
next_batch_decoder_attention_mask
=
torch
.
cat
(
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment