Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
1e9bcd9d
Unverified
Commit
1e9bcd9d
authored
Mar 22, 2024
by
OlivierDehaene
Committed by
GitHub
Mar 22, 2024
Browse files
feat: cohere (#1660)
parent
f171bdc8
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1182 additions
and
785 deletions
+1182
-785
server/Makefile
server/Makefile
+1
-1
server/poetry.lock
server/poetry.lock
+590
-548
server/pyproject.toml
server/pyproject.toml
+3
-3
server/requirements_common.txt
server/requirements_common.txt
+0
-46
server/requirements_cuda.txt
server/requirements_cuda.txt
+12
-13
server/requirements_rocm.txt
server/requirements_rocm.txt
+12
-12
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+27
-0
server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
...on_server/models/custom_modeling/flash_cohere_modeling.py
+461
-0
server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
...ion_server/models/custom_modeling/flash_gemma_modeling.py
+0
-161
server/text_generation_server/models/flash_cohere.py
server/text_generation_server/models/flash_cohere.py
+75
-0
server/text_generation_server/models/flash_gemma.py
server/text_generation_server/models/flash_gemma.py
+1
-1
No files found.
server/Makefile
View file @
1e9bcd9d
...
@@ -29,5 +29,5 @@ run-dev:
...
@@ -29,5 +29,5 @@ run-dev:
SAFETENSORS_FAST_GPU
=
1 python
-m
torch.distributed.run
--nproc_per_node
=
2 text_generation_server/cli.py serve bigscience/bloom-560m
--sharded
SAFETENSORS_FAST_GPU
=
1 python
-m
torch.distributed.run
--nproc_per_node
=
2 text_generation_server/cli.py serve bigscience/bloom-560m
--sharded
export-requirements
:
export-requirements
:
poetry
export
-o
requirements_cuda.txt
--extras
bnb
--without-hashes
poetry
export
-o
requirements_cuda.txt
--without-hashes
poetry
export
-o
requirements_rocm.txt
--without-hashes
poetry
export
-o
requirements_rocm.txt
--without-hashes
server/poetry.lock
View file @
1e9bcd9d
This diff is collapsed.
Click to expand it.
server/pyproject.toml
View file @
1e9bcd9d
...
@@ -17,7 +17,7 @@ grpc-interceptor = "^0.15.0"
...
@@ -17,7 +17,7 @@ grpc-interceptor = "^0.15.0"
typer
=
"^0.6.1"
typer
=
"^0.6.1"
accelerate
=
{
version
=
"^0.28.0"
,
optional
=
true
}
accelerate
=
{
version
=
"^0.28.0"
,
optional
=
true
}
bitsandbytes
=
{
version
=
"^0.43.0"
,
optional
=
true
}
bitsandbytes
=
{
version
=
"^0.43.0"
,
optional
=
true
}
safetensors
=
"^0.4
.1
"
safetensors
=
"^0.4"
loguru
=
"^0.6.0"
loguru
=
"^0.6.0"
opentelemetry-api
=
"^1.15.0"
opentelemetry-api
=
"^1.15.0"
opentelemetry-exporter-otlp
=
"^1.15.0"
opentelemetry-exporter-otlp
=
"^1.15.0"
...
@@ -26,11 +26,11 @@ hf-transfer = "^0.1.2"
...
@@ -26,11 +26,11 @@ hf-transfer = "^0.1.2"
sentencepiece
=
"^0.1.97"
sentencepiece
=
"^0.1.97"
tokenizers
=
"^0.15.0"
tokenizers
=
"^0.15.0"
huggingface-hub
=
"^0.19.3"
huggingface-hub
=
"^0.19.3"
transformers
=
"^4.38
.2
"
transformers
=
"^4.38"
einops
=
"^0.6.1"
einops
=
"^0.6.1"
texttable
=
{
version
=
"^1.6.7"
,
optional
=
true
}
texttable
=
{
version
=
"^1.6.7"
,
optional
=
true
}
datasets
=
{
version
=
"^2.14.0"
,
optional
=
true
}
datasets
=
{
version
=
"^2.14.0"
,
optional
=
true
}
peft
=
{
version
=
"^0.9
.0
"
,
optional
=
true
}
peft
=
{
version
=
"^0.9"
,
optional
=
true
}
torch
=
{
version
=
"^2.1.1"
,
optional
=
true
}
torch
=
{
version
=
"^2.1.1"
,
optional
=
true
}
scipy
=
"^1.11.1"
scipy
=
"^1.11.1"
pillow
=
"^10.0.0"
pillow
=
"^10.0.0"
...
...
server/requirements_common.txt
deleted
100644 → 0
View file @
f171bdc8
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
server/requirements_cuda.txt
View file @
1e9bcd9d
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
bitsandbytes==0.41.3.post2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
...
@@ -7,13 +6,13 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p
...
@@ -7,13 +6,13 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==202
3.10
.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==202
4.2
.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.6
2
.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.6
3
.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.6
0
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.6
2
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.6
0
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.6
2
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.6
0
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.6
2
.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.
5
; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.
6
; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
...
@@ -27,21 +26,21 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
...
@@ -27,21 +26,21 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==2
3.2
; python_version >= "3.9" and python_version < "3.13"
packaging==2
4.0
; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.
3.3
; python_version >= "3.9" and python_version < "3.13"
safetensors==0.
4.2
; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.
1
.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.
2
.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.3
7.1
; python_version >= "3.9" and python_version < "3.13"
transformers==4.3
9.0
; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.
9
.0 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.
10
.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.
0
; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.
1
; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
server/requirements_rocm.txt
View file @
1e9bcd9d
...
@@ -6,13 +6,13 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p
...
@@ -6,13 +6,13 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==202
3.10
.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==202
4.2
.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.6
2
.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.6
3
.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.6
0
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.6
2
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.6
0
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.6
2
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.6
0
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.6
2
.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.
5
; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.
6
; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
...
@@ -26,21 +26,21 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
...
@@ -26,21 +26,21 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==2
3.2
; python_version >= "3.9" and python_version < "3.13"
packaging==2
4.0
; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.
3.3
; python_version >= "3.9" and python_version < "3.13"
safetensors==0.
4.2
; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.
1
.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.
2
.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.3
7.1
; python_version >= "3.9" and python_version < "3.13"
transformers==4.3
9.0
; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.
9
.0 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.
10
.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.
0
; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.
1
; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
server/text_generation_server/models/__init__.py
View file @
1e9bcd9d
...
@@ -57,6 +57,9 @@ try:
...
@@ -57,6 +57,9 @@ try:
from
text_generation_server.models.flash_qwen2
import
(
from
text_generation_server.models.flash_qwen2
import
(
FlashQwen2
,
FlashQwen2
,
)
)
from
text_generation_server.models.flash_cohere
import
(
FlashCohere
,
)
from
text_generation_server.models.flash_gemma
import
(
from
text_generation_server.models.flash_gemma
import
(
FlashGemma
,
FlashGemma
,
)
)
...
@@ -86,6 +89,8 @@ if FLASH_ATTENTION:
...
@@ -86,6 +89,8 @@ if FLASH_ATTENTION:
__all__
.
append
(
FlashPhi
)
__all__
.
append
(
FlashPhi
)
__all__
.
append
(
FlashQwen2
)
__all__
.
append
(
FlashQwen2
)
__all__
.
append
(
FlashStarcoder2
)
__all__
.
append
(
FlashStarcoder2
)
__all__
.
append
(
FlashGemma
)
__all__
.
append
(
FlashCohere
)
MAMBA_AVAILABLE
=
True
MAMBA_AVAILABLE
=
True
try
:
try
:
...
@@ -354,6 +359,28 @@ def get_model(
...
@@ -354,6 +359,28 @@ def get_model(
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
if
model_type
==
"cohere"
:
if
FLASH_ATTENTION
:
return
FlashCohere
(
model_id
,
revision
,
quantize
=
quantize
,
use_medusa
=
use_medusa
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
elif
sharded
:
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
"Sharded Cohere"
))
else
:
return
CausalLM
(
model_id
,
revision
,
quantize
=
quantize
,
use_medusa
=
use_medusa
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
model_type
in
[
"RefinedWeb"
,
"RefinedWebModel"
,
"falcon"
]:
if
model_type
in
[
"RefinedWeb"
,
"RefinedWebModel"
,
"falcon"
]:
if
sharded
:
if
sharded
:
if
FLASH_ATTENTION
:
if
FLASH_ATTENTION
:
...
...
server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
0 → 100644
View file @
1e9bcd9d
# coding=utf-8
# Copyright 2024 Cohere team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.distributed
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
PositionRotaryEmbedding
,
SpeculativeHead
,
get_linear
,
FastRMSNorm
,
)
class
CohereConfig
(
PretrainedConfig
):
def
__init__
(
self
,
vocab_size
=
256000
,
hidden_size
=
8192
,
intermediate_size
=
22528
,
num_hidden_layers
=
40
,
num_attention_heads
=
64
,
num_key_value_heads
=
None
,
hidden_act
=
"silu"
,
max_position_embeddings
=
8192
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-5
,
use_cache
=
True
,
pad_token_id
=
0
,
bos_token_id
=
5
,
eos_token_id
=
255001
,
pretraining_tp
=
1
,
tie_word_embeddings
=
True
,
rope_theta
=
10000.0
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
logit_scale
=
1.0
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
pretraining_tp
=
pretraining_tp
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
logit_scale
=
logit_scale
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
def
load_attention
(
config
,
prefix
,
weights
):
if
config
.
num_attention_heads
!=
config
.
num_key_value_heads
:
return
_load_gqa
(
config
,
prefix
,
weights
)
else
:
return
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
dim
=
0
,
weights
=
weights
,
bias
=
config
.
attention_bias
,
)
def
_load_gqa
(
config
,
prefix
:
str
,
weights
):
assert
config
.
hidden_size
%
config
.
num_attention_heads
==
0
assert
config
.
num_attention_heads
%
weights
.
process_group
.
size
()
==
0
weight
=
weights
.
get_multi_weights_col
(
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
quantize
=
config
.
quantize
,
dim
=
0
,
)
if
config
.
quantize
not
in
[
"gptq"
,
"awq"
]:
weight
=
weight
.
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
head_size
=
config
.
hidden_size
//
config
.
num_attention_heads
num_heads
=
config
.
num_attention_heads
//
weights
.
process_group
.
size
()
num_key_value_heads
=
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
assert
list
(
weight
.
shape
)
==
[
(
num_heads
+
2
*
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
,
],
f
"
{
list
(
weight
.
shape
)
}
!=
{
[(
num_heads
+
2
*
config
.
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
]
}
"
if
config
.
attention_bias
:
w
=
[
weights
.
get_sharded
(
f
"
{
p
}
.bias"
,
dim
=
0
)
for
p
in
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
]
]
bias
=
torch
.
cat
(
w
,
dim
=
0
).
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
else
:
bias
=
None
return
TensorParallelColumnLinear
(
get_linear
(
weight
,
bias
=
bias
,
quantize
=
config
.
quantize
)
)
class
FlashCohereAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
,
):
super
().
__init__
()
self
.
num_heads
=
config
.
num_attention_heads
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_emb
=
PositionRotaryEmbedding
.
static
(
config
=
config
,
dim
=
self
.
head_size
,
base
=
config
.
rope_theta
,
device
=
weights
.
device
,
)
self
.
softmax_scale
=
self
.
head_size
**-
0.5
if
self
.
num_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
num_key_value_heads
=
(
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
)
self
.
query_key_value
=
load_attention
(
config
,
prefix
,
weights
)
self
.
o_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
weights
=
weights
,
bias
=
config
.
attention_bias
,
)
self
.
num_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
kv_head_mapping
=
torch
.
arange
(
0
,
self
.
num_key_value_heads
,
dtype
=
torch
.
int32
,
device
=
weights
.
device
).
repeat_interleave
(
self
.
num_groups
)
def
forward
(
self
,
hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
query
,
kv
=
qkv
.
split
(
[
self
.
head_size
*
self
.
num_heads
,
2
*
self
.
head_size
*
self
.
num_key_value_heads
,
],
dim
=
1
,
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
kv
=
kv
.
view
(
-
1
,
2
,
self
.
num_key_value_heads
,
self
.
head_size
)
self
.
rotary_emb
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
cos
,
sin
)
paged_attention
.
reshape_and_cache
(
kv
[:,
0
],
kv
[:,
1
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
# output tensor
attn_output
=
torch
.
empty_like
(
query
)
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
flash_attn
.
attention
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
),
attn_output
,
cu_seqlen_prefill
,
max_s
,
self
.
softmax_scale
,
)
# Decode
else
:
paged_attention
.
attention
(
attn_output
,
query
,
kv_cache
[
0
],
kv_cache
[
1
],
self
.
kv_head_mapping
,
self
.
softmax_scale
,
block_tables
,
input_lengths
,
max_s
,
)
return
self
.
o_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
),
reduce
=
False
)
class
CohereMLP
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
act
=
config
.
hidden_act
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
(
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
),
)
)
# Fuse gate and up proj
self
.
gate_up_proj
=
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.gate_proj"
,
f
"
{
prefix
}
.up_proj"
],
weights
=
weights
,
dim
=
0
,
bias
=
False
,
)
self
.
down_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
weights
=
weights
,
bias
=
False
,
)
self
.
intermediate_size
=
(
config
.
intermediate_size
//
weights
.
process_group
.
size
()
)
def
forward
(
self
,
hidden_states
):
gate_up_states
=
self
.
gate_up_proj
(
hidden_states
)
gate_up_states
=
gate_up_states
.
view
(
-
1
,
2
,
self
.
intermediate_size
)
return
self
.
down_proj
(
self
.
act
(
gate_up_states
[:,
0
])
*
gate_up_states
[:,
1
],
reduce
=
False
)
class
FlashCohereLayer
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
,
config
,
weights
):
super
().
__init__
()
prefix
=
f
"model.layers.
{
layer_id
}
"
self
.
self_attn
=
FlashCohereAttention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
)
self
.
mlp
=
CohereMLP
(
prefix
=
f
"
{
prefix
}
.mlp"
,
config
=
config
,
weights
=
weights
)
self
.
input_layernorm
=
FastRMSNorm
.
load
(
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
layer_norm_eps
,
)
self
.
process_group
=
weights
.
process_group
def
forward
(
self
,
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
):
normed_hidden_states
,
res
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Self Attention
attn_output
=
self
.
self_attn
(
normed_hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
)
mlp_output
=
self
.
mlp
(
normed_hidden_states
)
output
=
attn_output
+
mlp_output
if
self
.
process_group
.
size
()
>
1
:
torch
.
distributed
.
all_reduce
(
output
,
group
=
self
.
process_group
)
return
output
,
res
class
FlashCohereModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
process_group
=
weights
.
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
embed_tokens
=
TensorParallelEmbedding
(
prefix
=
"model.embed_tokens"
,
weights
=
weights
)
self
.
layers
=
nn
.
ModuleList
(
[
FlashCohereLayer
(
layer_id
,
config
,
weights
,
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
FastRMSNorm
.
load
(
prefix
=
"model.norm"
,
weights
=
weights
,
eps
=
config
.
layer_norm_eps
)
self
.
gradient_checkpointing
=
False
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
self
.
num_key_value_heads
=
self
.
layers
[
0
].
self_attn
.
num_key_value_heads
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos
,
sin
=
self
.
layers
[
0
].
self_attn
.
rotary_emb
.
get_cos_sin
(
position_ids
,
max_s
,
hidden_states
.
dtype
)
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
[
i
],
block_tables
,
slots
,
input_lengths
,
max_s
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
FlashCohereForCausalLM
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
self
.
model
=
FlashCohereModel
(
config
,
weights
)
try
:
self
.
lm_head
=
SpeculativeHead
.
load
(
config
,
prefix
=
"lm_head"
,
weights
=
weights
,
)
except
RuntimeError
:
self
.
lm_head
=
SpeculativeHead
.
load
(
config
,
prefix
=
"model.embed_tokens"
,
weights
=
weights
,
)
self
.
logit_scale
=
config
.
logit_scale
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
hidden_states
=
self
.
model
(
input_ids
,
position_ids
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
)
if
lm_head_indices
is
not
None
:
hidden_states
=
hidden_states
[
lm_head_indices
]
logits
,
speculative_logits
=
self
.
lm_head
(
hidden_states
)
logits
*=
self
.
logit_scale
if
speculative_logits
is
not
None
:
speculative_logits
*=
self
.
logit_scale
return
logits
,
speculative_logits
server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
View file @
1e9bcd9d
...
@@ -20,16 +20,11 @@
...
@@ -20,16 +20,11 @@
import
torch
import
torch
import
torch.distributed
import
torch.distributed
import
os
from
shutil
import
copyfile
from
torch
import
nn
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
from
typing
import
Optional
,
List
,
Tuple
from
tokenizers
import
processors
from
transformers.tokenization_utils_fast
import
PreTrainedTokenizerFast
from
transformers.utils
import
logging
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.layers
import
(
from
text_generation_server.utils.layers
import
(
...
@@ -42,162 +37,6 @@ from text_generation_server.utils.layers import (
...
@@ -42,162 +37,6 @@ from text_generation_server.utils.layers import (
FastRMSNorm
,
FastRMSNorm
,
)
)
GemmaTokenizer
=
None
logger
=
logging
.
get_logger
(
__name__
)
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"tokenizer.model"
,
"tokenizer_file"
:
"tokenizer.json"
,
}
PRETRAINED_VOCAB_FILES_MAP
=
{
"vocab_file"
:
{
"hf-internal-testing/llama-tokenizer"
:
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model"
,
},
"tokenizer_file"
:
{
"hf-internal-testing/llama-tokenizer"
:
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json"
,
},
}
B_INST
,
E_INST
=
"[INST]"
,
"[/INST]"
B_SYS
,
E_SYS
=
"<<SYS>>
\n
"
,
"
\n
<</SYS>>
\n\n
"
# fmt: off
DEFAULT_SYSTEM_PROMPT
=
"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your
\
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure
\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not
\
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class
GemmaTokenizerFast
(
PreTrainedTokenizerFast
):
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
slow_tokenizer_class
=
GemmaTokenizer
padding_side
=
"left"
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
def
__init__
(
self
,
vocab_file
=
None
,
tokenizer_file
=
None
,
clean_up_tokenization_spaces
=
False
,
unk_token
=
"<unk>"
,
bos_token
=
"<bos>"
,
eos_token
=
"<eos>"
,
pad_token
=
"<pad>"
,
add_bos_token
=
True
,
add_eos_token
=
False
,
use_default_system_prompt
=
False
,
**
kwargs
,
):
super
().
__init__
(
vocab_file
=
vocab_file
,
tokenizer_file
=
tokenizer_file
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
,
unk_token
=
unk_token
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
pad_token
=
pad_token
,
add_bos_token
=
add_bos_token
,
add_eos_token
=
add_eos_token
,
use_default_system_prompt
=
use_default_system_prompt
,
**
kwargs
,
)
self
.
_add_bos_token
=
add_bos_token
self
.
_add_eos_token
=
add_eos_token
self
.
update_post_processor
()
self
.
use_default_system_prompt
=
use_default_system_prompt
self
.
vocab_file
=
vocab_file
@
property
def
can_save_slow_tokenizer
(
self
)
->
bool
:
return
os
.
path
.
isfile
(
self
.
vocab_file
)
if
self
.
vocab_file
else
False
def
update_post_processor
(
self
):
"""
Updates the underlying post processor with the current `bos_token` and `eos_token`.
"""
bos
=
self
.
bos_token
bos_token_id
=
self
.
bos_token_id
if
bos
is
None
and
self
.
add_bos_token
:
raise
ValueError
(
"add_bos_token = True but bos_token = None"
)
eos
=
self
.
eos_token
eos_token_id
=
self
.
eos_token_id
if
eos
is
None
and
self
.
add_eos_token
:
raise
ValueError
(
"add_eos_token = True but eos_token = None"
)
single
=
f
"
{
(
bos
+
':0 '
)
if
self
.
add_bos_token
else
''
}
$A:0
{
(
' '
+
eos
+
':0'
)
if
self
.
add_eos_token
else
''
}
"
pair
=
f
"
{
single
}{
(
' '
+
bos
+
':1'
)
if
self
.
add_bos_token
else
''
}
$B:1
{
(
' '
+
eos
+
':1'
)
if
self
.
add_eos_token
else
''
}
"
special_tokens
=
[]
if
self
.
add_bos_token
:
special_tokens
.
append
((
bos
,
bos_token_id
))
if
self
.
add_eos_token
:
special_tokens
.
append
((
eos
,
eos_token_id
))
self
.
_tokenizer
.
post_processor
=
processors
.
TemplateProcessing
(
single
=
single
,
pair
=
pair
,
special_tokens
=
special_tokens
)
@
property
def
add_eos_token
(
self
):
return
self
.
_add_eos_token
@
property
def
add_bos_token
(
self
):
return
self
.
_add_bos_token
@
add_eos_token
.
setter
def
add_eos_token
(
self
,
value
):
self
.
_add_eos_token
=
value
self
.
update_post_processor
()
@
add_bos_token
.
setter
def
add_bos_token
(
self
,
value
):
self
.
_add_bos_token
=
value
self
.
update_post_processor
()
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
if
not
self
.
can_save_slow_tokenizer
:
raise
ValueError
(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
],
)
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
@
property
def
default_chat_template
(
self
):
raise
NotImplementedError
# TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
,
token_ids_1
=
None
):
bos_token_id
=
[
self
.
bos_token_id
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
self
.
eos_token_id
]
if
self
.
add_eos_token
else
[]
output
=
bos_token_id
+
token_ids_0
+
eos_token_id
if
token_ids_1
is
not
None
:
output
=
output
+
bos_token_id
+
token_ids_1
+
eos_token_id
return
output
class
GemmaConfig
(
PretrainedConfig
):
class
GemmaConfig
(
PretrainedConfig
):
def
__init__
(
def
__init__
(
...
...
server/text_generation_server/models/flash_cohere.py
0 → 100644
View file @
1e9bcd9d
import
torch
import
torch.distributed
from
opentelemetry
import
trace
from
typing
import
Optional
from
transformers.models.llama
import
LlamaTokenizerFast
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_cohere_modeling
import
(
FlashCohereForCausalLM
,
CohereConfig
,
)
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
Weights
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
class
FlashCohere
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
use_medusa
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
bfloat16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashCohere is only available on GPU"
)
tokenizer
=
LlamaTokenizerFast
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
trust_remote_code
=
trust_remote_code
,
use_fast
=
True
,
from_slow
=
False
,
)
config
=
CohereConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
config
.
quantize
=
quantize
config
.
use_medusa
=
use_medusa
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
if
config
.
quantize
in
[
"gptq"
,
"awq"
]:
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
FlashCohereForCausalLM
(
config
,
weights
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCohere
,
self
).
__init__
(
model
=
model
,
tokenizer
=
tokenizer
,
num_layers
=
len
(
model
.
model
.
layers
),
num_kv_heads
=
model
.
model
.
num_key_value_heads
,
head_size
=
model
.
model
.
head_size
,
dtype
=
dtype
,
device
=
device
,
rank
=
rank
,
world_size
=
world_size
,
)
server/text_generation_server/models/flash_gemma.py
View file @
1e9bcd9d
...
@@ -3,10 +3,10 @@ import torch.distributed
...
@@ -3,10 +3,10 @@ import torch.distributed
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
typing
import
Optional
from
typing
import
Optional
from
transformers.models.gemma
import
GemmaTokenizerFast
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_gemma_modeling
import
(
from
text_generation_server.models.custom_modeling.flash_gemma_modeling
import
(
GemmaTokenizerFast
,
FlashGemmaForCausalLM
,
FlashGemmaForCausalLM
,
GemmaConfig
,
GemmaConfig
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment