Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
a2985036
Unverified
Commit
a2985036
authored
Dec 08, 2022
by
OlivierDehaene
Committed by
GitHub
Dec 08, 2022
Browse files
feat(server): Add model tests (#6)
parent
31d76e23
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1105 additions
and
29 deletions
+1105
-29
README.md
README.md
+1
-6
router/src/batcher.rs
router/src/batcher.rs
+1
-1
server/Makefile
server/Makefile
+3
-2
server/poetry.lock
server/poetry.lock
+98
-1
server/pyproject.toml
server/pyproject.toml
+1
-0
server/tests/conftest.py
server/tests/conftest.py
+36
-0
server/tests/models/test_bloom.py
server/tests/models/test_bloom.py
+279
-0
server/tests/models/test_causal_lm.py
server/tests/models/test_causal_lm.py
+296
-0
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+306
-0
server/tests/test_utils.py
server/tests/test_utils.py
+34
-0
server/text_generation/models/__init__.py
server/text_generation/models/__init__.py
+3
-3
server/text_generation/models/bloom.py
server/text_generation/models/bloom.py
+22
-2
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+17
-9
server/text_generation/models/galactica.py
server/text_generation/models/galactica.py
+1
-1
server/text_generation/models/seq2seq_lm.py
server/text_generation/models/seq2seq_lm.py
+5
-2
server/text_generation/utils.py
server/text_generation/utils.py
+2
-2
No files found.
README.md
View file @
a2985036
...
...
@@ -88,8 +88,3 @@ curl 127.0.0.1:3000/generate \
make server-dev
make router-dev
```
\ No newline at end of file
## TODO:
-
[ ] Add tests for the
`server/model`
logic
-
[ ] Backport custom CUDA kernels to Transformers
\ No newline at end of file
router/src/batcher.rs
View file @
a2985036
...
...
@@ -70,7 +70,7 @@ impl Batcher {
// Notify the background task that we have a new entry in the database that needs
// to be batched
self
.shared.batching_task
.notify_
waiters
();
self
.shared.batching_task
.notify_
one
();
// Await on the response from the background task
// We can safely unwrap as the background task will never drop the sender
...
...
server/Makefile
View file @
a2985036
...
...
@@ -8,8 +8,9 @@ gen-server:
install-transformers
:
# Install specific version of transformers with custom cuda kernels
rm
transformers
||
true
rm
transformers-text_generation_inference
||
true
pip uninstall transformers
-y
||
true
rm
-rf
transformers
||
true
rm
-rf
transformers-text_generation_inference
||
true
curl
-L
-O
https://github.com/OlivierDehaene/transformers/archive/refs/heads/text_generation_inference.zip
unzip text_generation_inference.zip
rm
text_generation_inference.zip
...
...
server/poetry.lock
View file @
a2985036
...
...
@@ -22,6 +22,20 @@ test_prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"]
test_trackers = ["comet-ml", "tensorboard", "wandb"]
testing = ["datasets", "deepspeed (<0.7.0)", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scipy", "sklearn", "tqdm", "transformers"]
[[package]]
name = "attrs"
version = "22.1.0"
description = "Classes Without Boilerplate"
category = "dev"
optional = false
python-versions = ">=3.5"
[package.extras]
dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"]
docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"]
tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"]
tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"]
[[package]]
name = "bitsandbytes"
version = "0.35.1"
...
...
@@ -49,6 +63,17 @@ category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
[[package]]
name = "exceptiongroup"
version = "1.0.4"
description = "Backport of PEP 654 (exception groups)"
category = "dev"
optional = false
python-versions = ">=3.7"
[package.extras]
test = ["pytest (>=6)"]
[[package]]
name = "grpcio"
version = "1.50.0"
...
...
@@ -88,6 +113,14 @@ grpcio = ">=1.50.0"
protobuf = ">=4.21.6,<5.0dev"
setuptools = "*"
[[package]]
name = "iniconfig"
version = "1.1.1"
description = "iniconfig: brain-dead simple config-ini parsing"
category = "dev"
optional = false
python-versions = "*"
[[package]]
name = "numpy"
version = "1.23.4"
...
...
@@ -107,6 +140,18 @@ python-versions = ">=3.6"
[package.dependencies]
pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
[[package]]
name = "pluggy"
version = "1.0.0"
description = "plugin and hook calling mechanisms for python"
category = "dev"
optional = false
python-versions = ">=3.6"
[package.extras]
dev = ["pre-commit", "tox"]
testing = ["pytest", "pytest-benchmark"]
[[package]]
name = "protobuf"
version = "4.21.8"
...
...
@@ -137,6 +182,26 @@ python-versions = ">=3.6.8"
[package.extras]
diagrams = ["jinja2", "railroad-diagrams"]
[[package]]
name = "pytest"
version = "7.2.0"
description = "pytest: simple powerful testing with Python"
category = "dev"
optional = false
python-versions = ">=3.7"
[package.dependencies]
attrs = ">=19.2.0"
colorama = {version = "*", markers = "sys_platform == \"win32\""}
exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""}
iniconfig = "*"
packaging = "*"
pluggy = ">=0.12,<2.0"
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
[package.extras]
testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "xmlschema"]
[[package]]
name = "PyYAML"
version = "6.0"
...
...
@@ -178,6 +243,14 @@ category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
[[package]]
name = "tomli"
version = "2.0.1"
description = "A lil' TOML parser"
category = "dev"
optional = false
python-versions = ">=3.7"
[[package]]
name = "torch"
version = "1.12.1"
...
...
@@ -220,13 +293,17 @@ bnb = ["bitsandbytes"]
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "
3266187ef14fe8f9e29b3b6530d07781ea952aa670c0fe0de34be43efa231a67
"
content-hash = "
51693654531e3229ac64bee250932ace20a60e8d45af074ae7b860ed32b25ef8
"
[metadata.files]
accelerate = [
{file = "accelerate-0.12.0-py3-none-any.whl", hash = "sha256:7742ca5c9f15dd1e0a283305599c196e260af4717a561d1f544aeab27d828af6"},
{file = "accelerate-0.12.0.tar.gz", hash = "sha256:e8b119c94fac31877620d5f9de311164ec81fa9dc9e175f0d0d4f50fc8d79473"},
]
attrs = [
{file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"},
{file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"},
]
bitsandbytes = [
{file = "bitsandbytes-0.35.1-py3-none-any.whl", hash = "sha256:4506a9e3778359a743938aa5592d8d043fa91d1df66cd01ba8cc6486e64dea45"},
{file = "bitsandbytes-0.35.1.tar.gz", hash = "sha256:63a6f59c87b713a731a685e43d68c19789ee6381e62196cafab293b87eca5d46"},
...
...
@@ -239,6 +316,10 @@ colorama = [
{file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"},
{file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"},
]
exceptiongroup = [
{file = "exceptiongroup-1.0.4-py3-none-any.whl", hash = "sha256:542adf9dea4055530d6e1279602fa5cb11dab2395fa650b8674eaec35fc4a828"},
{file = "exceptiongroup-1.0.4.tar.gz", hash = "sha256:bd14967b79cd9bdb54d97323216f8fdf533e278df937aa2a90089e7d6e06e5ec"},
]
grpcio = [
{file = "grpcio-1.50.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:906f4d1beb83b3496be91684c47a5d870ee628715227d5d7c54b04a8de802974"},
{file = "grpcio-1.50.0-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:2d9fd6e38b16c4d286a01e1776fdf6c7a4123d99ae8d6b3f0b4a03a34bf6ce45"},
...
...
@@ -337,6 +418,10 @@ grpcio-tools = [
{file = "grpcio_tools-1.50.0-cp39-cp39-win32.whl", hash = "sha256:e1a8f9a57bbcc2e633aaf327e39830527f3c1f7add18c7580f3058fe9a0fa780"},
{file = "grpcio_tools-1.50.0-cp39-cp39-win_amd64.whl", hash = "sha256:b7eb7a84d9171c0ae1550833f4a6ca52372bed9db0fa10f8c9dbe6ca65f97a8c"},
]
iniconfig = [
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
]
numpy = [
{file = "numpy-1.23.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:95d79ada05005f6f4f337d3bb9de8a7774f259341c70bc88047a1f7b96a4bcb2"},
{file = "numpy-1.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:926db372bc4ac1edf81cfb6c59e2a881606b409ddc0d0920b988174b2e2a767f"},
...
...
@@ -371,6 +456,10 @@ packaging = [
{file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"},
{file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
]
pluggy = [
{file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
{file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
]
protobuf = [
{file = "protobuf-4.21.8-cp310-abi3-win32.whl", hash = "sha256:c252c55ee15175aa1b21b7b9896e6add5162d066d5202e75c39f96136f08cce3"},
{file = "protobuf-4.21.8-cp310-abi3-win_amd64.whl", hash = "sha256:809ca0b225d3df42655a12f311dd0f4148a943c51f1ad63c38343e457492b689"},
...
...
@@ -429,6 +518,10 @@ pyparsing = [
{file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
{file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"},
]
pytest = [
{file = "pytest-7.2.0-py3-none-any.whl", hash = "sha256:892f933d339f068883b6fd5a459f03d85bfcb355e4981e146d2c7616c21fef71"},
{file = "pytest-7.2.0.tar.gz", hash = "sha256:c4014eb40e10f11f355ad4e3c2fb2c6c6d1919c73f3b5a433de4708202cade59"},
]
PyYAML = [
{file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"},
{file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"},
...
...
@@ -512,6 +605,10 @@ six = [
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
]
tomli = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
]
torch = [
{file = "torch-1.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c038662db894a23e49e385df13d47b2a777ffd56d9bcd5b832593fab0a7e286"},
{file = "torch-1.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:4e1b9c14cf13fd2ab8d769529050629a0e68a6fc5cb8e84b4a3cc1dd8c4fe541"},
...
...
server/pyproject.toml
View file @
a2985036
...
...
@@ -22,6 +22,7 @@ bnb = ["bitsandbytes"]
[tool.poetry.group.dev.dependencies]
grpcio-tools
=
"^1.49.1"
pytest
=
"^7.2.0"
[build-system]
requires
=
["poetry-core>=1.0.0"]
...
...
server/tests/conftest.py
0 → 100644
View file @
a2985036
import
pytest
from
transformers
import
AutoTokenizer
from
text_generation.pb
import
generate_pb2
@
pytest
.
fixture
def
default_pb_parameters
():
return
generate_pb2
.
LogitsWarperParameters
(
temperature
=
1.0
,
top_k
=
0
,
top_p
=
1.0
,
do_sample
=
False
,
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
bloom_560m_tokenizer
():
return
AutoTokenizer
.
from_pretrained
(
"bigscience/bloom-560m"
,
padding_side
=
"left"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
gpt2_tokenizer
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"gpt2"
,
padding_side
=
"left"
)
tokenizer
.
pad_token_id
=
50256
return
tokenizer
@
pytest
.
fixture
(
scope
=
"session"
)
def
mt0_small_tokenizer
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"bigscience/mt0-small"
,
padding_side
=
"left"
)
tokenizer
.
bos_token_id
=
0
return
tokenizer
server/tests/models/test_bloom.py
0 → 100644
View file @
a2985036
import
pytest
import
torch
from
copy
import
copy
from
text_generation.pb
import
generate_pb2
from
text_generation.models.causal_lm
import
CausalLMBatch
from
text_generation.models.bloom
import
BloomCausalLMBatch
,
BLOOM
@
pytest
.
fixture
def
default_pb_request
(
default_pb_parameters
):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"Test"
,
input_length
=
1
,
parameters
=
default_pb_parameters
,
max_new_tokens
=
10
,
)
@
pytest
.
fixture
def
default_pb_batch
(
default_pb_request
):
return
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
default_pb_request
],
size
=
1
)
@
pytest
.
fixture
def
default_bloom_batch
(
default_pb_batch
,
bloom_560m_tokenizer
):
return
BloomCausalLMBatch
.
from_pb
(
default_pb_batch
,
bloom_560m_tokenizer
,
torch
.
device
(
"cpu"
)
)
@
pytest
.
fixture
def
default_multi_requests_bloom_batch
(
default_pb_request
,
bloom_560m_tokenizer
):
req_0
=
copy
(
default_pb_request
)
req_1
=
default_pb_request
req_1
.
id
=
1
req_1
.
max_new_tokens
=
5
batch_pb
=
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
req_0
,
req_1
],
size
=
2
)
return
BloomCausalLMBatch
.
from_pb
(
batch_pb
,
bloom_560m_tokenizer
,
torch
.
device
(
"cpu"
)
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
default_bloom
():
return
BLOOM
(
"bigscience/bloom-560m"
)
def
test_batch_from_pb
(
default_pb_batch
,
default_bloom_batch
):
batch
=
default_bloom_batch
assert
batch
.
batch_id
==
default_pb_batch
.
id
assert
batch
.
requests
==
default_pb_batch
.
requests
assert
len
(
batch
.
input_ids
)
==
default_pb_batch
.
size
assert
len
(
batch
.
input_ids
[
0
])
==
8
assert
batch
.
input_ids
[
0
][
-
1
]
==
10264
assert
torch
.
all
(
batch
.
input_ids
[
0
][:
-
1
]
==
3
)
assert
batch
.
attention_mask
[
0
][
-
1
]
==
1
assert
torch
.
all
(
batch
.
attention_mask
[
0
][:
-
1
]
==
0
)
assert
batch
.
past_key_values
is
None
assert
torch
.
equal
(
batch
.
input_ids
,
batch
.
all_input_ids
[:,
:,
0
])
assert
batch
.
input_lengths
==
[
1
]
assert
batch
.
size
==
default_pb_batch
.
size
assert
len
(
batch
.
next_token_choosers
)
==
len
(
batch
.
stopping_criterias
)
==
batch
.
size
assert
batch
.
max_sequence_length
==
batch
.
input_lengths
[
0
]
def
test_batch_concatenate_no_prefill
(
default_bloom_batch
):
with
pytest
.
raises
(
ValueError
):
BloomCausalLMBatch
.
concatenate
([
default_bloom_batch
,
default_bloom_batch
])
def
test_causal_lm_batch_type
(
default_bloom
):
assert
default_bloom
.
batch_type
==
BloomCausalLMBatch
def
test_causal_lm_generate_token
(
default_bloom
,
default_bloom_batch
):
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
default_bloom_batch
)
assert
generated_texts
==
[]
assert
isinstance
(
next_batch
,
CausalLMBatch
)
assert
not
next_batch
.
keys_head_dim_last
assert
len
(
next_batch
.
all_input_ids
)
==
next_batch
.
size
assert
len
(
next_batch
.
all_input_ids
[
0
])
==
len
(
next_batch
.
attention_mask
[
0
])
==
9
assert
torch
.
all
(
next_batch
.
all_input_ids
[
0
][
-
2
:]
==
10264
)
assert
torch
.
all
(
next_batch
.
all_input_ids
[
0
][:
-
2
]
==
3
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][
-
2
:]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][:
-
2
]
==
0
)
assert
next_batch
.
input_ids
.
shape
==
(
next_batch
.
size
,
1
)
assert
next_batch
.
input_ids
[
0
,
0
]
==
10264
assert
next_batch
.
input_lengths
==
[
2
]
assert
next_batch
.
max_sequence_length
==
next_batch
.
input_lengths
[
0
]
assert
next_batch
.
past_key_values
is
not
None
assert
all
([
p
[
0
].
shape
==
(
16
,
64
,
8
)
for
p
in
next_batch
.
past_key_values
])
assert
all
([
p
[
1
].
shape
==
(
16
,
8
,
64
)
for
p
in
next_batch
.
past_key_values
])
def
test_causal_lm_generate_token_completion
(
default_bloom
,
default_bloom_batch
):
next_batch
=
default_bloom_batch
for
_
in
range
(
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"TestTestTestTestTestTestTestTestTestTestTest"
assert
generated_texts
[
0
].
request
==
default_bloom_batch
.
requests
[
0
]
assert
(
generated_texts
[
0
].
tokens
==
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
def
test_causal_lm_generate_token_completion_multi
(
default_bloom
,
default_multi_requests_bloom_batch
):
next_batch
=
default_multi_requests_bloom_batch
for
i
in
range
(
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
1
):
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"TestTestTestTestTestTest"
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
1
]
assert
(
generated_texts
[
0
].
tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
for
_
in
range
(
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
1
):
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"TestTestTestTestTestTestTestTestTestTestTest"
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
0
]
assert
(
generated_texts
[
0
].
tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
def
test_batch_concatenate
(
default_bloom
,
default_bloom_batch
,
default_multi_requests_bloom_batch
):
next_batch_0
=
default_bloom_batch
_
,
next_batch_0
=
default_bloom
.
generate_token
(
next_batch_0
)
_
,
next_batch_0
=
default_bloom
.
generate_token
(
next_batch_0
)
next_batch_1
=
default_multi_requests_bloom_batch
_
,
next_batch_1
=
default_bloom
.
generate_token
(
next_batch_1
)
next_batch
=
BloomCausalLMBatch
.
concatenate
([
next_batch_0
,
next_batch_1
])
assert
torch
.
equal
(
next_batch
.
all_input_ids
[
0
],
next_batch_0
.
all_input_ids
[
0
])
assert
torch
.
equal
(
next_batch
.
all_input_ids
[
1
],
next_batch_1
.
all_input_ids
[
0
])
assert
torch
.
equal
(
next_batch
.
all_input_ids
[
2
],
next_batch_1
.
all_input_ids
[
1
])
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
-
2
:]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
:
-
2
]
==
0
)
assert
next_batch
.
batch_id
==
0
assert
torch
.
all
(
next_batch
.
input_ids
==
10264
)
assert
next_batch
.
input_lengths
==
[
3
,
2
,
2
]
assert
next_batch
.
max_sequence_length
==
3
assert
next_batch
.
requests
[
0
]
==
next_batch_0
.
requests
[
0
]
assert
next_batch
.
requests
[
1
:]
==
next_batch_1
.
requests
assert
next_batch
.
next_token_choosers
[
0
]
==
next_batch_0
.
next_token_choosers
[
0
]
assert
next_batch
.
next_token_choosers
[
1
:]
==
next_batch_1
.
next_token_choosers
assert
next_batch
.
stopping_criterias
[
0
]
==
next_batch_0
.
stopping_criterias
[
0
]
assert
next_batch
.
stopping_criterias
[
1
:]
==
next_batch_1
.
stopping_criterias
assert
next_batch
.
past_key_values
is
not
None
assert
all
([
p
[
0
].
shape
==
(
3
,
16
,
64
,
2
)
for
p
in
next_batch
.
past_key_values
])
assert
all
([
p
[
1
].
shape
==
(
3
,
16
,
2
,
64
)
for
p
in
next_batch
.
past_key_values
])
for
i
,
past
in
enumerate
(
next_batch
.
past_key_values
):
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
0
][:,
:,
-
2
:],
past
[
0
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
0
][:,
:,
-
1
:],
past
[
0
][
1
:,
:,
:,
-
1
].
reshape
(
-
1
,
64
,
1
),
)
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
1
][:,
-
2
:,
:],
past
[
1
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
1
][:,
-
1
:,
:],
past
[
1
][
1
:,
:,
-
1
,
:].
reshape
(
-
1
,
1
,
64
),
)
for
_
in
range
(
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
):
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"TestTestTestTestTestTest"
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
1
]
assert
(
generated_texts
[
0
].
tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
for
_
in
range
(
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
):
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"TestTestTestTestTestTestTestTestTestTestTest"
assert
generated_texts
[
0
].
request
==
default_bloom_batch
.
requests
[
0
]
assert
(
generated_texts
[
0
].
tokens
==
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
for
_
in
range
(
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
default_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
default_multi_requests_bloom_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
4
):
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_bloom
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"TestTestTestTestTestTestTestTestTestTestTest"
assert
generated_texts
[
0
].
request
==
default_multi_requests_bloom_batch
.
requests
[
0
]
assert
(
generated_texts
[
0
].
tokens
==
default_multi_requests_bloom_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
server/tests/models/test_causal_lm.py
0 → 100644
View file @
a2985036
import
pytest
import
torch
from
copy
import
copy
from
text_generation.pb
import
generate_pb2
from
text_generation.models.causal_lm
import
CausalLM
,
CausalLMBatch
@
pytest
.
fixture
def
default_pb_request
(
default_pb_parameters
):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"Test"
,
input_length
=
1
,
parameters
=
default_pb_parameters
,
max_new_tokens
=
10
,
)
@
pytest
.
fixture
def
default_pb_batch
(
default_pb_request
):
return
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
default_pb_request
],
size
=
1
)
@
pytest
.
fixture
def
default_causal_lm_batch
(
default_pb_batch
,
gpt2_tokenizer
):
return
CausalLMBatch
.
from_pb
(
default_pb_batch
,
gpt2_tokenizer
,
torch
.
device
(
"cpu"
))
@
pytest
.
fixture
def
default_multi_requests_causal_lm_batch
(
default_pb_request
,
gpt2_tokenizer
):
req_0
=
copy
(
default_pb_request
)
req_1
=
default_pb_request
req_1
.
id
=
1
req_1
.
max_new_tokens
=
5
batch_pb
=
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
req_0
,
req_1
],
size
=
2
)
return
CausalLMBatch
.
from_pb
(
batch_pb
,
gpt2_tokenizer
,
torch
.
device
(
"cpu"
))
@
pytest
.
fixture
(
scope
=
"session"
)
def
default_causal_lm
():
return
CausalLM
(
"gpt2"
)
def
test_batch_from_pb
(
default_pb_batch
,
default_causal_lm_batch
):
batch
=
default_causal_lm_batch
assert
batch
.
batch_id
==
default_pb_batch
.
id
assert
batch
.
requests
==
default_pb_batch
.
requests
assert
len
(
batch
.
input_ids
)
==
default_pb_batch
.
size
assert
len
(
batch
.
input_ids
[
0
])
==
8
assert
batch
.
input_ids
[
0
][
-
1
]
==
14402
assert
torch
.
all
(
batch
.
input_ids
[
0
][:
-
1
]
==
50256
)
assert
batch
.
attention_mask
[
0
][
-
1
]
==
1
assert
torch
.
all
(
batch
.
attention_mask
[
0
][:
-
1
]
==
0
)
assert
batch
.
past_key_values
is
None
assert
torch
.
equal
(
batch
.
input_ids
,
batch
.
all_input_ids
[:,
:,
0
])
assert
batch
.
input_lengths
==
[
1
]
assert
batch
.
size
==
default_pb_batch
.
size
assert
len
(
batch
.
next_token_choosers
)
==
len
(
batch
.
stopping_criterias
)
==
batch
.
size
assert
batch
.
max_sequence_length
==
batch
.
input_lengths
[
0
]
def
test_batch_concatenate_no_prefill
(
default_causal_lm_batch
):
with
pytest
.
raises
(
ValueError
):
CausalLMBatch
.
concatenate
([
default_causal_lm_batch
,
default_causal_lm_batch
])
def
test_causal_lm_batch_type
(
default_causal_lm
):
assert
default_causal_lm
.
batch_type
==
CausalLMBatch
def
test_causal_lm_generate_token
(
default_causal_lm
,
default_causal_lm_batch
):
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
default_causal_lm_batch
)
assert
generated_texts
==
[]
assert
isinstance
(
next_batch
,
CausalLMBatch
)
assert
len
(
next_batch
.
all_input_ids
)
==
next_batch
.
size
assert
len
(
next_batch
.
all_input_ids
[
0
])
==
len
(
next_batch
.
attention_mask
[
0
])
==
9
assert
next_batch
.
all_input_ids
[
0
][
-
1
]
==
6208
assert
next_batch
.
all_input_ids
[
0
][
-
2
]
==
14402
assert
torch
.
all
(
next_batch
.
all_input_ids
[
0
][:
-
2
]
==
50256
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][
-
2
:]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
][:
-
2
]
==
0
)
assert
next_batch
.
input_ids
.
shape
==
(
next_batch
.
size
,
1
)
assert
next_batch
.
input_ids
[
0
,
0
]
==
6208
assert
next_batch
.
input_lengths
==
[
2
]
assert
next_batch
.
max_sequence_length
==
next_batch
.
input_lengths
[
0
]
assert
next_batch
.
past_key_values
is
not
None
assert
all
([
p
[
0
].
shape
==
(
1
,
12
,
8
,
64
)
for
p
in
next_batch
.
past_key_values
])
assert
all
([
p
[
1
].
shape
==
(
1
,
12
,
8
,
64
)
for
p
in
next_batch
.
past_key_values
])
def
test_causal_lm_generate_token_completion
(
default_causal_lm
,
default_causal_lm_batch
):
next_batch
=
default_causal_lm_batch
for
_
in
range
(
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
1
):
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
(
generated_texts
[
0
].
output
==
"Test Test Test Test Test Test Test Test Test Test Test"
)
assert
generated_texts
[
0
].
request
==
default_causal_lm_batch
.
requests
[
0
]
assert
(
generated_texts
[
0
].
tokens
==
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
def
test_causal_lm_generate_token_completion_multi
(
default_causal_lm
,
default_multi_requests_causal_lm_batch
):
next_batch
=
default_multi_requests_causal_lm_batch
for
i
in
range
(
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
1
):
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"Test Test Test Test Test Test"
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
1
]
)
assert
(
generated_texts
[
0
].
tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
for
_
in
range
(
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
1
):
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
(
generated_texts
[
0
].
output
==
"Test Test Test Test Test Test Test Test Test Test Test"
)
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
0
]
)
assert
(
generated_texts
[
0
].
tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
def
test_batch_concatenate
(
default_causal_lm
,
default_causal_lm_batch
,
default_multi_requests_causal_lm_batch
):
next_batch_0
=
default_causal_lm_batch
_
,
next_batch_0
=
default_causal_lm
.
generate_token
(
next_batch_0
)
_
,
next_batch_0
=
default_causal_lm
.
generate_token
(
next_batch_0
)
next_batch_1
=
default_multi_requests_causal_lm_batch
_
,
next_batch_1
=
default_causal_lm
.
generate_token
(
next_batch_1
)
next_batch
=
CausalLMBatch
.
concatenate
([
next_batch_0
,
next_batch_1
])
assert
torch
.
equal
(
next_batch
.
all_input_ids
[
0
],
next_batch_0
.
all_input_ids
[
0
])
assert
torch
.
equal
(
next_batch
.
all_input_ids
[
1
],
next_batch_1
.
all_input_ids
[
0
])
assert
torch
.
equal
(
next_batch
.
all_input_ids
[
2
],
next_batch_1
.
all_input_ids
[
1
])
assert
torch
.
all
(
next_batch
.
attention_mask
[
0
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
-
2
:]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
[
1
:,
:
-
2
]
==
0
)
assert
next_batch
.
batch_id
==
0
assert
torch
.
all
(
next_batch
.
input_ids
==
6208
)
assert
next_batch
.
input_lengths
==
[
3
,
2
,
2
]
assert
next_batch
.
max_sequence_length
==
3
assert
next_batch
.
requests
[
0
]
==
next_batch_0
.
requests
[
0
]
assert
next_batch
.
requests
[
1
:]
==
next_batch_1
.
requests
assert
next_batch
.
next_token_choosers
[
0
]
==
next_batch_0
.
next_token_choosers
[
0
]
assert
next_batch
.
next_token_choosers
[
1
:]
==
next_batch_1
.
next_token_choosers
assert
next_batch
.
stopping_criterias
[
0
]
==
next_batch_0
.
stopping_criterias
[
0
]
assert
next_batch
.
stopping_criterias
[
1
:]
==
next_batch_1
.
stopping_criterias
assert
next_batch
.
past_key_values
is
not
None
assert
all
([
p
[
0
].
shape
==
(
3
,
12
,
2
,
64
)
for
p
in
next_batch
.
past_key_values
])
assert
all
([
p
[
1
].
shape
==
(
3
,
12
,
2
,
64
)
for
p
in
next_batch
.
past_key_values
])
for
i
,
past
in
enumerate
(
next_batch
.
past_key_values
):
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
0
][
0
,
:,
-
2
:],
past
[
0
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
0
][:,
:,
-
1
:],
past
[
0
][
1
:,
:,
-
1
:,
:]
)
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
1
][
0
,
:,
-
2
:],
past
[
1
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
1
][:,
:,
-
1
:],
past
[
1
][
1
:,
:,
-
1
:,
:]
)
for
_
in
range
(
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
):
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"Test Test Test Test Test Test"
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
1
]
)
assert
(
generated_texts
[
0
].
tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
)
for
_
in
range
(
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
2
):
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
(
generated_texts
[
0
].
output
==
"Test Test Test Test Test Test Test Test Test Test Test"
)
assert
generated_texts
[
0
].
request
==
default_causal_lm_batch
.
requests
[
0
]
assert
(
generated_texts
[
0
].
tokens
==
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
for
_
in
range
(
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
default_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
-
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
1
].
max_new_tokens
-
4
):
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_causal_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
(
generated_texts
[
0
].
output
==
"Test Test Test Test Test Test Test Test Test Test Test"
)
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_causal_lm_batch
.
requests
[
0
]
)
assert
(
generated_texts
[
0
].
tokens
==
default_multi_requests_causal_lm_batch
.
stopping_criterias
[
0
].
max_new_tokens
)
server/tests/models/test_seq2seq_lm.py
0 → 100644
View file @
a2985036
import
pytest
import
torch
from
copy
import
copy
from
text_generation.pb
import
generate_pb2
from
text_generation.models.seq2seq_lm
import
Seq2SeqLM
,
Seq2SeqLMBatch
@
pytest
.
fixture
def
default_pb_request
(
default_pb_parameters
):
return
generate_pb2
.
Request
(
id
=
0
,
inputs
=
"Test"
,
input_length
=
2
,
parameters
=
default_pb_parameters
,
max_new_tokens
=
10
,
)
@
pytest
.
fixture
def
default_pb_batch
(
default_pb_request
):
return
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
default_pb_request
],
size
=
1
)
@
pytest
.
fixture
def
default_seq2seq_lm_batch
(
default_pb_batch
,
mt0_small_tokenizer
):
return
Seq2SeqLMBatch
.
from_pb
(
default_pb_batch
,
mt0_small_tokenizer
,
torch
.
device
(
"cpu"
)
)
@
pytest
.
fixture
def
default_multi_requests_seq2seq_lm_batch
(
default_pb_request
,
mt0_small_tokenizer
):
req_0
=
copy
(
default_pb_request
)
req_1
=
default_pb_request
req_1
.
id
=
1
req_1
.
max_new_tokens
=
5
batch_pb
=
generate_pb2
.
Batch
(
id
=
0
,
requests
=
[
req_0
,
req_1
],
size
=
2
)
return
Seq2SeqLMBatch
.
from_pb
(
batch_pb
,
mt0_small_tokenizer
,
torch
.
device
(
"cpu"
))
@
pytest
.
fixture
(
scope
=
"session"
)
def
default_seq2seq_lm
():
return
Seq2SeqLM
(
"bigscience/mt0-small"
)
def
test_batch_from_pb
(
default_pb_batch
,
default_seq2seq_lm_batch
):
batch
=
default_seq2seq_lm_batch
assert
batch
.
batch_id
==
default_pb_batch
.
id
assert
batch
.
requests
==
default_pb_batch
.
requests
assert
batch
.
input_ids
.
shape
==
(
default_pb_batch
.
size
,
8
)
assert
batch
.
input_ids
[
0
][
-
2
]
==
4268
assert
batch
.
input_ids
[
0
][
-
1
]
==
1
assert
torch
.
all
(
batch
.
input_ids
[
0
][:
-
2
]
==
0
)
assert
torch
.
all
(
batch
.
attention_mask
[
0
][
-
2
:]
==
1
)
assert
torch
.
all
(
batch
.
attention_mask
[
0
][:
-
2
]
==
0
)
assert
batch
.
decoder_input_ids
.
shape
==
(
default_pb_batch
.
size
,
1
)
assert
batch
.
decoder_attention_mask
is
None
assert
batch
.
encoder_last_hidden_state
is
None
assert
batch
.
past_key_values
is
None
assert
batch
.
input_lengths
==
[
2
]
assert
batch
.
decoder_input_lengths
==
[
1
]
assert
batch
.
size
==
default_pb_batch
.
size
assert
len
(
batch
.
next_token_choosers
)
==
len
(
batch
.
stopping_criterias
)
==
batch
.
size
assert
batch
.
max_input_length
==
batch
.
input_lengths
[
0
]
assert
batch
.
max_decoder_input_length
==
batch
.
decoder_input_lengths
[
0
]
def
test_batch_concatenate_no_prefill
(
default_seq2seq_lm_batch
):
with
pytest
.
raises
(
ValueError
):
Seq2SeqLMBatch
.
concatenate
([
default_seq2seq_lm_batch
,
default_seq2seq_lm_batch
])
def
test_seq2seq_lm_batch_type
(
default_seq2seq_lm
):
assert
default_seq2seq_lm
.
batch_type
==
Seq2SeqLMBatch
def
test_seq2seq_lm_generate_token
(
default_seq2seq_lm
,
default_seq2seq_lm_batch
):
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
default_seq2seq_lm_batch
)
assert
generated_texts
==
[]
assert
isinstance
(
next_batch
,
Seq2SeqLMBatch
)
assert
torch
.
equal
(
next_batch
.
input_ids
,
default_seq2seq_lm_batch
.
input_ids
)
assert
torch
.
equal
(
next_batch
.
attention_mask
,
default_seq2seq_lm_batch
.
attention_mask
)
assert
next_batch
.
input_lengths
==
default_seq2seq_lm_batch
.
input_lengths
assert
next_batch
.
max_input_length
==
default_seq2seq_lm_batch
.
max_input_length
assert
(
next_batch
.
next_token_choosers
==
default_seq2seq_lm_batch
.
next_token_choosers
)
assert
next_batch
.
stopping_criterias
==
default_seq2seq_lm_batch
.
stopping_criterias
assert
next_batch
.
decoder_input_ids
.
shape
==
(
next_batch
.
size
,
2
)
assert
next_batch
.
decoder_input_ids
[
0
,
0
]
==
0
assert
next_batch
.
decoder_input_ids
[
0
,
1
]
==
259
assert
next_batch
.
decoder_attention_mask
is
None
assert
next_batch
.
encoder_last_hidden_state
.
shape
==
(
1
,
8
,
512
)
assert
next_batch
.
decoder_input_lengths
==
[
2
]
assert
next_batch
.
max_decoder_input_length
==
2
assert
next_batch
.
past_key_values
is
not
None
assert
all
(
[
p
[
0
].
shape
==
(
next_batch
.
size
,
6
,
1
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
assert
all
(
[
p
[
1
].
shape
==
(
next_batch
.
size
,
6
,
1
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
assert
all
(
[
p
[
2
].
shape
==
(
next_batch
.
size
,
6
,
8
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
assert
all
(
[
p
[
3
].
shape
==
(
next_batch
.
size
,
6
,
8
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
def
test_seq2seq_lm_generate_token_completion
(
default_seq2seq_lm
,
default_seq2seq_lm_batch
):
next_batch
=
default_seq2seq_lm_batch
for
_
in
range
(
6
):
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"a few weeks"
assert
generated_texts
[
0
].
request
==
default_seq2seq_lm_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
tokens
==
7
def
test_seq2seq_lm_generate_token_completion_multi
(
default_seq2seq_lm
,
default_multi_requests_seq2seq_lm_batch
):
next_batch
=
default_multi_requests_seq2seq_lm_batch
for
i
in
range
(
4
):
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"a few "
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
1
]
)
assert
generated_texts
[
0
].
tokens
==
5
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"a few weeks"
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
0
]
)
assert
generated_texts
[
0
].
tokens
==
7
def
test_batch_concatenate
(
default_seq2seq_lm
,
default_seq2seq_lm_batch
,
default_multi_requests_seq2seq_lm_batch
,
):
next_batch_0
=
default_seq2seq_lm_batch
_
,
next_batch_0
=
default_seq2seq_lm
.
generate_token
(
next_batch_0
)
_
,
next_batch_0
=
default_seq2seq_lm
.
generate_token
(
next_batch_0
)
next_batch_1
=
default_multi_requests_seq2seq_lm_batch
_
,
next_batch_1
=
default_seq2seq_lm
.
generate_token
(
next_batch_1
)
next_batch
=
Seq2SeqLMBatch
.
concatenate
([
next_batch_0
,
next_batch_1
])
assert
next_batch
.
batch_id
==
0
assert
torch
.
all
(
next_batch
.
input_ids
[:,
0
]
==
4268
)
assert
torch
.
all
(
next_batch
.
input_ids
[:,
1
]
==
1
)
assert
torch
.
all
(
next_batch
.
attention_mask
==
1
)
assert
torch
.
equal
(
next_batch
.
decoder_input_ids
[
0
],
next_batch_0
.
decoder_input_ids
[
0
]
)
assert
torch
.
all
(
next_batch
.
decoder_input_ids
[
1
:,
0
]
==
0
)
assert
torch
.
equal
(
next_batch
.
decoder_input_ids
[
1
:,
-
2
:],
next_batch_1
.
decoder_input_ids
)
assert
torch
.
all
(
next_batch
.
decoder_attention_mask
[
0
]
==
1
)
assert
torch
.
all
(
next_batch
.
decoder_attention_mask
[
1
:,
0
]
==
0
)
assert
torch
.
all
(
next_batch
.
decoder_attention_mask
[
1
:,
-
2
:]
==
1
)
assert
torch
.
equal
(
next_batch
.
encoder_last_hidden_state
[
0
],
next_batch_0
.
encoder_last_hidden_state
[
0
,
-
2
:],
)
assert
torch
.
equal
(
next_batch
.
encoder_last_hidden_state
[
1
:],
next_batch_1
.
encoder_last_hidden_state
[:,
-
2
:],
)
assert
next_batch
.
input_lengths
==
[
2
,
2
,
2
]
assert
next_batch
.
decoder_input_lengths
==
[
3
,
2
,
2
]
assert
next_batch
.
max_input_length
==
2
assert
next_batch
.
max_decoder_input_length
==
3
assert
next_batch
.
requests
[
0
]
==
next_batch_0
.
requests
[
0
]
assert
next_batch
.
requests
[
1
:]
==
next_batch_1
.
requests
assert
next_batch
.
next_token_choosers
[
0
]
==
next_batch_0
.
next_token_choosers
[
0
]
assert
next_batch
.
next_token_choosers
[
1
:]
==
next_batch_1
.
next_token_choosers
assert
next_batch
.
stopping_criterias
[
0
]
==
next_batch_0
.
stopping_criterias
[
0
]
assert
next_batch
.
stopping_criterias
[
1
:]
==
next_batch_1
.
stopping_criterias
assert
next_batch
.
past_key_values
is
not
None
assert
all
(
[
p
[
0
].
shape
==
(
next_batch
.
size
,
6
,
2
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
assert
all
(
[
p
[
1
].
shape
==
(
next_batch
.
size
,
6
,
2
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
assert
all
(
[
p
[
2
].
shape
==
(
next_batch
.
size
,
6
,
2
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
assert
all
(
[
p
[
3
].
shape
==
(
next_batch
.
size
,
6
,
2
,
64
)
for
p
in
next_batch
.
past_key_values
]
)
for
i
,
past
in
enumerate
(
next_batch
.
past_key_values
):
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
0
][
0
,
:,
-
2
:,
:],
past
[
0
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
0
][:,
:,
-
1
:,
:],
past
[
0
][
1
:,
:,
-
1
:,
:]
)
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
1
][
0
,
:,
-
2
:,
:],
past
[
1
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
1
][:,
:,
-
1
:,
:],
past
[
1
][
1
:,
:,
-
1
:,
:]
)
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
2
][
0
,
:,
-
2
:,
:],
past
[
2
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
2
][:,
:,
-
2
:,
:],
past
[
2
][
1
:]
)
assert
torch
.
equal
(
next_batch_0
.
past_key_values
[
i
][
3
][
0
,
:,
-
2
:,
:],
past
[
3
][
0
])
assert
torch
.
equal
(
next_batch_1
.
past_key_values
[
i
][
3
][:,
:,
-
2
:,
:],
past
[
3
][
1
:]
)
for
_
in
range
(
3
):
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
generated_texts
==
[]
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"a few "
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
1
]
)
assert
generated_texts
[
0
].
tokens
==
5
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
not
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"a few weeks"
assert
generated_texts
[
0
].
request
==
default_seq2seq_lm_batch
.
requests
[
0
]
assert
generated_texts
[
0
].
tokens
==
7
generated_texts
,
next_batch
=
default_seq2seq_lm
.
generate_token
(
next_batch
)
assert
next_batch
is
None
assert
len
(
generated_texts
)
==
1
assert
generated_texts
[
0
].
output
==
"a few weeks"
assert
(
generated_texts
[
0
].
request
==
default_multi_requests_seq2seq_lm_batch
.
requests
[
0
]
)
assert
generated_texts
[
0
].
tokens
==
7
server/tests/test_utils.py
0 → 100644
View file @
a2985036
import
pytest
from
text_generation.utils
import
(
weight_hub_files
,
download_weights
,
weight_files
,
LocalEntryNotFoundError
,
)
def
test_weight_hub_files
():
filenames
=
weight_hub_files
(
"bigscience/bloom-560m"
)
assert
filenames
==
[
"model.safetensors"
]
def
test_weight_hub_files_llm
():
filenames
=
weight_hub_files
(
"bigscience/bloom"
)
assert
filenames
==
[
f
"model_
{
i
:
05
d
}
-of-00072.safetensors"
for
i
in
range
(
1
,
73
)]
def
test_weight_hub_files_empty
():
filenames
=
weight_hub_files
(
"bigscience/bloom"
,
".errors"
)
assert
filenames
==
[]
def
test_download_weights
():
files
=
download_weights
(
"bigscience/bloom-560m"
)
local_files
=
weight_files
(
"bigscience/bloom-560m"
)
assert
files
==
local_files
def
test_weight_files_error
():
with
pytest
.
raises
(
LocalEntryNotFoundError
):
weight_files
(
"bert-base-uncased"
)
server/text_generation/models/__init__.py
View file @
a2985036
from
text_generation.models.model
import
Model
from
text_generation.models.causal_lm
import
CausalLM
from
text_generation.models.bloom
import
BLOOMSharded
from
text_generation.models.bloom
import
BLOOM
,
BLOOMSharded
from
text_generation.models.seq2seq_lm
import
Seq2SeqLM
from
text_generation.models.galactica
import
Galactica
,
GalacticaSharded
__all__
=
[
"Model"
,
"BLOOMSharded"
,
"CausalLM"
,
"Seq2SeqLM"
]
__all__
=
[
"Model"
,
"BLOOM"
,
"BLOOMSharded"
,
"CausalLM"
,
"Seq2SeqLM"
]
def
get_model
(
model_name
:
str
,
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
...
...
@@ -12,7 +12,7 @@ def get_model(model_name: str, sharded: bool, quantize: bool) -> Model:
if
sharded
:
return
BLOOMSharded
(
model_name
,
quantize
=
quantize
)
else
:
return
CausalL
M
(
model_name
,
quantize
=
quantize
)
return
BLOO
M
(
model_name
,
quantize
=
quantize
)
elif
model_name
.
startswith
(
"facebook/galactica"
):
if
sharded
:
return
GalacticaSharded
(
model_name
,
quantize
=
quantize
)
...
...
server/text_generation/models/bloom.py
View file @
a2985036
import
torch
import
torch.distributed
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Type
from
accelerate
import
init_empty_weights
from
safetensors
import
safe_open
...
...
@@ -13,6 +13,8 @@ from transformers.models.bloom.parallel_layers import (
)
from
text_generation.models
import
CausalLM
from
text_generation.models.causal_lm
import
CausalLMBatch
from
text_generation.pb
import
generate_pb2
from
text_generation.utils
import
(
initialize_torch_distributed
,
weight_files
,
...
...
@@ -29,7 +31,25 @@ except Exception as e:
torch
.
manual_seed
(
0
)
class
BLOOMSharded
(
CausalLM
):
class
BloomCausalLMBatch
(
CausalLMBatch
):
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
AutoTokenizer
,
device
:
torch
.
device
)
->
"CausalLMBatch"
:
batch
=
super
(
BloomCausalLMBatch
,
cls
).
from_pb
(
pb
=
pb
,
tokenizer
=
tokenizer
,
device
=
device
)
batch
.
keys_head_dim_last
=
False
return
batch
class
BLOOM
(
CausalLM
):
@
property
def
batch_type
(
self
)
->
Type
[
CausalLMBatch
]:
return
BloomCausalLMBatch
class
BLOOMSharded
(
BLOOM
):
def
__init__
(
self
,
model_name
:
str
,
quantize
:
bool
=
False
):
if
not
model_name
.
startswith
(
"bigscience/bloom"
):
raise
ValueError
(
f
"Model
{
model_name
}
is not supported"
)
...
...
server/text_generation/models/causal_lm.py
View file @
a2985036
...
...
@@ -34,6 +34,9 @@ class CausalLMBatch:
size
:
int
max_sequence_length
:
int
# Past metadata
keys_head_dim_last
:
bool
=
True
def
to_pb
(
self
):
return
generate_pb2
.
Batch
(
id
=
self
.
batch_id
,
...
...
@@ -165,20 +168,16 @@ class CausalLMBatch:
head_dim
,
)
if
batch
.
keys_head_dim_last
:
padded_past_keys_shape
=
padded_past_values_shape
# seq_length is last for BLOOM
if
past_keys
.
shape
[
-
2
]
==
head_dim
:
past_keys_head_dim_last
=
False
else
:
padded_past_keys_shape
=
(
total_batch_size
,
num_heads
,
head_dim
,
max_sequence_length
-
1
,
)
elif
past_keys
.
shape
[
-
1
]
==
head_dim
:
past_keys_head_dim_last
=
True
padded_past_keys_shape
=
padded_past_values_shape
else
:
raise
ValueError
(
f
"past_keys shape
{
past_keys
.
shape
}
is not valid"
)
# This will run only once per layer
if
j
==
len
(
past_key_values
):
...
...
@@ -195,7 +194,7 @@ class CausalLMBatch:
past_key_values
.
append
((
padded_past_keys
,
padded_past_values
))
# We slice the past keys and values to remove the padding from previous batches
if
past_
keys_head_dim_last
:
if
batch
.
keys_head_dim_last
:
past_key_values
[
j
][
0
][
start_index
:
end_index
,
:,
...
...
@@ -228,6 +227,7 @@ class CausalLMBatch:
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
max_sequence_length
=
max_sequence_length
,
keys_head_dim_last
=
batches
[
0
].
keys_head_dim_last
,
)
...
...
@@ -237,6 +237,9 @@ class CausalLM(Model):
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
else
:
if
quantize
:
raise
ValueError
(
"quantization is not available on CPU"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
...
...
@@ -247,7 +250,11 @@ class CausalLM(Model):
device_map
=
"auto"
if
torch
.
cuda
.
is_available
()
else
None
,
load_in_8bit
=
quantize
,
).
eval
()
tokenizer
.
pad_token_id
=
self
.
model
.
config
.
pad_token_id
tokenizer
.
pad_token_id
=
(
self
.
model
.
config
.
pad_token_id
if
self
.
model
.
config
.
pad_token_id
is
not
None
else
self
.
model
.
config
.
eos_token_id
)
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
...
...
@@ -397,5 +404,6 @@ class CausalLM(Model):
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
max_sequence_length
=
next_batch_max_sequence_length
,
keys_head_dim_last
=
batch
.
keys_head_dim_last
,
)
return
generated_texts
,
next_batch
server/text_generation/models/galactica.py
View file @
a2985036
...
...
@@ -83,7 +83,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
@
classmethod
def
from_pb
(
cls
,
pb
:
generate_pb2
.
Batch
,
tokenizer
:
AutoTokenizer
,
device
:
torch
.
device
)
->
"CausalLMBatch"
:
)
->
"
Galactica
CausalLMBatch"
:
inputs
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
...
...
server/text_generation/models/seq2seq_lm.py
View file @
a2985036
...
...
@@ -221,8 +221,8 @@ class Seq2SeqLMBatch:
# Copy to correct indices
encoder_last_hidden_state
[
start_index
:
end_index
,
-
batch
.
max_
decoder_
input_length
:,
:
]
=
batch
.
encoder_last_hidden_state
[:,
-
batch
.
max_
decoder_
input_length
:,
:]
start_index
:
end_index
,
-
batch
.
max_input_length
:,
:
]
=
batch
.
encoder_last_hidden_state
[:,
-
batch
.
max_input_length
:,
:]
# Iterate over attention layers
for
j
,
past
in
enumerate
(
batch
.
past_key_values
):
...
...
@@ -305,6 +305,9 @@ class Seq2SeqLM(Model):
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
else
:
if
quantize
:
raise
ValueError
(
"quantization is not available on CPU"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
...
...
server/text_generation/utils.py
View file @
a2985036
...
...
@@ -137,8 +137,8 @@ def download_weights(model_name, extension=".safetensors"):
executor
.
submit
(
download_function
,
filename
=
filename
)
for
filename
in
filenames
]
files
=
[
f
ile
for
f
il
e
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
))
f
uture
.
result
()
for
f
utur
e
in
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
futures
))
]
return
files
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment