Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
9b56d3fb
Unverified
Commit
9b56d3fb
authored
Dec 15, 2023
by
OlivierDehaene
Committed by
GitHub
Dec 15, 2023
Browse files
feat: relax mistral requirements (#1351)
Close #1253 Close #1279
parent
f3aea78f
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
671 additions
and
621 deletions
+671
-621
.github/workflows/build.yaml
.github/workflows/build.yaml
+40
-38
server/poetry.lock
server/poetry.lock
+587
-508
server/pyproject.toml
server/pyproject.toml
+4
-4
server/requirements_cuda.txt
server/requirements_cuda.txt
+10
-10
server/requirements_rocm.txt
server/requirements_rocm.txt
+9
-9
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+15
-29
server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
...n_server/models/custom_modeling/flash_mistral_modeling.py
+0
-9
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
...n_server/models/custom_modeling/flash_mixtral_modeling.py
+6
-14
No files found.
.github/workflows/build.yaml
View file @
9b56d3fb
...
...
@@ -146,11 +146,50 @@ jobs:
cache-from
:
type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
cache-to
:
type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache,mode=min
integration-tests
:
concurrency
:
group
:
${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress
:
true
needs
:
-
start-runner
-
build-and-push-image
# Wait for the docker image to be built
runs-on
:
${{ needs.start-runner.outputs.label }}
# run the job on the newly created runner
env
:
DOCKER_VOLUME
:
/cache
steps
:
-
uses
:
actions/checkout@v2
-
name
:
Inject slug/short variables
uses
:
rlespinasse/github-slug-action@v4.4.1
-
name
:
Set up Python
uses
:
actions/setup-python@v4
with
:
python-version
:
3.9
-
name
:
Tailscale
uses
:
tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
with
:
authkey
:
${{ secrets.TAILSCALE_AUTHKEY }}
-
name
:
Prepare disks
run
:
|
sudo mkfs -t ext4 /dev/nvme1n1
sudo mkdir ${{ env.DOCKER_VOLUME }}
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
-
name
:
Install
run
:
|
make install-integration-tests
-
name
:
Run tests
run
:
|
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv integration-tests
build-and-push-image-rocm
:
concurrency
:
group
:
${{ github.workflow }}-build-and-push-image-rocm-${{ github.head_ref || github.run_id }}
cancel-in-progress
:
true
needs
:
start-runner
# required to start the main job when the runner is ready
needs
:
-
start-runner
-
build-and-push-image
# Wait for the main docker image to be built
-
integration-tests
# Wait for the main integration-tests
runs-on
:
${{ needs.start-runner.outputs.label }}
# run the job on the newly created runner
permissions
:
contents
:
write
...
...
@@ -235,43 +274,6 @@ jobs:
cache-from
:
type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
cache-to
:
type=registry,ref=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:cache-rocm,mode=min
integration-tests
:
concurrency
:
group
:
${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress
:
true
needs
:
-
start-runner
-
build-and-push-image
# Wait for the docker image to be built
-
build-and-push-image-rocm
runs-on
:
${{ needs.start-runner.outputs.label }}
# run the job on the newly created runner
env
:
DOCKER_VOLUME
:
/cache
steps
:
-
uses
:
actions/checkout@v2
-
name
:
Inject slug/short variables
uses
:
rlespinasse/github-slug-action@v4.4.1
-
name
:
Set up Python
uses
:
actions/setup-python@v4
with
:
python-version
:
3.9
-
name
:
Tailscale
uses
:
tailscale/github-action@7bd8039bf25c23c4ab1b8d6e2cc2da2280601966
with
:
authkey
:
${{ secrets.TAILSCALE_AUTHKEY }}
-
name
:
Prepare disks
run
:
|
sudo mkfs -t ext4 /dev/nvme1n1
sudo mkdir ${{ env.DOCKER_VOLUME }}
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
-
name
:
Install
run
:
|
make install-integration-tests
-
name
:
Run tests
run
:
|
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv integration-tests
stop-runner
:
name
:
Stop self-hosted EC2 runner
needs
:
...
...
server/poetry.lock
View file @
9b56d3fb
This diff is collapsed.
Click to expand it.
server/pyproject.toml
View file @
9b56d3fb
...
...
@@ -15,7 +15,7 @@ grpcio-status = "^1.51.1"
grpcio-reflection
=
"^1.51.1"
grpc-interceptor
=
"^0.15.0"
typer
=
"^0.6.1"
accelerate
=
{
version
=
"^0.2
0
.0"
,
optional
=
true
}
accelerate
=
{
version
=
"^0.2
5
.0"
,
optional
=
true
}
bitsandbytes
=
{
version
=
"^0.41.1"
,
optional
=
true
}
safetensors
=
"^0.3.2"
loguru
=
"^0.6.0"
...
...
@@ -24,9 +24,9 @@ opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc
=
"^0.36b0"
hf-transfer
=
"^0.1.2"
sentencepiece
=
"^0.1.97"
tokenizers
=
"^0.1
3.3
"
huggingface-hub
=
"^0.1
6.4
"
transformers
=
"^4.3
2
.1"
tokenizers
=
"^0.1
5.0
"
huggingface-hub
=
"^0.1
9.3
"
transformers
=
"^4.3
6
.1"
einops
=
"^0.6.1"
texttable
=
{
version
=
"^1.6.7"
,
optional
=
true
}
datasets
=
{
version
=
"^2.14.0"
,
optional
=
true
}
...
...
server/requirements_cuda.txt
View file @
9b56d3fb
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
bitsandbytes==0.41.
2
.post2 ; python_version >= "3.9" and python_version < "3.13"
bitsandbytes==0.41.
3
.post2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
...
...
@@ -8,14 +8,14 @@ deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.6
1
.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.6
2
.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.
59.3
; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.
59.3
; python_version >= "3.9" and python_version < "3.13"
grpcio==1.
59.3
; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.
60.0
; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.
60.0
; python_version >= "3.9" and python_version < "3.13"
grpcio==1.
60.0
; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.1
6
.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.
4
; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.1
9
.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.
6
; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
...
...
@@ -37,11 +37,11 @@ safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.1
3.3
; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.1
5.0
; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.3
3.3
; python_version >= "3.9" and python_version < "3.13"
transformers==4.3
6.1
; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.
8
.0 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.
9
.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
server/requirements_rocm.txt
View file @
9b56d3fb
...
...
@@ -7,14 +7,14 @@ deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.6
1
.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.6
2
.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.
59.3
; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.
59.3
; python_version >= "3.9" and python_version < "3.13"
grpcio==1.
59.3
; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.
60.0
; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.
60.0
; python_version >= "3.9" and python_version < "3.13"
grpcio==1.
60.0
; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.1
6
.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.
4
; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.1
9
.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.
6
; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
...
...
@@ -36,11 +36,11 @@ safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.1
3.3
; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.1
5.0
; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.3
3.3
; python_version >= "3.9" and python_version < "3.13"
transformers==4.3
6.1
; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.
8
.0 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.
9
.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
server/text_generation_server/models/__init__.py
View file @
9b56d3fb
...
...
@@ -55,10 +55,14 @@ try:
FlashSantacoderSharded
,
)
from
text_generation_server.models.idefics
import
IDEFICSSharded
from
text_generation_server.models.flash_mistral
import
FlashMistral
from
text_generation_server.models.flash_mixtral
import
FlashMixtral
from
text_generation_server.utils.flash_attn
import
HAS_FLASH_ATTN_V2_CUDA
except
ImportError
as
e
:
logger
.
warning
(
f
"Could not import Flash Attention enabled models:
{
e
}
"
)
FLASH_ATTENTION
=
False
HAS_FLASH_ATTN_V2_CUDA
=
False
if
FLASH_ATTENTION
:
__all__
.
append
(
FlashNeoXSharded
)
...
...
@@ -66,25 +70,7 @@ if FLASH_ATTENTION:
__all__
.
append
(
FlashSantacoderSharded
)
__all__
.
append
(
FlashLlama
)
__all__
.
append
(
IDEFICSSharded
)
MISTRAL
=
True
try
:
from
text_generation_server.models.flash_mistral
import
FlashMistral
except
ImportError
as
e
:
logger
.
warning
(
f
"Could not import Mistral model:
{
e
}
"
)
MISTRAL
=
False
if
MISTRAL
:
__all__
.
append
(
FlashMistral
)
MIXTRAL
=
True
try
:
from
text_generation_server.models.flash_mixtral
import
FlashMixtral
except
ImportError
as
e
:
logger
.
warning
(
f
"Could not import Mixtral model:
{
e
}
"
)
MIXTRAL
=
False
if
MIXTRAL
:
__all__
.
append
(
FlashMixtral
)
...
...
@@ -295,7 +281,9 @@ def get_model(
)
if
model_type
==
"mistral"
:
if
MISTRAL
:
if
(
config_dict
[
"sliding_window"
]
is
None
and
FLASH_ATTENTION
)
or
(
config_dict
[
"sliding_window"
]
>
0
and
HAS_FLASH_ATTN_V2_CUDA
):
return
FlashMistral
(
model_id
,
revision
,
...
...
@@ -303,10 +291,11 @@ def get_model(
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
raise
NotImplementedError
(
"Mistral models requires flash attention v2"
)
if
model_type
==
"mixtral"
:
if
MIXTRAL
:
if
(
config_dict
[
"sliding_window"
]
is
None
and
FLASH_ATTENTION
)
or
(
config_dict
[
"sliding_window"
]
>
0
and
HAS_FLASH_ATTN_V2_CUDA
):
return
FlashMixtral
(
model_id
,
revision
,
...
...
@@ -314,9 +303,6 @@ def get_model(
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
raise
NotImplementedError
(
"Mixtral models requires flash attention v2, stk and megablocks"
)
if
model_type
==
"opt"
:
return
OPTSharded
(
...
...
@@ -348,17 +334,17 @@ def get_model(
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
"Idefics"
))
if
sharded
:
raise
Value
Error
(
"sharded is not supported for AutoModel"
)
raise
NotImplemented
Error
(
"sharded is not supported for AutoModel"
)
if
quantize
==
"gptq"
:
raise
Value
Error
(
raise
NotImplemented
Error
(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if
quantize
==
"awq"
:
raise
Value
Error
(
"awq quantization is not supported for AutoModel"
)
raise
NotImplemented
Error
(
"awq quantization is not supported for AutoModel"
)
elif
(
quantize
==
"bitsandbytes-fp4"
)
or
(
quantize
==
"bitsandbytes-nf4"
):
raise
Value
Error
(
"4bit quantization is not supported for AutoModel"
)
raise
NotImplemented
Error
(
"4bit quantization is not supported for AutoModel"
)
elif
quantize
==
"eetq"
:
raise
Value
Error
(
"Eetq quantization is not supported for AutoModel"
)
raise
NotImplemented
Error
(
"Eetq quantization is not supported for AutoModel"
)
if
model_type
in
modeling_auto
.
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
return
CausalLM
(
model_id
,
...
...
server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py
View file @
9b56d3fb
...
...
@@ -27,11 +27,6 @@ from transformers.configuration_utils import PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.flash_attn
import
(
attention
,
HAS_FLASH_ATTN_V2_ROCM
,
HAS_FLASH_ATTN_V2_CUDA
,
)
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
...
...
@@ -43,10 +38,6 @@ from text_generation_server.utils.layers import (
)
if
not
HAS_FLASH_ATTN_V2_CUDA
and
not
HAS_FLASH_ATTN_V2_ROCM
:
raise
ImportError
(
"Mistral model requires flash attn v2"
)
class
MistralConfig
(
PretrainedConfig
):
model_type
=
"mistral"
...
...
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
View file @
9b56d3fb
...
...
@@ -27,12 +27,9 @@ from torch import nn
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
from
loguru
import
logger
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.flash_attn
import
(
HAS_FLASH_ATTN_V2_ROCM
,
HAS_FLASH_ATTN_V2_CUDA
,
)
from
text_generation_server.utils.layers
import
(
FastLinear
,
FastRMSNorm
,
...
...
@@ -44,18 +41,13 @@ from text_generation_server.utils.layers import (
get_linear
,
)
if
not
HAS_FLASH_ATTN_V2_CUDA
and
not
HAS_FLASH_ATTN_V2_ROCM
:
raise
ImportError
(
"Mixtral model requires flash attn v2"
)
try
:
import
megablocks.ops
as
ops
except
ImportError
:
raise
ImportError
(
"Mixtral model requires megablocks to be installed"
)
HAS_MEGABLOCKS
=
True
try
:
import
stk
import
megablocks.ops
as
ops
except
ImportError
:
raise
ImportError
(
"Mixtral model requires stk to be installed"
)
logger
.
warning
(
"Mixtral: megablocks is not installed"
)
HAS_MEGABLOCKS
=
False
class
MixtralConfig
(
PretrainedConfig
):
...
...
@@ -590,7 +582,7 @@ class BlockSparseMoE(nn.Module):
return
out
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
len
(
x
)
>
256
:
if
len
(
x
)
>
256
and
HAS_MEGABLOCKS
:
return
self
.
sparse_forward
(
x
)
# This is faster when there is not a lot of tokens
return
self
.
dense_forward
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment