Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
1e9bcd9d
Unverified
Commit
1e9bcd9d
authored
Mar 22, 2024
by
OlivierDehaene
Committed by
GitHub
Mar 22, 2024
Browse files
feat: cohere (#1660)
parent
f171bdc8
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1182 additions
and
785 deletions
+1182
-785
server/Makefile
server/Makefile
+1
-1
server/poetry.lock
server/poetry.lock
+590
-548
server/pyproject.toml
server/pyproject.toml
+3
-3
server/requirements_common.txt
server/requirements_common.txt
+0
-46
server/requirements_cuda.txt
server/requirements_cuda.txt
+12
-13
server/requirements_rocm.txt
server/requirements_rocm.txt
+12
-12
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+27
-0
server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
...on_server/models/custom_modeling/flash_cohere_modeling.py
+461
-0
server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
...ion_server/models/custom_modeling/flash_gemma_modeling.py
+0
-161
server/text_generation_server/models/flash_cohere.py
server/text_generation_server/models/flash_cohere.py
+75
-0
server/text_generation_server/models/flash_gemma.py
server/text_generation_server/models/flash_gemma.py
+1
-1
No files found.
server/Makefile
View file @
1e9bcd9d
...
...
@@ -29,5 +29,5 @@ run-dev:
SAFETENSORS_FAST_GPU
=
1 python
-m
torch.distributed.run
--nproc_per_node
=
2 text_generation_server/cli.py serve bigscience/bloom-560m
--sharded
export-requirements
:
poetry
export
-o
requirements_cuda.txt
--extras
bnb
--without-hashes
poetry
export
-o
requirements_cuda.txt
--without-hashes
poetry
export
-o
requirements_rocm.txt
--without-hashes
server/poetry.lock
View file @
1e9bcd9d
This diff is collapsed.
Click to expand it.
server/pyproject.toml
View file @
1e9bcd9d
...
...
@@ -17,7 +17,7 @@ grpc-interceptor = "^0.15.0"
typer
=
"^0.6.1"
accelerate
=
{
version
=
"^0.28.0"
,
optional
=
true
}
bitsandbytes
=
{
version
=
"^0.43.0"
,
optional
=
true
}
safetensors
=
"^0.4
.1
"
safetensors
=
"^0.4"
loguru
=
"^0.6.0"
opentelemetry-api
=
"^1.15.0"
opentelemetry-exporter-otlp
=
"^1.15.0"
...
...
@@ -26,11 +26,11 @@ hf-transfer = "^0.1.2"
sentencepiece
=
"^0.1.97"
tokenizers
=
"^0.15.0"
huggingface-hub
=
"^0.19.3"
transformers
=
"^4.38
.2
"
transformers
=
"^4.38"
einops
=
"^0.6.1"
texttable
=
{
version
=
"^1.6.7"
,
optional
=
true
}
datasets
=
{
version
=
"^2.14.0"
,
optional
=
true
}
peft
=
{
version
=
"^0.9
.0
"
,
optional
=
true
}
peft
=
{
version
=
"^0.9"
,
optional
=
true
}
torch
=
{
version
=
"^2.1.1"
,
optional
=
true
}
scipy
=
"^1.11.1"
pillow
=
"^10.0.0"
...
...
server/requirements_common.txt
deleted
100644 → 0
View file @
f171bdc8
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.61.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.59.3 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.4 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.16.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.2 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.1.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.1 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.10.3 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.11.4 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.2 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.13.3 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.33.3 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
server/requirements_cuda.txt
View file @
1e9bcd9d
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
bitsandbytes==0.41.3.post2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
...
...
@@ -7,13 +6,13 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==202
3.10
.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.6
2
.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==202
4.2
.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.6
3
.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.6
0
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.6
0
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.6
0
.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.
5
; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.6
2
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.6
2
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.6
2
.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.
6
; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
...
...
@@ -27,21 +26,21 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==2
3.2
; python_version >= "3.9" and python_version < "3.13"
packaging==2
4.0
; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.
3.3
; python_version >= "3.9" and python_version < "3.13"
safetensors==0.
4.2
; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.
1
.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.
2
.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.3
7.1
; python_version >= "3.9" and python_version < "3.13"
transformers==4.3
9.0
; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.
9
.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.
0
; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.
10
.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.
1
; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
server/requirements_rocm.txt
View file @
1e9bcd9d
...
...
@@ -6,13 +6,13 @@ colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_p
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
einops==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==202
3.10
.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.6
2
.0 ; python_version >= "3.9" and python_version < "3.13"
fsspec==202
4.2
.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.6
3
.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.6
0
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.6
0
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.6
0
.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.
5
; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.6
2
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.6
2
.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.6
2
.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.
6
; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
...
...
@@ -26,21 +26,21 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==2
3.2
; python_version >= "3.9" and python_version < "3.13"
packaging==2
4.0
; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.
3.3
; python_version >= "3.9" and python_version < "3.13"
safetensors==0.
4.2
; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.
1
.0 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.
2
.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.3
7.1
; python_version >= "3.9" and python_version < "3.13"
transformers==4.3
9.0
; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.
9
.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.
0
; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.
10
.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.2.
1
; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
server/text_generation_server/models/__init__.py
View file @
1e9bcd9d
...
...
@@ -57,6 +57,9 @@ try:
from
text_generation_server.models.flash_qwen2
import
(
FlashQwen2
,
)
from
text_generation_server.models.flash_cohere
import
(
FlashCohere
,
)
from
text_generation_server.models.flash_gemma
import
(
FlashGemma
,
)
...
...
@@ -86,6 +89,8 @@ if FLASH_ATTENTION:
__all__
.
append
(
FlashPhi
)
__all__
.
append
(
FlashQwen2
)
__all__
.
append
(
FlashStarcoder2
)
__all__
.
append
(
FlashGemma
)
__all__
.
append
(
FlashCohere
)
MAMBA_AVAILABLE
=
True
try
:
...
...
@@ -354,6 +359,28 @@ def get_model(
trust_remote_code
=
trust_remote_code
,
)
if
model_type
==
"cohere"
:
if
FLASH_ATTENTION
:
return
FlashCohere
(
model_id
,
revision
,
quantize
=
quantize
,
use_medusa
=
use_medusa
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
elif
sharded
:
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
"Sharded Cohere"
))
else
:
return
CausalLM
(
model_id
,
revision
,
quantize
=
quantize
,
use_medusa
=
use_medusa
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
model_type
in
[
"RefinedWeb"
,
"RefinedWebModel"
,
"falcon"
]:
if
sharded
:
if
FLASH_ATTENTION
:
...
...
server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py
0 → 100644
View file @
1e9bcd9d
# coding=utf-8
# Copyright 2024 Cohere team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.distributed
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
PositionRotaryEmbedding
,
SpeculativeHead
,
get_linear
,
FastRMSNorm
,
)
class
CohereConfig
(
PretrainedConfig
):
def
__init__
(
self
,
vocab_size
=
256000
,
hidden_size
=
8192
,
intermediate_size
=
22528
,
num_hidden_layers
=
40
,
num_attention_heads
=
64
,
num_key_value_heads
=
None
,
hidden_act
=
"silu"
,
max_position_embeddings
=
8192
,
initializer_range
=
0.02
,
layer_norm_eps
=
1e-5
,
use_cache
=
True
,
pad_token_id
=
0
,
bos_token_id
=
5
,
eos_token_id
=
255001
,
pretraining_tp
=
1
,
tie_word_embeddings
=
True
,
rope_theta
=
10000.0
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
logit_scale
=
1.0
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
pretraining_tp
=
pretraining_tp
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
self
.
logit_scale
=
logit_scale
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
def
load_attention
(
config
,
prefix
,
weights
):
if
config
.
num_attention_heads
!=
config
.
num_key_value_heads
:
return
_load_gqa
(
config
,
prefix
,
weights
)
else
:
return
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
dim
=
0
,
weights
=
weights
,
bias
=
config
.
attention_bias
,
)
def
_load_gqa
(
config
,
prefix
:
str
,
weights
):
assert
config
.
hidden_size
%
config
.
num_attention_heads
==
0
assert
config
.
num_attention_heads
%
weights
.
process_group
.
size
()
==
0
weight
=
weights
.
get_multi_weights_col
(
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
quantize
=
config
.
quantize
,
dim
=
0
,
)
if
config
.
quantize
not
in
[
"gptq"
,
"awq"
]:
weight
=
weight
.
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
head_size
=
config
.
hidden_size
//
config
.
num_attention_heads
num_heads
=
config
.
num_attention_heads
//
weights
.
process_group
.
size
()
num_key_value_heads
=
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
assert
list
(
weight
.
shape
)
==
[
(
num_heads
+
2
*
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
,
],
f
"
{
list
(
weight
.
shape
)
}
!=
{
[(
num_heads
+
2
*
config
.
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
]
}
"
if
config
.
attention_bias
:
w
=
[
weights
.
get_sharded
(
f
"
{
p
}
.bias"
,
dim
=
0
)
for
p
in
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
]
]
bias
=
torch
.
cat
(
w
,
dim
=
0
).
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
else
:
bias
=
None
return
TensorParallelColumnLinear
(
get_linear
(
weight
,
bias
=
bias
,
quantize
=
config
.
quantize
)
)
class
FlashCohereAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
,
):
super
().
__init__
()
self
.
num_heads
=
config
.
num_attention_heads
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
num_heads
self
.
rotary_emb
=
PositionRotaryEmbedding
.
static
(
config
=
config
,
dim
=
self
.
head_size
,
base
=
config
.
rope_theta
,
device
=
weights
.
device
,
)
self
.
softmax_scale
=
self
.
head_size
**-
0.5
if
self
.
num_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
num_key_value_heads
=
(
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
)
self
.
query_key_value
=
load_attention
(
config
,
prefix
,
weights
)
self
.
o_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
weights
=
weights
,
bias
=
config
.
attention_bias
,
)
self
.
num_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
kv_head_mapping
=
torch
.
arange
(
0
,
self
.
num_key_value_heads
,
dtype
=
torch
.
int32
,
device
=
weights
.
device
).
repeat_interleave
(
self
.
num_groups
)
def
forward
(
self
,
hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
query
,
kv
=
qkv
.
split
(
[
self
.
head_size
*
self
.
num_heads
,
2
*
self
.
head_size
*
self
.
num_key_value_heads
,
],
dim
=
1
,
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
kv
=
kv
.
view
(
-
1
,
2
,
self
.
num_key_value_heads
,
self
.
head_size
)
self
.
rotary_emb
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
cos
,
sin
)
paged_attention
.
reshape_and_cache
(
kv
[:,
0
],
kv
[:,
1
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
# output tensor
attn_output
=
torch
.
empty_like
(
query
)
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
flash_attn
.
attention
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
),
attn_output
,
cu_seqlen_prefill
,
max_s
,
self
.
softmax_scale
,
)
# Decode
else
:
paged_attention
.
attention
(
attn_output
,
query
,
kv_cache
[
0
],
kv_cache
[
1
],
self
.
kv_head_mapping
,
self
.
softmax_scale
,
block_tables
,
input_lengths
,
max_s
,
)
return
self
.
o_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
),
reduce
=
False
)
class
CohereMLP
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
act
=
config
.
hidden_act
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
(
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
),
)
)
# Fuse gate and up proj
self
.
gate_up_proj
=
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.gate_proj"
,
f
"
{
prefix
}
.up_proj"
],
weights
=
weights
,
dim
=
0
,
bias
=
False
,
)
self
.
down_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
weights
=
weights
,
bias
=
False
,
)
self
.
intermediate_size
=
(
config
.
intermediate_size
//
weights
.
process_group
.
size
()
)
def
forward
(
self
,
hidden_states
):
gate_up_states
=
self
.
gate_up_proj
(
hidden_states
)
gate_up_states
=
gate_up_states
.
view
(
-
1
,
2
,
self
.
intermediate_size
)
return
self
.
down_proj
(
self
.
act
(
gate_up_states
[:,
0
])
*
gate_up_states
[:,
1
],
reduce
=
False
)
class
FlashCohereLayer
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
,
config
,
weights
):
super
().
__init__
()
prefix
=
f
"model.layers.
{
layer_id
}
"
self
.
self_attn
=
FlashCohereAttention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
)
self
.
mlp
=
CohereMLP
(
prefix
=
f
"
{
prefix
}
.mlp"
,
config
=
config
,
weights
=
weights
)
self
.
input_layernorm
=
FastRMSNorm
.
load
(
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
layer_norm_eps
,
)
self
.
process_group
=
weights
.
process_group
def
forward
(
self
,
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
):
normed_hidden_states
,
res
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Self Attention
attn_output
=
self
.
self_attn
(
normed_hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
)
mlp_output
=
self
.
mlp
(
normed_hidden_states
)
output
=
attn_output
+
mlp_output
if
self
.
process_group
.
size
()
>
1
:
torch
.
distributed
.
all_reduce
(
output
,
group
=
self
.
process_group
)
return
output
,
res
class
FlashCohereModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
process_group
=
weights
.
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
embed_tokens
=
TensorParallelEmbedding
(
prefix
=
"model.embed_tokens"
,
weights
=
weights
)
self
.
layers
=
nn
.
ModuleList
(
[
FlashCohereLayer
(
layer_id
,
config
,
weights
,
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
FastRMSNorm
.
load
(
prefix
=
"model.norm"
,
weights
=
weights
,
eps
=
config
.
layer_norm_eps
)
self
.
gradient_checkpointing
=
False
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
self
.
num_key_value_heads
=
self
.
layers
[
0
].
self_attn
.
num_key_value_heads
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos
,
sin
=
self
.
layers
[
0
].
self_attn
.
rotary_emb
.
get_cos_sin
(
position_ids
,
max_s
,
hidden_states
.
dtype
)
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
[
i
],
block_tables
,
slots
,
input_lengths
,
max_s
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
FlashCohereForCausalLM
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
self
.
model
=
FlashCohereModel
(
config
,
weights
)
try
:
self
.
lm_head
=
SpeculativeHead
.
load
(
config
,
prefix
=
"lm_head"
,
weights
=
weights
,
)
except
RuntimeError
:
self
.
lm_head
=
SpeculativeHead
.
load
(
config
,
prefix
=
"model.embed_tokens"
,
weights
=
weights
,
)
self
.
logit_scale
=
config
.
logit_scale
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
hidden_states
=
self
.
model
(
input_ids
,
position_ids
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
)
if
lm_head_indices
is
not
None
:
hidden_states
=
hidden_states
[
lm_head_indices
]
logits
,
speculative_logits
=
self
.
lm_head
(
hidden_states
)
logits
*=
self
.
logit_scale
if
speculative_logits
is
not
None
:
speculative_logits
*=
self
.
logit_scale
return
logits
,
speculative_logits
server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
View file @
1e9bcd9d
...
...
@@ -20,16 +20,11 @@
import
torch
import
torch.distributed
import
os
from
shutil
import
copyfile
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
from
tokenizers
import
processors
from
transformers.tokenization_utils_fast
import
PreTrainedTokenizerFast
from
transformers.utils
import
logging
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.layers
import
(
...
...
@@ -42,162 +37,6 @@ from text_generation_server.utils.layers import (
FastRMSNorm
,
)
GemmaTokenizer
=
None
logger
=
logging
.
get_logger
(
__name__
)
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"tokenizer.model"
,
"tokenizer_file"
:
"tokenizer.json"
,
}
PRETRAINED_VOCAB_FILES_MAP
=
{
"vocab_file"
:
{
"hf-internal-testing/llama-tokenizer"
:
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model"
,
},
"tokenizer_file"
:
{
"hf-internal-testing/llama-tokenizer"
:
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json"
,
},
}
B_INST
,
E_INST
=
"[INST]"
,
"[/INST]"
B_SYS
,
E_SYS
=
"<<SYS>>
\n
"
,
"
\n
<</SYS>>
\n\n
"
# fmt: off
DEFAULT_SYSTEM_PROMPT
=
"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your
\
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure
\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not
\
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class
GemmaTokenizerFast
(
PreTrainedTokenizerFast
):
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
slow_tokenizer_class
=
GemmaTokenizer
padding_side
=
"left"
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
def
__init__
(
self
,
vocab_file
=
None
,
tokenizer_file
=
None
,
clean_up_tokenization_spaces
=
False
,
unk_token
=
"<unk>"
,
bos_token
=
"<bos>"
,
eos_token
=
"<eos>"
,
pad_token
=
"<pad>"
,
add_bos_token
=
True
,
add_eos_token
=
False
,
use_default_system_prompt
=
False
,
**
kwargs
,
):
super
().
__init__
(
vocab_file
=
vocab_file
,
tokenizer_file
=
tokenizer_file
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
,
unk_token
=
unk_token
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
pad_token
=
pad_token
,
add_bos_token
=
add_bos_token
,
add_eos_token
=
add_eos_token
,
use_default_system_prompt
=
use_default_system_prompt
,
**
kwargs
,
)
self
.
_add_bos_token
=
add_bos_token
self
.
_add_eos_token
=
add_eos_token
self
.
update_post_processor
()
self
.
use_default_system_prompt
=
use_default_system_prompt
self
.
vocab_file
=
vocab_file
@
property
def
can_save_slow_tokenizer
(
self
)
->
bool
:
return
os
.
path
.
isfile
(
self
.
vocab_file
)
if
self
.
vocab_file
else
False
def
update_post_processor
(
self
):
"""
Updates the underlying post processor with the current `bos_token` and `eos_token`.
"""
bos
=
self
.
bos_token
bos_token_id
=
self
.
bos_token_id
if
bos
is
None
and
self
.
add_bos_token
:
raise
ValueError
(
"add_bos_token = True but bos_token = None"
)
eos
=
self
.
eos_token
eos_token_id
=
self
.
eos_token_id
if
eos
is
None
and
self
.
add_eos_token
:
raise
ValueError
(
"add_eos_token = True but eos_token = None"
)
single
=
f
"
{
(
bos
+
':0 '
)
if
self
.
add_bos_token
else
''
}
$A:0
{
(
' '
+
eos
+
':0'
)
if
self
.
add_eos_token
else
''
}
"
pair
=
f
"
{
single
}{
(
' '
+
bos
+
':1'
)
if
self
.
add_bos_token
else
''
}
$B:1
{
(
' '
+
eos
+
':1'
)
if
self
.
add_eos_token
else
''
}
"
special_tokens
=
[]
if
self
.
add_bos_token
:
special_tokens
.
append
((
bos
,
bos_token_id
))
if
self
.
add_eos_token
:
special_tokens
.
append
((
eos
,
eos_token_id
))
self
.
_tokenizer
.
post_processor
=
processors
.
TemplateProcessing
(
single
=
single
,
pair
=
pair
,
special_tokens
=
special_tokens
)
@
property
def
add_eos_token
(
self
):
return
self
.
_add_eos_token
@
property
def
add_bos_token
(
self
):
return
self
.
_add_bos_token
@
add_eos_token
.
setter
def
add_eos_token
(
self
,
value
):
self
.
_add_eos_token
=
value
self
.
update_post_processor
()
@
add_bos_token
.
setter
def
add_bos_token
(
self
,
value
):
self
.
_add_bos_token
=
value
self
.
update_post_processor
()
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
if
not
self
.
can_save_slow_tokenizer
:
raise
ValueError
(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
],
)
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
@
property
def
default_chat_template
(
self
):
raise
NotImplementedError
# TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
,
token_ids_1
=
None
):
bos_token_id
=
[
self
.
bos_token_id
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
self
.
eos_token_id
]
if
self
.
add_eos_token
else
[]
output
=
bos_token_id
+
token_ids_0
+
eos_token_id
if
token_ids_1
is
not
None
:
output
=
output
+
bos_token_id
+
token_ids_1
+
eos_token_id
return
output
class
GemmaConfig
(
PretrainedConfig
):
def
__init__
(
...
...
server/text_generation_server/models/flash_cohere.py
0 → 100644
View file @
1e9bcd9d
import
torch
import
torch.distributed
from
opentelemetry
import
trace
from
typing
import
Optional
from
transformers.models.llama
import
LlamaTokenizerFast
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_cohere_modeling
import
(
FlashCohereForCausalLM
,
CohereConfig
,
)
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
Weights
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
class
FlashCohere
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
use_medusa
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
bfloat16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashCohere is only available on GPU"
)
tokenizer
=
LlamaTokenizerFast
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
trust_remote_code
=
trust_remote_code
,
use_fast
=
True
,
from_slow
=
False
,
)
config
=
CohereConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
config
.
quantize
=
quantize
config
.
use_medusa
=
use_medusa
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
if
config
.
quantize
in
[
"gptq"
,
"awq"
]:
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
FlashCohereForCausalLM
(
config
,
weights
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCohere
,
self
).
__init__
(
model
=
model
,
tokenizer
=
tokenizer
,
num_layers
=
len
(
model
.
model
.
layers
),
num_kv_heads
=
model
.
model
.
num_key_value_heads
,
head_size
=
model
.
model
.
head_size
,
dtype
=
dtype
,
device
=
device
,
rank
=
rank
,
world_size
=
world_size
,
)
server/text_generation_server/models/flash_gemma.py
View file @
1e9bcd9d
...
...
@@ -3,10 +3,10 @@ import torch.distributed
from
opentelemetry
import
trace
from
typing
import
Optional
from
transformers.models.gemma
import
GemmaTokenizerFast
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_gemma_modeling
import
(
GemmaTokenizerFast
,
FlashGemmaForCausalLM
,
GemmaConfig
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment