Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
299217c9
Unverified
Commit
299217c9
authored
Apr 11, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 11, 2023
Browse files
feat(server): add flash attention llama (#144)
parent
99879600
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1175 additions
and
40 deletions
+1175
-40
README.md
README.md
+4
-6
launcher/tests/mt0_base.json
launcher/tests/mt0_base.json
+1
-1
server/poetry.lock
server/poetry.lock
+103
-1
server/pyproject.toml
server/pyproject.toml
+2
-0
server/tests/models/test_seq2seq_lm.py
server/tests/models/test_seq2seq_lm.py
+1
-1
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+21
-6
server/text_generation_server/models/causal_lm.py
server/text_generation_server/models/causal_lm.py
+27
-3
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+619
-0
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+28
-2
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+303
-0
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+6
-2
server/text_generation_server/models/model.py
server/text_generation_server/models/model.py
+33
-15
server/text_generation_server/models/seq2seq_lm.py
server/text_generation_server/models/seq2seq_lm.py
+27
-3
No files found.
README.md
View file @
299217c9
...
@@ -51,16 +51,14 @@ to power LLMs api-inference widgets.
...
@@ -51,16 +51,14 @@ to power LLMs api-inference widgets.
-
Log probabilities
-
Log probabilities
-
Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
-
Production ready (distributed tracing with Open Telemetry, Prometheus metrics)
## O
fficially support
ed architectures
## O
ptimiz
ed architectures
-
[
BLOOM
](
https://huggingface.co/bigscience/bloom
)
-
[
BLOOM
](
https://huggingface.co/bigscience/bloom
)
-
[
BLOOMZ
](
https://huggingface.co/bigscience/bloomz
)
-
[
MT0-XXL
](
https://huggingface.co/bigscience/mt0-xxl
)
-
[
Galactica
](
https://huggingface.co/facebook/galactica-120b
)
-
[
Galactica
](
https://huggingface.co/facebook/galactica-120b
)
-
[
SantaCoder
](
https://huggingface.co/bigcode/santacoder
)
-
[
SantaCoder
](
https://huggingface.co/bigcode/santacoder
)
-
[
GPT-Neox
20B
](
https://huggingface.co/EleutherAI/gpt-neox-20b
)
-
[
GPT-Neox
](
https://huggingface.co/EleutherAI/gpt-neox-20b
)
-
[
FLAN-T5
-XXL
](
https://huggingface.co/google/flan-t5-xxl
)
-
[
FLAN-T5
](
https://huggingface.co/google/flan-t5-xxl
)
-
[
FLAN-UL2
](
https://
huggingface.co/google/flan-ul2
)
-
[
Llama
](
https://
github.com/facebookresearch/llama
)
Other architectures are supported on a best effort basis using:
Other architectures are supported on a best effort basis using:
...
...
launcher/tests/mt0_base.json
View file @
299217c9
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
"tokens"
:
[
"tokens"
:
[
{
{
"id"
:
259
,
"id"
:
259
,
"text"
:
"
"
,
"text"
:
""
,
"logprob"
:
-1.3656927
,
"logprob"
:
-1.3656927
,
"special"
:
false
"special"
:
false
},
},
...
...
server/poetry.lock
View file @
299217c9
...
@@ -517,6 +517,14 @@ tensorflow = ["tensorflow"]
...
@@ -517,6 +517,14 @@ tensorflow = ["tensorflow"]
testing = ["h5py", "huggingface-hub", "numpy", "pytest", "pytest-benchmark", "setuptools-rust"]
testing = ["h5py", "huggingface-hub", "numpy", "pytest", "pytest-benchmark", "setuptools-rust"]
torch = ["torch"]
torch = ["torch"]
[[package]]
name = "sentencepiece"
version = "0.1.97"
description = "SentencePiece python wrapper"
category = "main"
optional = false
python-versions = "*"
[[package]]
[[package]]
name = "setuptools"
name = "setuptools"
version = "67.4.0"
version = "67.4.0"
...
@@ -530,6 +538,19 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-g
...
@@ -530,6 +538,19 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "pygments-g
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing = ["build[virtualenv]", "filelock (>=3.4.0)", "flake8 (<5)", "flake8-2020", "ini2toml[lite] (>=0.9)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pip (>=19.1)", "pip-run (>=8.8)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf", "pytest-timeout", "pytest-xdist", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
testing-integration = ["build[virtualenv]", "filelock (>=3.4.0)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "pytest", "pytest-enabler", "pytest-xdist", "tomli", "virtualenv (>=13.0.0)", "wheel"]
[[package]]
name = "tokenizers"
version = "0.13.3"
description = "Fast and Customizable Tokenizers"
category = "main"
optional = false
python-versions = "*"
[package.extras]
dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
[[package]]
[[package]]
name = "tomli"
name = "tomli"
version = "2.0.1"
version = "2.0.1"
...
@@ -630,7 +651,7 @@ bnb = ["bitsandbytes"]
...
@@ -630,7 +651,7 @@ bnb = ["bitsandbytes"]
[metadata]
[metadata]
lock-version = "1.1"
lock-version = "1.1"
python-versions = "^3.9"
python-versions = "^3.9"
content-hash = "
521dc9f3c283dc56f7d2e2f96759919ff27ab49ffd3ae7cd26317b209e7fa98d
"
content-hash = "
1c57379c7b9349d2a860b50b3ab125737a0f6f94f4303d7cb55248cb86ff8b8e
"
[metadata.files]
[metadata.files]
accelerate = [
accelerate = [
...
@@ -1116,10 +1137,91 @@ safetensors = [
...
@@ -1116,10 +1137,91 @@ safetensors = [
{file = "safetensors-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:ba3dc236a2344b7feadc9868307f42ba5e4804c9d68a80a35aac831349b31f6f"},
{file = "safetensors-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:ba3dc236a2344b7feadc9868307f42ba5e4804c9d68a80a35aac831349b31f6f"},
{file = "safetensors-0.2.8.tar.gz", hash = "sha256:2720b20a6a38c799dca79bd76caeeac2f7df585a9d4f7d59fa7e28eff9ccb27f"},
{file = "safetensors-0.2.8.tar.gz", hash = "sha256:2720b20a6a38c799dca79bd76caeeac2f7df585a9d4f7d59fa7e28eff9ccb27f"},
]
]
sentencepiece = [
{file = "sentencepiece-0.1.97-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:6f249c8f1852893be86eae66b19d522c5fb30bbad4fe2d1b07f06fdc86e1907e"},
{file = "sentencepiece-0.1.97-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:09e1bc53178de70c557a9ba4fece07364b4416ce3d36570726b3372b68aea135"},
{file = "sentencepiece-0.1.97-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:667193c57fb48b238be7e3d7636cfc8da56cb5bac5559d8f0b647334e1175be8"},
{file = "sentencepiece-0.1.97-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2780531985af79c6163f63d4f200fec8a28b70b6768d2c19f70d01568a4524e8"},
{file = "sentencepiece-0.1.97-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:205050670c53ef9015e2a98cce3934bfbcf0aafaa14caa0c618dd5667bc217ee"},
{file = "sentencepiece-0.1.97-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28b183dadef8e8b6b4645c1c20692d7be0a13ecc3ec1a07b3885c8905516675f"},
{file = "sentencepiece-0.1.97-cp310-cp310-win32.whl", hash = "sha256:ee3c9dbd558d8d85bb1617087b86df6ea2b856a528669630ce6cedeb4353b823"},
{file = "sentencepiece-0.1.97-cp310-cp310-win_amd64.whl", hash = "sha256:f7dc55379e2f7dee86537180283db2e5f8418c6825fdd2fe436c724eb5604c05"},
{file = "sentencepiece-0.1.97-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:ba1b4154f9144c5a7528b00aff5cffaa1a896a1c6ca53ca78b6e74cd2dae5244"},
{file = "sentencepiece-0.1.97-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac3d90aee5581e55d029d124ac11b6ae2fbae0817863b664b2f2302e966ababb"},
{file = "sentencepiece-0.1.97-cp36-cp36m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c27400f1ac46518a01c87cb7703650e4e48728649feb115d2e3f1102a946a42"},
{file = "sentencepiece-0.1.97-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6e12a166eba75994ca749aadc4a5056b91b31405f805d6de6e8914cc9741c60"},
{file = "sentencepiece-0.1.97-cp36-cp36m-win32.whl", hash = "sha256:ed85dff5c0a9b3dd1a414c7e1119f2a19e863fc3f81da525bf7f885ebc883de0"},
{file = "sentencepiece-0.1.97-cp36-cp36m-win_amd64.whl", hash = "sha256:91a19ab6f40ffbae6d6127119953d2c6a85e93d734953dbc8629fde0d21ace66"},
{file = "sentencepiece-0.1.97-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:bae580e4a35a9314ff49561ac7c06574fe6afc71b821ed6bb00534e571458156"},
{file = "sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ad7262e7530c683b186672b5dd0082f82719a50a500a8cfbc4bbd7cde5bff8c"},
{file = "sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:620cee35279720016735a7c7103cddbd9b84fe5e2f098bd5e673834d69fee2b8"},
{file = "sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:93b921b59914c0ec6697e8c6d5e6b44d99d1298fb1a0af56980a79ade0540c19"},
{file = "sentencepiece-0.1.97-cp37-cp37m-win32.whl", hash = "sha256:9b9a4c44a31d5f47616e9568dcf31e029b0bfa776e0a252c0b59247881598b09"},
{file = "sentencepiece-0.1.97-cp37-cp37m-win_amd64.whl", hash = "sha256:f31533cdacced56219e239d3459a003ece35116920dd64b2309d4ad047b77644"},
{file = "sentencepiece-0.1.97-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:7d643c01d1cad13b9206a276bbe5bc1a468e3d7cf6a26bde7783f945277f859d"},
{file = "sentencepiece-0.1.97-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:542f1985b1ee279a92bef7740ec0781452372028ce01e15aa88df3228b197ba3"},
{file = "sentencepiece-0.1.97-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:93701da21fea906dd244bf88cdbe640385a89c45d3c1812b76dbadf8782cdbcd"},
{file = "sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a51514047b964047b7fadb480d88a5e0f72c02f6ca1ba96258fbbc6e79274a94"},
{file = "sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e3ae2e9b7a5b6f2aa64ec9240b0c185dabe597d0e787dc4344acfbaef1ffe0b2"},
{file = "sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:923ee4af16dbae1f2ab358ed09f8a0eb89e40a8198a8b343bf54181482342721"},
{file = "sentencepiece-0.1.97-cp38-cp38-win32.whl", hash = "sha256:fa6f2b88850b5fae3a05053658824cf9f147c8e3c3b40eb64539a976c83d8a24"},
{file = "sentencepiece-0.1.97-cp38-cp38-win_amd64.whl", hash = "sha256:5137ff0d0b1cc574751d178650ef800ff8d90bf21eb9f71e9567d4a0548940a5"},
{file = "sentencepiece-0.1.97-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:f92876271a10494671431ad955bff2d6f8ea59baaf957f5ae5946aff56dfcb90"},
{file = "sentencepiece-0.1.97-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:35c227b6d55e473033db7e0ecc51b1e99e6ed7607cc08602fb5768132543c81d"},
{file = "sentencepiece-0.1.97-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1706a8a8188f7b3d4b7922db9bb00c64c4e16ee68ab4caaae79f55b3e18748c7"},
{file = "sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ce61efc1862ccb18856c4aabbd930e13d5bfbb4b09b4f111081ac53a9dc62275"},
{file = "sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a78c03800ef9f02d320e0159f5768b15357f3e9ebea545c9c4ba7928ba8ba254"},
{file = "sentencepiece-0.1.97-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753b8088fd685ee787d9f54c84275ab347de558c7c4ebc6accb4c35bf7776f20"},
{file = "sentencepiece-0.1.97-cp39-cp39-win32.whl", hash = "sha256:24306fd86031c17a1a6ae92671e76a350390a3140a65620bc2843dad7db24e2a"},
{file = "sentencepiece-0.1.97-cp39-cp39-win_amd64.whl", hash = "sha256:c6641d0b7acec61fde5881ea6ebe098c169557ac9aa3bdabdf124eab5a5592bb"},
{file = "sentencepiece-0.1.97.tar.gz", hash = "sha256:c901305e0a710bbcd296f66d79e96f744e6e175b29812bd5178318437d4e1f6c"},
]
setuptools = [
setuptools = [
{file = "setuptools-67.4.0-py3-none-any.whl", hash = "sha256:f106dee1b506dee5102cc3f3e9e68137bbad6d47b616be7991714b0c62204251"},
{file = "setuptools-67.4.0-py3-none-any.whl", hash = "sha256:f106dee1b506dee5102cc3f3e9e68137bbad6d47b616be7991714b0c62204251"},
{file = "setuptools-67.4.0.tar.gz", hash = "sha256:e5fd0a713141a4a105412233c63dc4e17ba0090c8e8334594ac790ec97792330"},
{file = "setuptools-67.4.0.tar.gz", hash = "sha256:e5fd0a713141a4a105412233c63dc4e17ba0090c8e8334594ac790ec97792330"},
]
]
tokenizers = [
{file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"},
{file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"},
{file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"},
{file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee0b1b311d65beab83d7a41c56a1e46ab732a9eed4460648e8eb0bd69fc2d059"},
{file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ef4215284df1277dadbcc5e17d4882bda19f770d02348e73523f7e7d8b8d396"},
{file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4d53976079cff8a033f778fb9adca2d9d69d009c02fa2d71a878b5f3963ed30"},
{file = "tokenizers-0.13.3-cp310-cp310-win32.whl", hash = "sha256:1f0e3b4c2ea2cd13238ce43548959c118069db7579e5d40ec270ad77da5833ce"},
{file = "tokenizers-0.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:89649c00d0d7211e8186f7a75dfa1db6996f65edce4b84821817eadcc2d3c79e"},
{file = "tokenizers-0.13.3-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:56b726e0d2bbc9243872b0144515ba684af5b8d8cd112fb83ee1365e26ec74c8"},
{file = "tokenizers-0.13.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc5c022ce692e1f499d745af293ab9ee6f5d92538ed2faf73f9708c89ee59ce6"},
{file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f55c981ac44ba87c93e847c333e58c12abcbb377a0c2f2ef96e1a266e4184ff2"},
{file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f247eae99800ef821a91f47c5280e9e9afaeed9980fc444208d5aa6ba69ff148"},
{file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b3e3215d048e94f40f1c95802e45dcc37c5b05eb46280fc2ccc8cd351bff839"},
{file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba2b0bf01777c9b9bc94b53764d6684554ce98551fec496f71bc5be3a03e98b"},
{file = "tokenizers-0.13.3-cp311-cp311-win32.whl", hash = "sha256:cc78d77f597d1c458bf0ea7c2a64b6aa06941c7a99cb135b5969b0278824d808"},
{file = "tokenizers-0.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:ecf182bf59bd541a8876deccf0360f5ae60496fd50b58510048020751cf1724c"},
{file = "tokenizers-0.13.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:0527dc5436a1f6bf2c0327da3145687d3bcfbeab91fed8458920093de3901b44"},
{file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cbb2c307627dc99b44b22ef05ff4473aa7c7cc1fec8f0a8b37d8a64b1a16d2"},
{file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4560dbdeaae5b7ee0d4e493027e3de6d53c991b5002d7ff95083c99e11dd5ac0"},
{file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64064bd0322405c9374305ab9b4c07152a1474370327499911937fd4a76d004b"},
{file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8c6e2ab0f2e3d939ca66aa1d596602105fe33b505cd2854a4c1717f704c51de"},
{file = "tokenizers-0.13.3-cp37-cp37m-win32.whl", hash = "sha256:6cc29d410768f960db8677221e497226e545eaaea01aa3613fa0fdf2cc96cff4"},
{file = "tokenizers-0.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fc2a7fdf864554a0dacf09d32e17c0caa9afe72baf9dd7ddedc61973bae352d8"},
{file = "tokenizers-0.13.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8791dedba834c1fc55e5f1521be325ea3dafb381964be20684b92fdac95d79b7"},
{file = "tokenizers-0.13.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d607a6a13718aeb20507bdf2b96162ead5145bbbfa26788d6b833f98b31b26e1"},
{file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3791338f809cd1bf8e4fee6b540b36822434d0c6c6bc47162448deee3f77d425"},
{file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2f35f30e39e6aab8716f07790f646bdc6e4a853816cc49a95ef2a9016bf9ce6"},
{file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310204dfed5aa797128b65d63538a9837cbdd15da2a29a77d67eefa489edda26"},
{file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0f9b92ea052305166559f38498b3b0cae159caea712646648aaa272f7160963"},
{file = "tokenizers-0.13.3-cp38-cp38-win32.whl", hash = "sha256:9a3fa134896c3c1f0da6e762d15141fbff30d094067c8f1157b9fdca593b5806"},
{file = "tokenizers-0.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:8e7b0cdeace87fa9e760e6a605e0ae8fc14b7d72e9fc19c578116f7287bb873d"},
{file = "tokenizers-0.13.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:00cee1e0859d55507e693a48fa4aef07060c4bb6bd93d80120e18fea9371c66d"},
{file = "tokenizers-0.13.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a23ff602d0797cea1d0506ce69b27523b07e70f6dda982ab8cf82402de839088"},
{file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ce07445050b537d2696022dafb115307abdffd2a5c106f029490f84501ef97"},
{file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:280ffe95f50eaaf655b3a1dc7ff1d9cf4777029dbbc3e63a74e65a056594abc3"},
{file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97acfcec592f7e9de8cadcdcda50a7134423ac8455c0166b28c9ff04d227b371"},
{file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7730c98a3010cd4f523465867ff95cd9d6430db46676ce79358f65ae39797b"},
{file = "tokenizers-0.13.3-cp39-cp39-win32.whl", hash = "sha256:48625a108029cb1ddf42e17a81b5a3230ba6888a70c9dc14e81bc319e812652d"},
{file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"},
{file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"},
]
tomli = [
tomli = [
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
{file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
{file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
...
...
server/pyproject.toml
View file @
299217c9
...
@@ -23,6 +23,8 @@ opentelemetry-api = "^1.15.0"
...
@@ -23,6 +23,8 @@ opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp
=
"^1.15.0"
opentelemetry-exporter-otlp
=
"^1.15.0"
opentelemetry-instrumentation-grpc
=
"^0.36b0"
opentelemetry-instrumentation-grpc
=
"^0.36b0"
hf-transfer
=
"^0.1.2"
hf-transfer
=
"^0.1.2"
sentencepiece
=
"^0.1.97"
tokenizers
=
"0.13.3"
[tool.poetry.extras]
[tool.poetry.extras]
bnb
=
["bitsandbytes"]
bnb
=
["bitsandbytes"]
...
...
server/tests/models/test_seq2seq_lm.py
View file @
299217c9
...
@@ -148,7 +148,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
...
@@ -148,7 +148,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert
all
([
generation
.
generated_text
is
None
for
generation
in
generations
])
assert
all
([
generation
.
generated_text
is
None
for
generation
in
generations
])
assert
all
([
len
(
generation
.
prefill_tokens
)
==
1
for
generation
in
generations
])
assert
all
([
len
(
generation
.
prefill_tokens
)
==
1
for
generation
in
generations
])
assert
all
([
generation
.
token_id
.
item
()
==
259
for
generation
in
generations
])
assert
all
([
generation
.
token_id
.
item
()
==
259
for
generation
in
generations
])
assert
all
([
generation
.
token_text
==
"
"
for
generation
in
generations
])
assert
all
([
generation
.
token_text
==
""
for
generation
in
generations
])
assert
generations
[
0
].
request_id
==
0
assert
generations
[
0
].
request_id
==
0
...
...
server/text_generation_server/models/__init__.py
View file @
299217c9
...
@@ -19,13 +19,11 @@ from text_generation_server.models.t5 import T5Sharded
...
@@ -19,13 +19,11 @@ from text_generation_server.models.t5 import T5Sharded
try
:
try
:
from
text_generation_server.models.flash_neox
import
FlashNeoX
,
FlashNeoXSharded
from
text_generation_server.models.flash_neox
import
FlashNeoX
,
FlashNeoXSharded
from
text_generation_server.models.flash_santacoder
import
FlashSantacoder
from
text_generation_server.models.flash_santacoder
import
FlashSantacoder
from
text_generation_server.models.flash_llama
import
FlashLlama
,
FlashLlamaSharded
FLASH_ATTENTION
=
(
FLASH_ATTENTION
=
torch
.
cuda
.
is_available
()
torch
.
cuda
.
is_available
()
and
int
(
os
.
environ
.
get
(
"FLASH_ATTENTION"
,
0
))
==
1
)
except
ImportError
:
except
ImportError
:
if
int
(
os
.
environ
.
get
(
"FLASH_ATTENTION"
,
0
))
==
1
:
logger
.
exception
(
"Could not import Flash Attention enabled models"
)
logger
.
exception
(
"Could not import Flash Attention models"
)
FLASH_ATTENTION
=
False
FLASH_ATTENTION
=
False
__all__
=
[
__all__
=
[
...
@@ -47,6 +45,12 @@ if FLASH_ATTENTION:
...
@@ -47,6 +45,12 @@ if FLASH_ATTENTION:
__all__
.
append
(
FlashNeoX
)
__all__
.
append
(
FlashNeoX
)
__all__
.
append
(
FlashNeoXSharded
)
__all__
.
append
(
FlashNeoXSharded
)
__all__
.
append
(
FlashSantacoder
)
__all__
.
append
(
FlashSantacoder
)
__all__
.
append
(
FlashLlama
)
__all__
.
append
(
FlashLlamaSharded
)
FLASH_ATT_ERROR_MESSAGE
=
"{} requires Flash Attention CUDA kernels to be installed.
\n
"
\
"Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) "
\
"or install flash attention with `cd server && make install install-flash-attention`"
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
# in PyTorch 1.12 and later.
...
@@ -60,7 +64,7 @@ torch.set_grad_enabled(False)
...
@@ -60,7 +64,7 @@ torch.set_grad_enabled(False)
def
get_model
(
def
get_model
(
model_id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
bool
model_id
:
str
,
revision
:
Optional
[
str
],
sharded
:
bool
,
quantize
:
bool
)
->
Model
:
)
->
Model
:
if
"facebook/galactica"
in
model_id
:
if
"facebook/galactica"
in
model_id
:
if
sharded
:
if
sharded
:
...
@@ -92,6 +96,17 @@ def get_model(
...
@@ -92,6 +96,17 @@ def get_model(
neox_cls
=
FlashNeoX
if
FLASH_ATTENTION
else
CausalLM
neox_cls
=
FlashNeoX
if
FLASH_ATTENTION
else
CausalLM
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
return
neox_cls
(
model_id
,
revision
,
quantize
=
quantize
)
if
model_type
==
"llama"
:
if
sharded
:
if
FLASH_ATTENTION
:
return
FlashLlamaSharded
(
model_id
,
revision
,
quantize
=
quantize
)
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
f
"Sharded Llama"
)
)
else
:
llama_cls
=
FlashLlama
if
FLASH_ATTENTION
else
CausalLM
return
llama_cls
(
model_id
,
revision
,
quantize
=
quantize
)
if
model_type
==
"t5"
:
if
model_type
==
"t5"
:
if
sharded
:
if
sharded
:
return
T5Sharded
(
model_id
,
revision
,
quantize
=
quantize
)
return
T5Sharded
(
model_id
,
revision
,
quantize
=
quantize
)
...
...
server/text_generation_server/models/causal_lm.py
View file @
299217c9
...
@@ -34,6 +34,8 @@ class CausalLMBatch(Batch):
...
@@ -34,6 +34,8 @@ class CausalLMBatch(Batch):
# Lengths of all generations present in the batch
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
input_lengths
:
List
[
int
]
offsets
:
List
[
Optional
[
int
]]
token_offsets
:
List
[
Optional
[
int
]]
# Generation helpers
# Generation helpers
next_token_choosers
:
List
[
NextTokenChooser
]
next_token_choosers
:
List
[
NextTokenChooser
]
...
@@ -64,12 +66,16 @@ class CausalLMBatch(Batch):
...
@@ -64,12 +66,16 @@ class CausalLMBatch(Batch):
inputs
=
[]
inputs
=
[]
next_token_choosers
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
stopping_criterias
=
[]
offsets
=
[]
token_offsets
=
[]
# Parse batch
# Parse batch
max_truncation
=
0
max_truncation
=
0
padding_right_offset
=
0
padding_right_offset
=
0
for
r
in
pb
.
requests
:
for
r
in
pb
.
requests
:
inputs
.
append
(
r
.
inputs
)
inputs
.
append
(
r
.
inputs
)
offsets
.
append
(
None
)
token_offsets
.
append
(
None
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criteria
=
StoppingCriteria
.
from_pb
(
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
r
.
stopping_parameters
,
tokenizer
...
@@ -113,6 +119,8 @@ class CausalLMBatch(Batch):
...
@@ -113,6 +119,8 @@ class CausalLMBatch(Batch):
past_key_values
=
None
,
past_key_values
=
None
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
input_lengths
=
input_lengths
.
tolist
(),
input_lengths
=
input_lengths
.
tolist
(),
offsets
=
offsets
,
token_offsets
=
token_offsets
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
size
=
pb
.
size
,
size
=
pb
.
size
,
...
@@ -135,6 +143,8 @@ class CausalLMBatch(Batch):
...
@@ -135,6 +143,8 @@ class CausalLMBatch(Batch):
# Batch attributes
# Batch attributes
requests
=
[]
requests
=
[]
input_lengths
=
[]
input_lengths
=
[]
offsets
=
[]
token_offsets
=
[]
all_input_ids
=
[]
all_input_ids
=
[]
next_token_choosers
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
stopping_criterias
=
[]
...
@@ -151,6 +161,8 @@ class CausalLMBatch(Batch):
...
@@ -151,6 +161,8 @@ class CausalLMBatch(Batch):
for
i
,
batch
in
enumerate
(
batches
):
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
input_lengths
.
extend
(
batch
.
input_lengths
)
offsets
.
extend
(
batch
.
offsets
)
token_offsets
.
extend
(
batch
.
token_offsets
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
...
@@ -264,6 +276,8 @@ class CausalLMBatch(Batch):
...
@@ -264,6 +276,8 @@ class CausalLMBatch(Batch):
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
offsets
=
offsets
,
token_offsets
=
token_offsets
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
size
=
total_batch_size
,
...
@@ -289,7 +303,7 @@ class CausalLM(Model):
...
@@ -289,7 +303,7 @@ class CausalLM(Model):
dtype
=
torch
.
float32
dtype
=
torch
.
float32
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_id
,
model_id
,
...
@@ -350,6 +364,8 @@ class CausalLM(Model):
...
@@ -350,6 +364,8 @@ class CausalLM(Model):
# New values for next forward
# New values for next forward
next_batch_input_lengths
=
[]
next_batch_input_lengths
=
[]
next_batch_offsets
=
[]
next_batch_token_offsets
=
[]
next_batch_input_ids
=
[]
next_batch_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_batch_all_input_ids
=
[]
...
@@ -364,6 +380,8 @@ class CausalLM(Model):
...
@@ -364,6 +380,8 @@ class CausalLM(Model):
iterator
=
zip
(
iterator
=
zip
(
batch
.
requests
,
batch
.
requests
,
batch
.
input_lengths
,
batch
.
input_lengths
,
batch
.
offsets
,
batch
.
token_offsets
,
logits
,
logits
,
batch
.
next_token_choosers
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
stopping_criterias
,
...
@@ -374,6 +392,8 @@ class CausalLM(Model):
...
@@ -374,6 +392,8 @@ class CausalLM(Model):
for
i
,
(
for
i
,
(
request
,
request
,
input_length
,
input_length
,
offset
,
token_offset
,
logits
,
logits
,
next_token_chooser
,
next_token_chooser
,
stopping_criteria
,
stopping_criteria
,
...
@@ -391,8 +411,8 @@ class CausalLM(Model):
...
@@ -391,8 +411,8 @@ class CausalLM(Model):
# Generated token
# Generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
]
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
]
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_text
=
self
.
decode_token
(
next_token_text
,
offset
,
token_offset
=
self
.
decode_token
(
next_token_id_squeezed
,
all_input_ids
[:,
0
],
offset
,
token_offset
)
)
# Evaluate stopping criteria
# Evaluate stopping criteria
...
@@ -423,6 +443,8 @@ class CausalLM(Model):
...
@@ -423,6 +443,8 @@ class CausalLM(Model):
next_batch_all_input_ids
.
append
(
all_input_ids
)
next_batch_all_input_ids
.
append
(
all_input_ids
)
next_batch_size
+=
1
next_batch_size
+=
1
next_batch_input_lengths
.
append
(
new_input_length
)
next_batch_input_lengths
.
append
(
new_input_length
)
next_batch_offsets
.
append
(
offset
)
next_batch_token_offsets
.
append
(
token_offset
)
next_batch_max_input_length
=
max
(
next_batch_max_input_length
=
max
(
next_batch_max_input_length
,
new_input_length
next_batch_max_input_length
,
new_input_length
)
)
...
@@ -506,6 +528,8 @@ class CausalLM(Model):
...
@@ -506,6 +528,8 @@ class CausalLM(Model):
past_key_values
=
next_batch_past_key_values
,
past_key_values
=
next_batch_past_key_values
,
all_input_ids
=
next_batch_all_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
input_lengths
=
next_batch_input_lengths
,
input_lengths
=
next_batch_input_lengths
,
offsets
=
next_batch_offsets
,
token_offsets
=
next_batch_token_offsets
,
next_token_choosers
=
next_batch_next_token_choosers
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
size
=
next_batch_size
,
...
...
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
0 → 100644
View file @
299217c9
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.distributed
from
torch.nn
import
functional
as
F
from
torch
import
nn
from
transformers.activations
import
ACT2FN
# Flash attention imports
import
rotary_emb
import
flash_attn_cuda
import
dropout_layer_norm
from
flash_attn.layers.rotary
import
RotaryEmbedding
class
LlamaRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
,
residual
=
None
):
if
hidden_states
.
shape
[
-
1
]
>
8192
:
if
residual
is
not
None
:
hidden_states
+=
residual
residual
=
hidden_states
variance
=
hidden_states
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
# convert into half-precision if necessary
if
self
.
weight
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
hidden_states
=
hidden_states
.
to
(
self
.
weight
.
dtype
)
return
self
.
weight
*
hidden_states
,
residual
else
:
# faster post attention rms norm
normed_hidden_states
,
res
,
*
rest
=
dropout_layer_norm
.
dropout_add_ln_fwd
(
hidden_states
,
residual
,
self
.
weight
,
None
,
None
,
None
,
None
,
None
,
0.0
,
self
.
variance_epsilon
,
1.0
,
0
,
None
,
False
,
True
,
# Activate RMSNorm
)
if
res
is
None
:
res
=
hidden_states
return
normed_hidden_states
,
res
class
FastLinear
(
nn
.
Linear
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
,
)
->
None
:
super
(
FastLinear
,
self
).
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
def
transpose_weight
(
self
):
self
.
weight
=
nn
.
Parameter
(
self
.
weight
.
T
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
bias
is
not
None
:
return
torch
.
addmm
(
self
.
bias
,
input
,
self
.
weight
)
return
torch
.
matmul
(
input
,
self
.
weight
)
class
TensorParallelColumnLinear
(
FastLinear
):
def
__init__
(
self
,
in_features
,
out_features
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
bias
=
True
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_world_size
=
process_group
.
size
()
assert
out_features
%
self
.
tp_world_size
==
0
out_features
=
out_features
//
self
.
tp_world_size
super
().
__init__
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
class
TensorParallelRowLinear
(
FastLinear
):
def
__init__
(
self
,
in_features
,
out_features
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
reduce
=
True
,
bias
=
True
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_world_size
=
process_group
.
size
()
self
.
reduce
=
reduce
assert
in_features
%
self
.
tp_world_size
==
0
in_features
=
in_features
//
self
.
tp_world_size
super
().
__init__
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
super
(
TensorParallelRowLinear
,
self
).
forward
(
input
)
if
self
.
reduce
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
class
TensorParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.0
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
_weight
=
None
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
original_num_embeddings
=
num_embeddings
assert
num_embeddings
%
self
.
tp_world_size
==
0
block_size
=
num_embeddings
//
self
.
tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self
.
min_id
=
self
.
tp_rank
*
block_size
self
.
max_id
=
(
self
.
tp_rank
+
1
)
*
block_size
# Additional entry that will map to zero
# Used for masking
self
.
null_idx
=
block_size
super
().
__init__
(
block_size
,
embedding_dim
,
padding_idx
=
padding_idx
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
,
_weight
=
_weight
,
device
=
device
,
dtype
=
dtype
,
)
def
add_null_idx
(
self
):
"""Additional 0 entry used for masking"""
self
.
weight
=
nn
.
Parameter
(
F
.
pad
(
self
.
weight
,
(
0
,
0
,
0
,
1
)))
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input
=
torch
.
where
(
(
self
.
min_id
>
input
)
|
(
input
>=
self
.
max_id
),
self
.
null_idx
,
input
-
self
.
min_id
,
)
out
=
super
().
forward
(
input
)
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
class
PositionRotaryEmbedding
(
RotaryEmbedding
):
def
_update_cos_sin_cache
(
self
,
dtype
,
device
,
seqlen
):
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if
(
seqlen
>
self
.
_seq_len_cached
or
self
.
_cos_cached
.
device
!=
device
or
self
.
_cos_cached
.
dtype
!=
dtype
):
self
.
_seq_len_cached
=
seqlen
t
=
torch
.
arange
(
seqlen
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
outer
(
t
,
self
.
inv_freq
.
to
(
device
=
t
.
device
))
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
dtype
)
def
get_cos_sin
(
self
,
position_ids
:
torch
.
Tensor
,
max_s
:
int
,
dtype
:
torch
.
dtype
):
"""
Return cos and sin for the asked position ids
"""
self
.
_update_cos_sin_cache
(
dtype
,
position_ids
.
device
,
max_s
)
cos
=
torch
.
index_select
(
self
.
_cos_cached
,
0
,
position_ids
)
sin
=
torch
.
index_select
(
self
.
_sin_cached
,
0
,
position_ids
)
return
cos
.
unsqueeze
(
1
),
sin
.
unsqueeze
(
1
)
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
):
rotary_dim
=
cos
.
shape
[
-
1
]
q1
=
qkv
[:,
0
,
:,
:
rotary_dim
]
q2
=
qkv
[:,
0
,
:,
rotary_dim
:
2
*
rotary_dim
]
k1
=
qkv
[:,
1
,
:,
:
rotary_dim
]
k2
=
qkv
[:,
1
,
:,
rotary_dim
:
2
*
rotary_dim
]
rotary_emb
.
apply_rotary
(
q1
,
q2
,
cos
,
sin
,
q1
,
q2
,
False
)
rotary_emb
.
apply_rotary
(
k1
,
k2
,
cos
,
sin
,
k1
,
k2
,
False
)
return
qkv
class
FlashLlamaAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
hidden_size
,
process_group
=
None
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
hidden_size
=
hidden_size
self
.
head_size
=
hidden_size
//
num_heads
self
.
rotary_emb
=
PositionRotaryEmbedding
(
self
.
head_size
,
base
=
10000
)
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
if
process_group
is
None
:
self
.
query_key_value
=
FastLinear
(
hidden_size
,
3
*
hidden_size
,
bias
=
False
)
self
.
o_proj
=
FastLinear
(
hidden_size
,
hidden_size
,
bias
=
False
)
else
:
self
.
num_heads
=
self
.
num_heads
//
process_group
.
size
()
self
.
query_key_value
=
TensorParallelColumnLinear
(
hidden_size
,
3
*
hidden_size
,
bias
=
False
,
process_group
=
process_group
,
)
self
.
o_proj
=
TensorParallelRowLinear
(
hidden_size
,
hidden_size
,
bias
=
False
,
process_group
=
process_group
,
)
def
forward
(
self
,
hidden_states
,
cos
,
sin
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
qkv
.
view
(
-
1
,
3
,
self
.
num_heads
,
self
.
head_size
)
qkv_rot
=
self
.
rotary_emb
(
qkv
,
cos
,
sin
)
# Prefill
if
layer_past_present_indices
is
None
:
# Copy to layer past
layer_past
[...]
=
qkv_rot
[:,
1
:]
# output
attn_output
=
torch
.
empty_like
(
qkv_rot
[:,
0
])
# flash attention
flash_attn_cuda
.
fwd
(
qkv_rot
[:,
0
],
qkv_rot
[:,
1
],
qkv_rot
[:,
2
],
attn_output
,
cu_seqlens
,
cu_seqlens
,
max_s
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
True
,
False
,
0
,
None
,
)
# Decode
else
:
query
=
qkv_rot
[:,
0
]
# Add present to the layer_past tensor at the correct indices
layer_past
[
layer_past_present_indices
]
=
qkv_rot
[:,
1
:]
# output
attn_output
=
torch
.
empty_like
(
query
)
# flash attention
flash_attn_cuda
.
fwd
(
query
,
layer_past
[:,
0
],
layer_past
[:,
1
],
attn_output
,
cu_seqlens_q
,
cu_seqlens
,
1
,
max_s
,
0.0
,
self
.
softmax_scale
,
False
,
False
,
False
,
0
,
None
,
)
return
self
.
o_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
))
class
LlamaMLP
(
nn
.
Module
):
def
__init__
(
self
,
act
,
hidden_size
,
intermediate_size
,
process_group
=
None
):
super
().
__init__
()
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
None
,
)
)
if
process_group
is
None
:
# Fuse gate and up proj
self
.
gate_up_proj
=
FastLinear
(
hidden_size
,
2
*
intermediate_size
,
bias
=
False
)
self
.
down_proj
=
FastLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
)
self
.
intermediate_size
=
intermediate_size
else
:
# Fuse gate and up proj
self
.
gate_up_proj
=
TensorParallelColumnLinear
(
hidden_size
,
2
*
intermediate_size
,
bias
=
False
,
process_group
=
process_group
,
)
self
.
down_proj
=
TensorParallelRowLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
process_group
=
process_group
,
reduce
=
True
,
)
self
.
intermediate_size
=
self
.
down_proj
.
in_features
self
.
process_group
=
process_group
def
forward
(
self
,
hidden_states
):
gate_up_states
=
self
.
gate_up_proj
(
hidden_states
)
gate_up_states
=
gate_up_states
.
view
(
-
1
,
2
,
self
.
intermediate_size
)
return
self
.
down_proj
(
self
.
act
(
gate_up_states
[:,
0
])
*
gate_up_states
[:,
1
])
class
FlashLlamaLayer
(
nn
.
Module
):
def
__init__
(
self
,
num_heads
,
act
,
hidden_size
,
intermediate_size
,
rms_norm_eps
,
process_group
=
None
,
):
super
().
__init__
()
self
.
self_attn
=
FlashLlamaAttention
(
num_heads
,
hidden_size
,
process_group
)
self
.
mlp
=
LlamaMLP
(
act
,
hidden_size
,
intermediate_size
,
process_group
)
self
.
input_layernorm
=
LlamaRMSNorm
(
hidden_size
,
eps
=
rms_norm_eps
)
self
.
post_attention_layernorm
=
LlamaRMSNorm
(
hidden_size
,
eps
=
rms_norm_eps
)
def
forward
(
self
,
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
):
normed_hidden_states
,
res
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Self Attention
attn_output
=
self
.
self_attn
(
normed_hidden_states
,
cos
,
sin
,
cu_seqlens
,
max_s
,
layer_past
,
layer_past_present_indices
,
cu_seqlens_q
,
)
# faster post attention rms norm
normed_attn_res_output
,
attn_res
=
self
.
post_attention_layernorm
(
attn_output
,
res
)
mlp_output
=
self
.
mlp
(
normed_attn_res_output
)
return
mlp_output
,
attn_res
class
FlashLlamaModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
process_group
=
None
):
super
(
FlashLlamaModel
,
self
).
__init__
()
self
.
config
=
config
self
.
tp_embeddings
=
False
if
process_group
is
not
None
:
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
if
config
.
vocab_size
%
self
.
tp_world_size
==
0
:
self
.
tp_embeddings
=
True
if
self
.
tp_embeddings
:
self
.
embed_tokens
=
TensorParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
process_group
=
process_group
)
else
:
self
.
embed_tokens
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
(
[
FlashLlamaLayer
(
config
.
num_attention_heads
,
config
.
hidden_act
,
config
.
hidden_size
,
config
.
intermediate_size
,
config
.
rms_norm_eps
,
process_group
,
)
for
_
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
LlamaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
gradient_checkpointing
=
False
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
def
post_load_weights
(
self
):
if
isinstance
(
self
.
embed_tokens
,
TensorParallelEmbedding
):
self
.
embed_tokens
.
add_null_idx
()
for
layer
in
self
.
layers
:
layer
:
FlashLlamaLayer
layer
.
self_attn
.
query_key_value
.
transpose_weight
()
layer
.
self_attn
.
o_proj
.
transpose_weight
()
layer
.
mlp
.
gate_up_proj
.
transpose_weight
()
layer
.
mlp
.
down_proj
.
transpose_weight
()
def
forward
(
self
,
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
=
None
,
):
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Prefill
if
past_key_values
is
None
:
# Create past tensor
past_key_values
=
hidden_states
.
new_empty
(
(
len
(
self
.
layers
),
len
(
hidden_states
),
2
,
self
.
num_heads
,
self
.
head_size
,
)
)
layer_past_present_indices
=
None
cu_seqlens_q
=
None
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos
,
sin
=
self
.
layers
[
0
].
self_attn
.
rotary_emb
.
get_cos_sin
(
position_ids
,
max_s
,
hidden_states
.
dtype
)
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlens
,
max_s
,
past_key_values
[
i
],
layer_past_present_indices
,
cu_seqlens_q
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
,
past_key_values
class
FlashLlamaForCausalLM
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
process_group
=
None
):
super
().
__init__
()
self
.
process_group
=
process_group
if
self
.
process_group
is
not
None
:
self
.
world_size
=
self
.
process_group
.
size
()
self
.
rank
=
self
.
process_group
.
rank
()
else
:
self
.
world_size
=
1
self
.
rank
=
0
self
.
model
=
FlashLlamaModel
(
config
,
process_group
)
if
self
.
model
.
tp_embeddings
:
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
//
process_group
.
size
(),
bias
=
False
,
)
else
:
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
def
post_load_weights
(
self
):
self
.
model
.
post_load_weights
()
self
.
lm_head
.
transpose_weight
()
def
forward
(
self
,
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
=
None
,
):
hidden_states
,
present
=
self
.
model
(
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
)
logits
=
self
.
lm_head
(
hidden_states
)
if
self
.
model
.
tp_embeddings
:
# Logits are sharded, so we need to gather them
world_logits
=
[
torch
.
empty_like
(
logits
)
for
_
in
range
(
self
.
world_size
)]
torch
.
distributed
.
all_gather
(
world_logits
,
logits
,
group
=
self
.
process_group
)
world_logits
=
torch
.
cat
(
world_logits
,
dim
=
1
)
return
world_logits
,
present
return
logits
,
present
server/text_generation_server/models/flash_causal_lm.py
View file @
299217c9
...
@@ -44,6 +44,8 @@ class FlashCausalLMBatch(Batch):
...
@@ -44,6 +44,8 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
input_lengths
:
List
[
int
]
offsets
:
List
[
Optional
[
int
]]
token_offsets
:
List
[
Optional
[
int
]]
# Generation helpers
# Generation helpers
next_token_choosers
:
List
[
NextTokenChooser
]
next_token_choosers
:
List
[
NextTokenChooser
]
...
@@ -67,6 +69,8 @@ class FlashCausalLMBatch(Batch):
...
@@ -67,6 +69,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen
=
0
max_seqlen
=
0
input_lengths
=
[]
input_lengths
=
[]
offsets
=
[]
token_offsets
=
[]
all_input_ids
=
[]
all_input_ids
=
[]
all_input_ids_tensor
=
[]
all_input_ids_tensor
=
[]
...
@@ -84,6 +88,8 @@ class FlashCausalLMBatch(Batch):
...
@@ -84,6 +88,8 @@ class FlashCausalLMBatch(Batch):
input_length
=
len
(
tokenized_input
)
input_length
=
len
(
tokenized_input
)
max_seqlen
=
max
(
max_seqlen
,
input_length
)
max_seqlen
=
max
(
max_seqlen
,
input_length
)
input_lengths
.
append
(
input_length
)
input_lengths
.
append
(
input_length
)
offsets
.
append
(
None
)
token_offsets
.
append
(
None
)
all_input_ids
.
append
(
tokenized_input
)
all_input_ids
.
append
(
tokenized_input
)
tokenized_input
=
torch
.
tensor
(
tokenized_input
,
device
=
device
)
tokenized_input
=
torch
.
tensor
(
tokenized_input
,
device
=
device
)
...
@@ -120,6 +126,8 @@ class FlashCausalLMBatch(Batch):
...
@@ -120,6 +126,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
past_key_values
=
None
,
past_key_values
=
None
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
offsets
=
offsets
,
token_offsets
=
token_offsets
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
all_input_ids_tensor
=
all_input_ids_tensor
,
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
...
@@ -132,6 +140,8 @@ class FlashCausalLMBatch(Batch):
...
@@ -132,6 +140,8 @@ class FlashCausalLMBatch(Batch):
# Batch attributes
# Batch attributes
requests
=
[]
requests
=
[]
input_lengths
=
[]
input_lengths
=
[]
offsets
=
[]
token_offsets
=
[]
all_input_ids
=
[]
all_input_ids
=
[]
all_input_ids_tensor
=
[]
all_input_ids_tensor
=
[]
next_token_choosers
=
[]
next_token_choosers
=
[]
...
@@ -150,6 +160,8 @@ class FlashCausalLMBatch(Batch):
...
@@ -150,6 +160,8 @@ class FlashCausalLMBatch(Batch):
for
i
,
batch
in
enumerate
(
batches
):
for
i
,
batch
in
enumerate
(
batches
):
requests
.
extend
(
batch
.
requests
)
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
input_lengths
.
extend
(
batch
.
input_lengths
)
offsets
.
extend
(
batch
.
offsets
)
token_offsets
.
extend
(
batch
.
token_offsets
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_input_ids
.
extend
(
batch
.
all_input_ids
)
all_input_ids_tensor
.
extend
(
batch
.
all_input_ids_tensor
)
all_input_ids_tensor
.
extend
(
batch
.
all_input_ids_tensor
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
...
@@ -182,6 +194,8 @@ class FlashCausalLMBatch(Batch):
...
@@ -182,6 +194,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen
=
max_seqlen
,
max_seqlen
=
max_seqlen
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
offsets
=
offsets
,
token_offsets
=
token_offsets
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
all_input_ids_tensor
=
all_input_ids_tensor
,
all_input_ids_tensor
=
all_input_ids_tensor
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
...
@@ -279,6 +293,8 @@ class FlashCausalLM(Model):
...
@@ -279,6 +293,8 @@ class FlashCausalLM(Model):
next_batch_max_seqlen
=
0
next_batch_max_seqlen
=
0
next_batch_past_key_values
=
[]
next_batch_past_key_values
=
[]
next_batch_input_lengths
=
[]
next_batch_input_lengths
=
[]
next_batch_offsets
=
[]
next_batch_token_offsets
=
[]
next_batch_all_input_ids
=
[]
next_batch_all_input_ids
=
[]
next_batch_all_input_ids_tensor
=
[]
next_batch_all_input_ids_tensor
=
[]
...
@@ -292,6 +308,8 @@ class FlashCausalLM(Model):
...
@@ -292,6 +308,8 @@ class FlashCausalLM(Model):
iterator
=
zip
(
iterator
=
zip
(
batch
.
requests
,
batch
.
requests
,
batch
.
input_lengths
,
batch
.
input_lengths
,
batch
.
offsets
,
batch
.
token_offsets
,
batch
.
next_token_choosers
,
batch
.
next_token_choosers
,
batch
.
stopping_criterias
,
batch
.
stopping_criterias
,
batch
.
all_input_ids
,
batch
.
all_input_ids
,
...
@@ -302,6 +320,8 @@ class FlashCausalLM(Model):
...
@@ -302,6 +320,8 @@ class FlashCausalLM(Model):
for
i
,
(
for
i
,
(
request
,
request
,
input_length
,
input_length
,
offset
,
token_offset
,
next_token_chooser
,
next_token_chooser
,
stopping_criteria
,
stopping_criteria
,
all_input_ids
,
all_input_ids
,
...
@@ -334,8 +354,10 @@ class FlashCausalLM(Model):
...
@@ -334,8 +354,10 @@ class FlashCausalLM(Model):
# Generated token
# Generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token_id_item
]
next_token_logprob
=
logprobs
[
-
1
,
next_token_id_item
]
next_token_text
=
self
.
decode_token
(
next_token_text
,
offset
,
token_offset
=
self
.
decode_token
(
next_token_id_item
,
all_input_ids
,
offset
,
token_offset
,
)
)
# Evaluate stopping criteria
# Evaluate stopping criteria
...
@@ -376,6 +398,8 @@ class FlashCausalLM(Model):
...
@@ -376,6 +398,8 @@ class FlashCausalLM(Model):
next_batch_cu_seqlens
[
-
1
]
+
new_input_length
next_batch_cu_seqlens
[
-
1
]
+
new_input_length
)
)
next_batch_input_lengths
.
append
(
new_input_length
)
next_batch_input_lengths
.
append
(
new_input_length
)
next_batch_offsets
.
append
(
offset
)
next_batch_token_offsets
.
append
(
token_offset
)
next_batch_all_input_ids
.
append
(
all_input_ids
)
next_batch_all_input_ids
.
append
(
all_input_ids
)
next_batch_all_input_ids_tensor
.
append
(
all_input_ids_tensor
)
next_batch_all_input_ids_tensor
.
append
(
all_input_ids_tensor
)
next_batch_max_seqlen
=
max
(
next_batch_max_seqlen
,
new_input_length
)
next_batch_max_seqlen
=
max
(
next_batch_max_seqlen
,
new_input_length
)
...
@@ -452,6 +476,8 @@ class FlashCausalLM(Model):
...
@@ -452,6 +476,8 @@ class FlashCausalLM(Model):
max_seqlen
=
next_batch_max_seqlen
,
max_seqlen
=
next_batch_max_seqlen
,
past_key_values
=
next_batch_past_key_values
,
past_key_values
=
next_batch_past_key_values
,
input_lengths
=
next_batch_input_lengths
,
input_lengths
=
next_batch_input_lengths
,
offsets
=
next_batch_offsets
,
token_offsets
=
next_batch_token_offsets
,
all_input_ids
=
next_batch_all_input_ids
,
all_input_ids
=
next_batch_all_input_ids
,
all_input_ids_tensor
=
next_batch_all_input_ids_tensor
,
all_input_ids_tensor
=
next_batch_all_input_ids_tensor
,
next_token_choosers
=
next_batch_next_token_choosers
,
next_token_choosers
=
next_batch_next_token_choosers
,
...
...
server/text_generation_server/models/flash_llama.py
0 → 100644
View file @
299217c9
import
torch
import
torch.distributed
from
accelerate
import
init_empty_weights
from
opentelemetry
import
trace
from
pathlib
import
Path
from
safetensors
import
safe_open
from
transformers
import
AutoConfig
from
transformers.models.llama
import
LlamaTokenizer
from
typing
import
Optional
,
List
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_llama_modeling
import
(
FlashLlamaForCausalLM
,
TensorParallelEmbedding
,
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
)
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
download_weights
,
weight_hub_files
,
LocalEntryNotFoundError
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
class
FlashLlama
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashLlama does not support quantization"
)
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
)
# We do not use from_pretrained as we modified the model internal module layout
try
:
filenames
=
weight_files
(
model_id
,
revision
,
".bin"
)
# Local files not found
except
LocalEntryNotFoundError
:
hub_files
=
weight_hub_files
(
model_id
,
revision
,
".bin"
)
filenames
=
download_weights
(
hub_files
,
model_id
,
revision
)
with
init_empty_weights
():
model
=
FlashLlamaForCausalLM
(
config
)
self
.
load_weights
(
model
,
filenames
,
device
,
dtype
)
self
.
model
=
model
.
eval
()
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
)
@
staticmethod
def
load_weights
(
model
,
filenames
:
List
[
Path
],
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
):
for
filename
in
filenames
:
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
for
key
,
value
in
state_dict
.
items
():
value
=
value
.
to
(
device
).
to
(
dtype
)
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
# Fused qkv
if
"q_proj"
in
key
or
"k_proj"
in
key
or
"v_proj"
in
key
:
final_key
=
layer_name
+
".query_key_value.weight"
# Fused gate and up projs
elif
"gate_proj"
in
key
or
"up_proj"
in
key
:
final_key
=
layer_name
+
".gate_up_proj.weight"
else
:
final_key
=
key
module_name
,
param_name
=
final_key
.
rsplit
(
"."
,
1
)
module
=
model
.
get_submodule
(
module_name
)
try
:
current_parameter_tensor
=
module
.
_parameters
[
param_name
]
except
KeyError
:
current_parameter_tensor
=
None
if
current_parameter_tensor
is
not
None
:
if
current_parameter_tensor
.
device
==
torch
.
device
(
"meta"
):
# Init qkv
if
"query_key_value"
in
final_key
:
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
(
value
.
shape
[
0
]
*
3
,
value
.
shape
[
1
])
)
# Init gate and up proj
elif
"gate_up_proj"
in
final_key
:
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
(
value
.
shape
[
0
]
*
2
,
value
.
shape
[
1
])
)
# Copy to correct slice
if
"q_proj"
in
key
:
module
.
_parameters
[
param_name
][:
value
.
shape
[
0
]]
=
value
elif
"k_proj"
in
key
:
module
.
_parameters
[
param_name
][
value
.
shape
[
0
]
:
value
.
shape
[
0
]
*
2
]
=
value
elif
"v_proj"
in
key
:
module
.
_parameters
[
param_name
][
value
.
shape
[
0
]
*
2
:]
=
value
elif
"gate_proj"
in
key
:
module
.
_parameters
[
param_name
][:
value
.
shape
[
0
]]
=
value
elif
"up_proj"
in
key
:
module
.
_parameters
[
param_name
][
value
.
shape
[
0
]
:]
=
value
else
:
if
current_parameter_tensor
.
shape
!=
value
.
shape
:
raise
ValueError
(
f
"Name
{
final_key
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
value
.
shape
}
"
)
module
.
_parameters
[
param_name
]
=
value
else
:
module
.
_buffers
[
param_name
]
=
value
del
value
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
()
class
FlashLlamaSharded
(
FlashLlama
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashLlama does not support quantization"
)
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
)
config
=
AutoConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
with
init_empty_weights
():
model
=
FlashLlamaForCausalLM
(
config
,
process_group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
self
.
load_weights
(
model
,
filenames
,
quantize
=
quantize
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
)
@
staticmethod
def
load_weights
(
model
,
filenames
:
List
[
str
],
quantize
:
bool
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
world_size
:
int
,
):
for
file
in
filenames
:
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
)
as
f
:
for
name
in
f
.
keys
():
slice_
=
f
.
get_slice
(
name
)
layer_name
=
"."
.
join
(
name
.
split
(
"."
)[:
4
])
# Fused qkv
if
"q_proj"
in
name
or
"k_proj"
in
name
or
"v_proj"
in
name
:
final_name
=
layer_name
+
".query_key_value.weight"
# Fused gate and up projs
elif
"gate_proj"
in
name
or
"up_proj"
in
name
:
final_name
=
layer_name
+
".gate_up_proj.weight"
else
:
final_name
=
name
module_name
,
param_name
=
final_name
.
rsplit
(
"."
,
1
)
module
=
model
.
get_submodule
(
module_name
)
if
isinstance
(
module
,
TensorParallelColumnLinear
):
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
elif
isinstance
(
module
,
TensorParallelRowLinear
):
size
=
slice_
.
get_shape
()[
1
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[:,
start
:
stop
]
elif
isinstance
(
module
,
TensorParallelEmbedding
):
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
elif
name
==
"lm_head.weight"
and
model
.
model
.
tp_embeddings
:
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
else
:
try
:
tensor
=
slice_
[:]
except
:
tensor
=
f
.
get_tensor
(
name
)
tensor
=
tensor
.
contiguous
().
to
(
dtype
)
try
:
current_parameter_tensor
=
module
.
_parameters
[
param_name
]
except
KeyError
:
current_parameter_tensor
=
None
if
current_parameter_tensor
is
not
None
:
if
current_parameter_tensor
.
device
==
torch
.
device
(
"meta"
):
# Init qkv
if
"query_key_value"
in
final_name
:
module
.
_parameters
[
param_name
]
=
tensor
.
new_empty
(
(
tensor
.
shape
[
0
]
*
3
,
tensor
.
shape
[
1
])
)
# Init gate and up proj
elif
"gate_up_proj"
in
final_name
:
module
.
_parameters
[
param_name
]
=
tensor
.
new_empty
(
(
tensor
.
shape
[
0
]
*
2
,
tensor
.
shape
[
1
])
)
# Init gate and up proj
if
"q_proj"
in
name
:
module
.
_parameters
[
param_name
][:
tensor
.
shape
[
0
]]
=
tensor
elif
"k_proj"
in
name
:
module
.
_parameters
[
param_name
][
tensor
.
shape
[
0
]
:
tensor
.
shape
[
0
]
*
2
]
=
tensor
elif
"v_proj"
in
name
:
module
.
_parameters
[
param_name
][
tensor
.
shape
[
0
]
*
2
:
]
=
tensor
elif
"gate_proj"
in
name
:
module
.
_parameters
[
param_name
][:
tensor
.
shape
[
0
]]
=
tensor
elif
"up_proj"
in
name
:
module
.
_parameters
[
param_name
][
tensor
.
shape
[
0
]
:]
=
tensor
else
:
if
current_parameter_tensor
.
shape
!=
tensor
.
shape
:
raise
ValueError
(
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
module
.
_parameters
[
param_name
]
=
tensor
else
:
module
.
_buffers
[
param_name
]
=
tensor
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
()
server/text_generation_server/models/galactica.py
View file @
299217c9
...
@@ -93,7 +93,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
...
@@ -93,7 +93,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
inputs
=
[]
inputs
=
[]
next_token_choosers
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
stopping_criterias
=
[]
input_lengths
=
[]
offsets
=
[]
token_offsets
=
[]
# Parse batch
# Parse batch
max_truncation
=
0
max_truncation
=
0
...
@@ -101,7 +102,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
...
@@ -101,7 +102,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
for
r
in
pb
.
requests
:
for
r
in
pb
.
requests
:
# Add escape_custom_split_sequence to the CausalLMBatch logic
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs
.
append
(
escape_custom_split_sequence
(
r
.
inputs
))
inputs
.
append
(
escape_custom_split_sequence
(
r
.
inputs
))
input_lengths
.
append
(
r
.
input_length
)
offsets
.
append
(
None
)
token_offsets
.
append
(
None
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criteria
=
StoppingCriteria
.
from_pb
(
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
r
.
stopping_parameters
,
tokenizer
...
@@ -146,6 +148,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
...
@@ -146,6 +148,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
past_key_values
=
None
,
past_key_values
=
None
,
all_input_ids
=
all_input_ids
,
all_input_ids
=
all_input_ids
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
offsets
=
offsets
,
token_offsets
=
token_offsets
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
size
=
pb
.
size
,
size
=
pb
.
size
,
...
...
server/text_generation_server/models/model.py
View file @
299217c9
...
@@ -15,15 +15,6 @@ class Model(ABC):
...
@@ -15,15 +15,6 @@ class Model(ABC):
self
.
all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
self
.
all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
self
.
device
=
device
self
.
device
=
device
# see `decode_token` method
self
.
tokenizer
.
add_special_tokens
(
{
"additional_special_tokens"
:
[
"<decode-token>"
]}
)
self
.
special_decode_token_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
"<decode-token>"
)
self
.
special_decode_token_length
=
len
(
"<decode-token>"
)
@
property
@
property
@
abstractmethod
@
abstractmethod
def
batch_type
(
self
)
->
Type
[
B
]:
def
batch_type
(
self
)
->
Type
[
B
]:
...
@@ -33,11 +24,38 @@ class Model(ABC):
...
@@ -33,11 +24,38 @@ class Model(ABC):
def
generate_token
(
self
,
batch
:
B
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
B
]]:
def
generate_token
(
self
,
batch
:
B
)
->
Tuple
[
List
[
GeneratedText
],
Optional
[
B
]]:
raise
NotImplementedError
raise
NotImplementedError
def
decode_token
(
self
,
token_id
:
int
)
->
str
:
def
decode_token
(
self
,
all_input_ids
:
List
[
int
],
offset
:
Optional
[
int
]
=
None
,
token_offset
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
str
,
Optional
[
int
],
Optional
[
int
]]:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# append token to special decode token and decode both
if
all_input_ids
[
-
1
]
in
self
.
all_special_ids
:
result
=
self
.
tokenizer
.
decode
(
return
(
[
self
.
special_decode_token_id
,
token_id
],
skip_special_tokens
=
False
self
.
tokenizer
.
decode
(
all_input_ids
[
-
1
],
skip_special_tokens
=
False
),
None
,
None
,
)
if
token_offset
is
None
:
token_offset
=
len
(
all_input_ids
)
-
3
# Decode token_offset token minus last one and token_offset tokens
results
=
self
.
tokenizer
.
batch_decode
(
[
all_input_ids
[
token_offset
:
-
1
],
all_input_ids
[
token_offset
:]],
skip_special_tokens
=
False
,
)
)
# slice to remove special decode token
return
result
[
self
.
special_decode_token_length
:]
# default offset is only the last token
if
offset
is
None
:
offset
=
len
(
results
[
0
])
# get text
text
=
results
[
1
][
offset
:]
# if text is utf-8
if
text
and
text
[
-
1
]
!=
"�"
:
return
text
,
None
,
None
else
:
return
""
,
offset
,
token_offset
server/text_generation_server/models/seq2seq_lm.py
View file @
299217c9
...
@@ -38,6 +38,8 @@ class Seq2SeqLMBatch(Batch):
...
@@ -38,6 +38,8 @@ class Seq2SeqLMBatch(Batch):
# Lengths of all generations present in the batch
# Lengths of all generations present in the batch
input_lengths
:
List
[
int
]
input_lengths
:
List
[
int
]
decoder_input_lengths
:
List
[
int
]
decoder_input_lengths
:
List
[
int
]
offsets
:
List
[
Optional
[
int
]]
token_offsets
:
List
[
Optional
[
int
]]
# Generation helpers
# Generation helpers
next_token_choosers
:
List
[
NextTokenChooser
]
next_token_choosers
:
List
[
NextTokenChooser
]
...
@@ -71,6 +73,8 @@ class Seq2SeqLMBatch(Batch):
...
@@ -71,6 +73,8 @@ class Seq2SeqLMBatch(Batch):
decoder_input_ids
=
[]
decoder_input_ids
=
[]
decoder_input_lengths
=
[]
decoder_input_lengths
=
[]
offsets
=
[]
token_offsets
=
[]
# Parse batch
# Parse batch
max_truncation
=
0
max_truncation
=
0
...
@@ -80,6 +84,8 @@ class Seq2SeqLMBatch(Batch):
...
@@ -80,6 +84,8 @@ class Seq2SeqLMBatch(Batch):
# Decoder sequence only contains the bos_token
# Decoder sequence only contains the bos_token
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_ids
.
append
(
tokenizer
.
bos_token_id
)
decoder_input_lengths
.
append
(
1
)
decoder_input_lengths
.
append
(
1
)
offsets
.
append
(
None
)
token_offsets
.
append
(
None
)
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
next_token_choosers
.
append
(
NextTokenChooser
.
from_pb
(
r
.
parameters
,
device
))
stopping_criteria
=
StoppingCriteria
.
from_pb
(
stopping_criteria
=
StoppingCriteria
.
from_pb
(
r
.
stopping_parameters
,
tokenizer
r
.
stopping_parameters
,
tokenizer
...
@@ -117,6 +123,8 @@ class Seq2SeqLMBatch(Batch):
...
@@ -117,6 +123,8 @@ class Seq2SeqLMBatch(Batch):
past_key_values
=
None
,
past_key_values
=
None
,
input_lengths
=
input_lengths
.
tolist
(),
input_lengths
=
input_lengths
.
tolist
(),
decoder_input_lengths
=
decoder_input_lengths
,
decoder_input_lengths
=
decoder_input_lengths
,
offsets
=
offsets
,
token_offsets
=
token_offsets
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
size
=
len
(
pb
.
requests
),
size
=
len
(
pb
.
requests
),
...
@@ -147,6 +155,8 @@ class Seq2SeqLMBatch(Batch):
...
@@ -147,6 +155,8 @@ class Seq2SeqLMBatch(Batch):
requests
=
[]
requests
=
[]
input_lengths
=
[]
input_lengths
=
[]
decoder_input_lengths
=
[]
decoder_input_lengths
=
[]
offsets
=
[]
token_offsets
=
[]
next_token_choosers
=
[]
next_token_choosers
=
[]
stopping_criterias
=
[]
stopping_criterias
=
[]
...
@@ -166,6 +176,8 @@ class Seq2SeqLMBatch(Batch):
...
@@ -166,6 +176,8 @@ class Seq2SeqLMBatch(Batch):
requests
.
extend
(
batch
.
requests
)
requests
.
extend
(
batch
.
requests
)
input_lengths
.
extend
(
batch
.
input_lengths
)
input_lengths
.
extend
(
batch
.
input_lengths
)
decoder_input_lengths
.
extend
(
batch
.
decoder_input_lengths
)
decoder_input_lengths
.
extend
(
batch
.
decoder_input_lengths
)
offsets
.
extend
(
batch
.
offsets
)
token_offsets
.
extend
(
batch
.
token_offsets
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
next_token_choosers
.
extend
(
batch
.
next_token_choosers
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
stopping_criterias
.
extend
(
batch
.
stopping_criterias
)
...
@@ -303,6 +315,8 @@ class Seq2SeqLMBatch(Batch):
...
@@ -303,6 +315,8 @@ class Seq2SeqLMBatch(Batch):
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
input_lengths
=
input_lengths
,
input_lengths
=
input_lengths
,
decoder_input_lengths
=
decoder_input_lengths
,
decoder_input_lengths
=
decoder_input_lengths
,
offsets
=
offsets
,
token_offsets
=
token_offsets
,
next_token_choosers
=
next_token_choosers
,
next_token_choosers
=
next_token_choosers
,
stopping_criterias
=
stopping_criterias
,
stopping_criterias
=
stopping_criterias
,
size
=
total_batch_size
,
size
=
total_batch_size
,
...
@@ -335,7 +349,7 @@ class Seq2SeqLM(Model):
...
@@ -335,7 +349,7 @@ class Seq2SeqLM(Model):
load_in_8bit
=
quantize
,
load_in_8bit
=
quantize
,
).
eval
()
).
eval
()
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
)
tokenizer
.
bos_token_id
=
self
.
model
.
config
.
decoder_start_token_id
tokenizer
.
bos_token_id
=
self
.
model
.
config
.
decoder_start_token_id
...
@@ -422,6 +436,8 @@ class Seq2SeqLM(Model):
...
@@ -422,6 +436,8 @@ class Seq2SeqLM(Model):
# New values for next forward
# New values for next forward
next_batch_input_lengths
=
[]
next_batch_input_lengths
=
[]
next_batch_offsets
=
[]
next_batch_token_offsets
=
[]
next_batch_decoder_input_ids
=
[]
next_batch_decoder_input_ids
=
[]
next_batch_decoder_input_lengths
=
[]
next_batch_decoder_input_lengths
=
[]
...
@@ -437,6 +453,8 @@ class Seq2SeqLM(Model):
...
@@ -437,6 +453,8 @@ class Seq2SeqLM(Model):
iterator
=
zip
(
iterator
=
zip
(
batch
.
requests
,
batch
.
requests
,
batch
.
input_lengths
,
batch
.
input_lengths
,
batch
.
offsets
,
batch
.
token_offsets
,
batch
.
decoder_input_lengths
,
batch
.
decoder_input_lengths
,
logits
,
logits
,
batch
.
next_token_choosers
,
batch
.
next_token_choosers
,
...
@@ -448,6 +466,8 @@ class Seq2SeqLM(Model):
...
@@ -448,6 +466,8 @@ class Seq2SeqLM(Model):
for
i
,
(
for
i
,
(
request
,
request
,
input_length
,
input_length
,
offset
,
token_offset
,
decoder_input_length
,
decoder_input_length
,
logits
,
logits
,
next_token_chooser
,
next_token_chooser
,
...
@@ -466,8 +486,8 @@ class Seq2SeqLM(Model):
...
@@ -466,8 +486,8 @@ class Seq2SeqLM(Model):
# Generated token
# Generated token
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
]
next_token_logprob
=
logprobs
[
-
1
,
next_token_id
]
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_id_squeezed
=
next_token_id
.
squeeze
()
next_token_text
=
self
.
decode_token
(
next_token_text
,
offset
,
token_offset
=
self
.
decode_token
(
next_token_id_squeezed
,
decoder_input_ids
,
offset
,
token_offset
)
)
# Evaluate stopping criteria
# Evaluate stopping criteria
...
@@ -495,6 +515,8 @@ class Seq2SeqLM(Model):
...
@@ -495,6 +515,8 @@ class Seq2SeqLM(Model):
next_batch_size
+=
1
next_batch_size
+=
1
next_batch_input_lengths
.
append
(
input_length
)
next_batch_input_lengths
.
append
(
input_length
)
next_batch_decoder_input_lengths
.
append
(
new_decoder_input_length
)
next_batch_decoder_input_lengths
.
append
(
new_decoder_input_length
)
next_batch_offsets
.
append
(
offset
)
next_batch_token_offsets
.
append
(
token_offset
)
next_batch_max_input_length
=
max
(
next_batch_max_input_length
=
max
(
next_batch_max_input_length
,
input_length
next_batch_max_input_length
,
input_length
)
)
...
@@ -580,6 +602,8 @@ class Seq2SeqLM(Model):
...
@@ -580,6 +602,8 @@ class Seq2SeqLM(Model):
past_key_values
=
next_batch_past_key_values
,
past_key_values
=
next_batch_past_key_values
,
input_lengths
=
next_batch_input_lengths
,
input_lengths
=
next_batch_input_lengths
,
decoder_input_lengths
=
next_batch_decoder_input_lengths
,
decoder_input_lengths
=
next_batch_decoder_input_lengths
,
offsets
=
next_batch_offsets
,
token_offsets
=
next_batch_token_offsets
,
next_token_choosers
=
next_batch_next_token_choosers
,
next_token_choosers
=
next_batch_next_token_choosers
,
stopping_criterias
=
next_batch_stopping_criterias
,
stopping_criterias
=
next_batch_stopping_criterias
,
size
=
next_batch_size
,
size
=
next_batch_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment