Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
c86f58d3
Unverified
Commit
c86f58d3
authored
Feb 21, 2024
by
OlivierDehaene
Committed by
GitHub
Feb 21, 2024
Browse files
feat: add support for Gemma (#1583)
parent
fa8a8e05
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1338 additions
and
0 deletions
+1338
-0
integration-tests/conftest.py
integration-tests/conftest.py
+3
-0
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json
...dels/__snapshots__/test_flash_gemma/test_flash_gemma.json
+89
-0
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json
...shots__/test_flash_gemma/test_flash_gemma_all_params.json
+89
-0
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json
...__snapshots__/test_flash_gemma/test_flash_gemma_load.json
+358
-0
integration-tests/models/test_flash_gemma.py
integration-tests/models/test_flash_gemma.py
+61
-0
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+25
-0
server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
...ion_server/models/custom_modeling/flash_gemma_modeling.py
+609
-0
server/text_generation_server/models/flash_gemma.py
server/text_generation_server/models/flash_gemma.py
+104
-0
No files found.
integration-tests/conftest.py
View file @
c86f58d3
...
...
@@ -40,6 +40,9 @@ class ResponseComparator(JSONSnapshotExtension):
exclude
=
None
,
matcher
=
None
,
):
if
isinstance
(
data
,
Response
):
data
=
data
.
dict
()
if
isinstance
(
data
,
List
):
data
=
[
d
.
dict
()
for
d
in
data
]
...
...
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json
0 → 100644
View file @
c86f58d3
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
2
,
"logprob"
:
null
,
"text"
:
"<bos>"
},
{
"id"
:
2015
,
"logprob"
:
-10.0
,
"text"
:
"Test"
},
{
"id"
:
3853
,
"logprob"
:
-10.875
,
"text"
:
" request"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
1736
,
"logprob"
:
-2.09375
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
109
,
"logprob"
:
-1.8671875
,
"special"
:
false
,
"text"
:
"
\n\n
"
},
{
"id"
:
651
,
"logprob"
:
-2.4375
,
"special"
:
false
,
"text"
:
"The"
},
{
"id"
:
2121
,
"logprob"
:
-1.8203125
,
"special"
:
false
,
"text"
:
" test"
},
{
"id"
:
3853
,
"logprob"
:
-0.23242188
,
"special"
:
false
,
"text"
:
" request"
},
{
"id"
:
1736
,
"logprob"
:
-0.08544922
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
603
,
"logprob"
:
-0.9375
,
"special"
:
false
,
"text"
:
" is"
},
{
"id"
:
1671
,
"logprob"
:
-1.671875
,
"special"
:
false
,
"text"
:
" used"
},
{
"id"
:
577
,
"logprob"
:
-0.40429688
,
"special"
:
false
,
"text"
:
" to"
},
{
"id"
:
3853
,
"logprob"
:
-1.1875
,
"special"
:
false
,
"text"
:
" request"
}
],
"top_tokens"
:
null
},
"generated_text"
:
" form
\n\n
The test request form is used to request"
}
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json
0 → 100644
View file @
c86f58d3
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
2
,
"logprob"
:
null
,
"text"
:
"<bos>"
},
{
"id"
:
2015
,
"logprob"
:
-10.0
,
"text"
:
"Test"
},
{
"id"
:
3853
,
"logprob"
:
-10.875
,
"text"
:
" request"
}
],
"seed"
:
0
,
"tokens"
:
[
{
"id"
:
7539
,
"logprob"
:
-0.73046875
,
"special"
:
false
,
"text"
:
" forms"
},
{
"id"
:
708
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" are"
},
{
"id"
:
671
,
"logprob"
:
-1.703125
,
"special"
:
false
,
"text"
:
" an"
},
{
"id"
:
8727
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" essential"
},
{
"id"
:
1702
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" part"
},
{
"id"
:
576
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" of"
},
{
"id"
:
573
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" the"
},
{
"id"
:
11859
,
"logprob"
:
-1.6953125
,
"special"
:
false
,
"text"
:
" lab"
},
{
"id"
:
2185
,
"logprob"
:
-1.3125
,
"special"
:
false
,
"text"
:
" process"
},
{
"id"
:
578
,
"logprob"
:
-1.5
,
"special"
:
false
,
"text"
:
" and"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"Test request forms are an essential part of the lab process and"
}
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json
0 → 100644
View file @
c86f58d3
[
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
2
,
"logprob"
:
null
,
"text"
:
"<bos>"
},
{
"id"
:
2015
,
"logprob"
:
-10.0
,
"text"
:
"Test"
},
{
"id"
:
3853
,
"logprob"
:
-10.875
,
"text"
:
" request"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
1736
,
"logprob"
:
-2.09375
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
109
,
"logprob"
:
-1.9140625
,
"special"
:
false
,
"text"
:
"
\n\n
"
},
{
"id"
:
651
,
"logprob"
:
-2.453125
,
"special"
:
false
,
"text"
:
"The"
},
{
"id"
:
2121
,
"logprob"
:
-1.8984375
,
"special"
:
false
,
"text"
:
" test"
},
{
"id"
:
3853
,
"logprob"
:
-0.23535156
,
"special"
:
false
,
"text"
:
" request"
},
{
"id"
:
1736
,
"logprob"
:
-0.091308594
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
603
,
"logprob"
:
-0.96875
,
"special"
:
false
,
"text"
:
" is"
},
{
"id"
:
1671
,
"logprob"
:
-1.6484375
,
"special"
:
false
,
"text"
:
" used"
},
{
"id"
:
577
,
"logprob"
:
-0.43164062
,
"special"
:
false
,
"text"
:
" to"
},
{
"id"
:
3853
,
"logprob"
:
-1.2421875
,
"special"
:
false
,
"text"
:
" request"
}
],
"top_tokens"
:
null
},
"generated_text"
:
" form
\n\n
The test request form is used to request"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
2
,
"logprob"
:
null
,
"text"
:
"<bos>"
},
{
"id"
:
2015
,
"logprob"
:
-10.0
,
"text"
:
"Test"
},
{
"id"
:
3853
,
"logprob"
:
-10.875
,
"text"
:
" request"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
1736
,
"logprob"
:
-2.09375
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
109
,
"logprob"
:
-1.9140625
,
"special"
:
false
,
"text"
:
"
\n\n
"
},
{
"id"
:
651
,
"logprob"
:
-2.453125
,
"special"
:
false
,
"text"
:
"The"
},
{
"id"
:
2121
,
"logprob"
:
-1.8984375
,
"special"
:
false
,
"text"
:
" test"
},
{
"id"
:
3853
,
"logprob"
:
-0.23535156
,
"special"
:
false
,
"text"
:
" request"
},
{
"id"
:
1736
,
"logprob"
:
-0.091308594
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
603
,
"logprob"
:
-0.96875
,
"special"
:
false
,
"text"
:
" is"
},
{
"id"
:
1671
,
"logprob"
:
-1.6484375
,
"special"
:
false
,
"text"
:
" used"
},
{
"id"
:
577
,
"logprob"
:
-0.43164062
,
"special"
:
false
,
"text"
:
" to"
},
{
"id"
:
3853
,
"logprob"
:
-1.2421875
,
"special"
:
false
,
"text"
:
" request"
}
],
"top_tokens"
:
null
},
"generated_text"
:
" form
\n\n
The test request form is used to request"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
2
,
"logprob"
:
null
,
"text"
:
"<bos>"
},
{
"id"
:
2015
,
"logprob"
:
-10.0
,
"text"
:
"Test"
},
{
"id"
:
3853
,
"logprob"
:
-10.875
,
"text"
:
" request"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
1736
,
"logprob"
:
-2.09375
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
109
,
"logprob"
:
-1.9140625
,
"special"
:
false
,
"text"
:
"
\n\n
"
},
{
"id"
:
651
,
"logprob"
:
-2.453125
,
"special"
:
false
,
"text"
:
"The"
},
{
"id"
:
2121
,
"logprob"
:
-1.8984375
,
"special"
:
false
,
"text"
:
" test"
},
{
"id"
:
3853
,
"logprob"
:
-0.23535156
,
"special"
:
false
,
"text"
:
" request"
},
{
"id"
:
1736
,
"logprob"
:
-0.091308594
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
603
,
"logprob"
:
-0.96875
,
"special"
:
false
,
"text"
:
" is"
},
{
"id"
:
1671
,
"logprob"
:
-1.6484375
,
"special"
:
false
,
"text"
:
" used"
},
{
"id"
:
577
,
"logprob"
:
-0.43164062
,
"special"
:
false
,
"text"
:
" to"
},
{
"id"
:
3853
,
"logprob"
:
-1.2421875
,
"special"
:
false
,
"text"
:
" request"
}
],
"top_tokens"
:
null
},
"generated_text"
:
" form
\n\n
The test request form is used to request"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
2
,
"logprob"
:
null
,
"text"
:
"<bos>"
},
{
"id"
:
2015
,
"logprob"
:
-10.0
,
"text"
:
"Test"
},
{
"id"
:
3853
,
"logprob"
:
-10.875
,
"text"
:
" request"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
1736
,
"logprob"
:
-2.09375
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
109
,
"logprob"
:
-1.9140625
,
"special"
:
false
,
"text"
:
"
\n\n
"
},
{
"id"
:
651
,
"logprob"
:
-2.453125
,
"special"
:
false
,
"text"
:
"The"
},
{
"id"
:
2121
,
"logprob"
:
-1.8984375
,
"special"
:
false
,
"text"
:
" test"
},
{
"id"
:
3853
,
"logprob"
:
-0.23535156
,
"special"
:
false
,
"text"
:
" request"
},
{
"id"
:
1736
,
"logprob"
:
-0.091308594
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
603
,
"logprob"
:
-0.96875
,
"special"
:
false
,
"text"
:
" is"
},
{
"id"
:
1671
,
"logprob"
:
-1.6484375
,
"special"
:
false
,
"text"
:
" used"
},
{
"id"
:
577
,
"logprob"
:
-0.43164062
,
"special"
:
false
,
"text"
:
" to"
},
{
"id"
:
3853
,
"logprob"
:
-1.2421875
,
"special"
:
false
,
"text"
:
" request"
}
],
"top_tokens"
:
null
},
"generated_text"
:
" form
\n\n
The test request form is used to request"
}
]
integration-tests/models/test_flash_gemma.py
0 → 100644
View file @
c86f58d3
import
pytest
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_gemma_handle
(
launcher
):
with
launcher
(
"gg-hf/gemma-2b"
,
num_shard
=
1
)
as
handle
:
yield
handle
@
pytest
.
fixture
(
scope
=
"module"
)
async
def
flash_gemma
(
flash_gemma_handle
):
await
flash_gemma_handle
.
health
(
300
)
return
flash_gemma_handle
.
client
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_gemma
(
flash_gemma
,
response_snapshot
):
response
=
await
flash_gemma
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
response_snapshot
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_gemma_all_params
(
flash_gemma
,
response_snapshot
):
response
=
await
flash_gemma
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
repetition_penalty
=
1.2
,
return_full_text
=
True
,
stop_sequences
=
[
"test"
],
temperature
=
0.5
,
top_p
=
0.9
,
top_k
=
10
,
truncate
=
5
,
typical_p
=
0.9
,
watermark
=
True
,
decoder_input_details
=
True
,
seed
=
0
,
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
response_snapshot
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_gemma_load
(
flash_gemma
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
flash_gemma
,
"Test request"
,
max_new_tokens
=
10
,
n
=
4
)
assert
len
(
responses
)
==
4
assert
all
([
r
.
generated_text
==
responses
[
0
].
generated_text
for
r
in
responses
])
assert
responses
==
response_snapshot
server/text_generation_server/models/__init__.py
View file @
c86f58d3
...
...
@@ -52,6 +52,9 @@ try:
from
text_generation_server.models.flash_llama
import
(
FlashLlama
,
)
from
text_generation_server.models.flash_gemma
import
(
FlashGemma
,
)
from
text_generation_server.models.flash_santacoder
import
(
FlashSantacoderSharded
,
)
...
...
@@ -312,6 +315,28 @@ def get_model(
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
model_type
==
"gemma"
:
if
FLASH_ATTENTION
:
return
FlashGemma
(
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
use_medusa
=
use_medusa
,
)
elif
sharded
:
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
"Sharded Golden Gate"
)
)
else
:
return
CausalLM
(
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
model_type
in
[
"RefinedWeb"
,
"RefinedWebModel"
,
"falcon"
]:
if
sharded
:
...
...
server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
0 → 100644
View file @
c86f58d3
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.distributed
import
os
from
shutil
import
copyfile
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
from
tokenizers
import
processors
from
transformers.tokenization_utils_fast
import
PreTrainedTokenizerFast
from
transformers.utils
import
logging
from
text_generation_server.utils
import
paged_attention
,
flash_attn
from
text_generation_server.utils.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
PositionRotaryEmbedding
,
TensorParallelHead
,
get_linear
,
FastRMSNorm
,
)
GemmaTokenizer
=
None
logger
=
logging
.
get_logger
(
__name__
)
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"tokenizer.model"
,
"tokenizer_file"
:
"tokenizer.json"
,
}
PRETRAINED_VOCAB_FILES_MAP
=
{
"vocab_file"
:
{
"hf-internal-testing/llama-tokenizer"
:
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model"
,
},
"tokenizer_file"
:
{
"hf-internal-testing/llama-tokenizer"
:
"https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json"
,
},
}
B_INST
,
E_INST
=
"[INST]"
,
"[/INST]"
B_SYS
,
E_SYS
=
"<<SYS>>
\n
"
,
"
\n
<</SYS>>
\n\n
"
# fmt: off
DEFAULT_SYSTEM_PROMPT
=
"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your
\
answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure
\
that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not
\
correct. If you don't know the answer to a question, please don't share false information."""
# fmt: on
class
GemmaTokenizerFast
(
PreTrainedTokenizerFast
):
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
slow_tokenizer_class
=
GemmaTokenizer
padding_side
=
"left"
model_input_names
=
[
"input_ids"
,
"attention_mask"
]
def
__init__
(
self
,
vocab_file
=
None
,
tokenizer_file
=
None
,
clean_up_tokenization_spaces
=
False
,
unk_token
=
"<unk>"
,
bos_token
=
"<bos>"
,
eos_token
=
"<eos>"
,
pad_token
=
"<pad>"
,
add_bos_token
=
True
,
add_eos_token
=
False
,
use_default_system_prompt
=
False
,
**
kwargs
,
):
super
().
__init__
(
vocab_file
=
vocab_file
,
tokenizer_file
=
tokenizer_file
,
clean_up_tokenization_spaces
=
clean_up_tokenization_spaces
,
unk_token
=
unk_token
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
pad_token
=
pad_token
,
add_bos_token
=
add_bos_token
,
add_eos_token
=
add_eos_token
,
use_default_system_prompt
=
use_default_system_prompt
,
**
kwargs
,
)
self
.
_add_bos_token
=
add_bos_token
self
.
_add_eos_token
=
add_eos_token
self
.
update_post_processor
()
self
.
use_default_system_prompt
=
use_default_system_prompt
self
.
vocab_file
=
vocab_file
@
property
def
can_save_slow_tokenizer
(
self
)
->
bool
:
return
os
.
path
.
isfile
(
self
.
vocab_file
)
if
self
.
vocab_file
else
False
def
update_post_processor
(
self
):
"""
Updates the underlying post processor with the current `bos_token` and `eos_token`.
"""
bos
=
self
.
bos_token
bos_token_id
=
self
.
bos_token_id
if
bos
is
None
and
self
.
add_bos_token
:
raise
ValueError
(
"add_bos_token = True but bos_token = None"
)
eos
=
self
.
eos_token
eos_token_id
=
self
.
eos_token_id
if
eos
is
None
and
self
.
add_eos_token
:
raise
ValueError
(
"add_eos_token = True but eos_token = None"
)
single
=
f
"
{
(
bos
+
':0 '
)
if
self
.
add_bos_token
else
''
}
$A:0
{
(
' '
+
eos
+
':0'
)
if
self
.
add_eos_token
else
''
}
"
pair
=
f
"
{
single
}{
(
' '
+
bos
+
':1'
)
if
self
.
add_bos_token
else
''
}
$B:1
{
(
' '
+
eos
+
':1'
)
if
self
.
add_eos_token
else
''
}
"
special_tokens
=
[]
if
self
.
add_bos_token
:
special_tokens
.
append
((
bos
,
bos_token_id
))
if
self
.
add_eos_token
:
special_tokens
.
append
((
eos
,
eos_token_id
))
self
.
_tokenizer
.
post_processor
=
processors
.
TemplateProcessing
(
single
=
single
,
pair
=
pair
,
special_tokens
=
special_tokens
)
@
property
def
add_eos_token
(
self
):
return
self
.
_add_eos_token
@
property
def
add_bos_token
(
self
):
return
self
.
_add_bos_token
@
add_eos_token
.
setter
def
add_eos_token
(
self
,
value
):
self
.
_add_eos_token
=
value
self
.
update_post_processor
()
@
add_bos_token
.
setter
def
add_bos_token
(
self
,
value
):
self
.
_add_bos_token
=
value
self
.
update_post_processor
()
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
if
not
self
.
can_save_slow_tokenizer
:
raise
ValueError
(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
],
)
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
@
property
def
default_chat_template
(
self
):
raise
NotImplementedError
# TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
,
token_ids_1
=
None
):
bos_token_id
=
[
self
.
bos_token_id
]
if
self
.
add_bos_token
else
[]
eos_token_id
=
[
self
.
eos_token_id
]
if
self
.
add_eos_token
else
[]
output
=
bos_token_id
+
token_ids_0
+
eos_token_id
if
token_ids_1
is
not
None
:
output
=
output
+
bos_token_id
+
token_ids_1
+
eos_token_id
return
output
class
GemmaConfig
(
PretrainedConfig
):
def
__init__
(
self
,
vocab_size
=
256128
,
hidden_size
=
3072
,
intermediate_size
=
24576
,
num_hidden_layers
=
28
,
num_attention_heads
=
16
,
num_key_value_heads
=
16
,
head_dim
=
256
,
hidden_act
=
"gelu"
,
max_position_embeddings
=
8192
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
1
,
eos_token_id
=
2
,
tie_word_embeddings
=
True
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
head_dim
=
head_dim
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
class
GemmaFastRMSNorm
(
FastRMSNorm
):
@
classmethod
def
load
(
cls
,
prefix
,
weights
,
eps
=
1e-6
):
weight
=
weights
.
get_tensor
(
f
"
{
prefix
}
.weight"
)
+
1
return
cls
(
weight
,
eps
)
def
load_attention
(
config
,
prefix
,
weights
):
if
config
.
num_attention_heads
!=
config
.
num_key_value_heads
:
return
_load_gqa
(
config
,
prefix
,
weights
)
else
:
return
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
dim
=
0
,
weights
=
weights
,
bias
=
False
,
)
def
_load_gqa
(
config
,
prefix
:
str
,
weights
):
assert
config
.
num_attention_heads
%
weights
.
process_group
.
size
()
==
0
weight
=
weights
.
get_multi_weights_col
(
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
quantize
=
config
.
quantize
,
dim
=
0
,
)
if
config
.
quantize
not
in
[
"gptq"
,
"awq"
]:
weight
=
weight
.
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
head_size
=
config
.
head_dim
num_heads
=
config
.
num_attention_heads
//
weights
.
process_group
.
size
()
num_key_value_heads
=
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
assert
list
(
weight
.
shape
)
==
[
(
num_heads
+
2
*
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
,
],
f
"
{
list
(
weight
.
shape
)
}
!=
{
[(
num_heads
+
2
*
config
.
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
]
}
"
return
TensorParallelColumnLinear
(
get_linear
(
weight
,
bias
=
None
,
quantize
=
config
.
quantize
)
)
class
FlashGemmaAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
,
):
super
().
__init__
()
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_size
=
config
.
head_dim
self
.
rotary_emb
=
PositionRotaryEmbedding
.
static
(
config
=
config
,
dim
=
self
.
head_size
,
base
=
config
.
rope_theta
,
device
=
weights
.
device
,
)
self
.
softmax_scale
=
self
.
head_size
**-
0.5
if
self
.
num_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
num_key_value_heads
=
(
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
)
self
.
query_key_value
=
load_attention
(
config
,
prefix
,
weights
)
self
.
o_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
weights
=
weights
,
bias
=
False
,
)
self
.
num_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
kv_head_mapping
=
torch
.
arange
(
0
,
self
.
num_key_value_heads
,
dtype
=
torch
.
int32
,
device
=
weights
.
device
).
repeat_interleave
(
self
.
num_groups
)
def
forward
(
self
,
hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
query
,
kv
=
qkv
.
split
(
[
self
.
head_size
*
self
.
num_heads
,
2
*
self
.
head_size
*
self
.
num_key_value_heads
,
],
dim
=
1
,
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
kv
=
kv
.
view
(
-
1
,
2
,
self
.
num_key_value_heads
,
self
.
head_size
)
self
.
rotary_emb
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
cos
,
sin
)
paged_attention
.
reshape_and_cache
(
kv
[:,
0
],
kv
[:,
1
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
# output tensor
attn_output
=
torch
.
empty_like
(
query
)
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
flash_attn
.
attention
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
),
attn_output
,
cu_seqlen_prefill
,
max_s
,
self
.
softmax_scale
,
)
# Decode
else
:
paged_attention
.
attention
(
attn_output
,
query
,
kv_cache
[
0
],
kv_cache
[
1
],
self
.
kv_head_mapping
,
self
.
softmax_scale
,
block_tables
,
input_lengths
,
max_s
,
)
return
self
.
o_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
))
class
GemmaMLP
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
act
=
config
.
hidden_act
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
(
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
),
)
)
# Fuse gate and up proj
self
.
gate_up_proj
=
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.gate_proj"
,
f
"
{
prefix
}
.up_proj"
],
weights
=
weights
,
dim
=
0
,
bias
=
False
,
)
self
.
down_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
weights
=
weights
,
bias
=
False
,
)
self
.
intermediate_size
=
(
config
.
intermediate_size
//
weights
.
process_group
.
size
()
)
def
forward
(
self
,
hidden_states
):
gate_up_states
=
self
.
gate_up_proj
(
hidden_states
)
gate_up_states
=
gate_up_states
.
view
(
-
1
,
2
,
self
.
intermediate_size
)
return
self
.
down_proj
(
self
.
act
(
gate_up_states
[:,
0
])
*
gate_up_states
[:,
1
])
class
FlashGemmaLayer
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
,
config
,
weights
):
super
().
__init__
()
prefix
=
f
"model.layers.
{
layer_id
}
"
self
.
self_attn
=
FlashGemmaAttention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
)
self
.
mlp
=
GemmaMLP
(
prefix
=
f
"
{
prefix
}
.mlp"
,
config
=
config
,
weights
=
weights
)
self
.
input_layernorm
=
GemmaFastRMSNorm
.
load
(
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
GemmaFastRMSNorm
.
load
(
prefix
=
f
"
{
prefix
}
.post_attention_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
,
)
def
forward
(
self
,
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
):
normed_hidden_states
,
res
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Self Attention
attn_output
=
self
.
self_attn
(
normed_hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
)
# faster post attention rms norm
normed_attn_res_output
,
attn_res
=
self
.
post_attention_layernorm
(
attn_output
,
res
)
mlp_output
=
self
.
mlp
(
normed_attn_res_output
)
return
mlp_output
,
attn_res
class
FlashGemmaModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
process_group
=
weights
.
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
embed_norm
=
config
.
hidden_size
**
0.5
self
.
embed_tokens
=
TensorParallelEmbedding
(
prefix
=
"model.embed_tokens"
,
weights
=
weights
)
self
.
embed_tokens
.
weight
*=
embed_norm
self
.
layers
=
nn
.
ModuleList
(
[
FlashGemmaLayer
(
layer_id
,
config
,
weights
,
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
GemmaFastRMSNorm
.
load
(
prefix
=
"model.norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
self
.
gradient_checkpointing
=
False
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
self
.
num_key_value_heads
=
self
.
layers
[
0
].
self_attn
.
num_key_value_heads
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos
,
sin
=
self
.
layers
[
0
].
self_attn
.
rotary_emb
.
get_cos_sin
(
position_ids
,
max_s
,
hidden_states
.
dtype
)
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
[
i
],
block_tables
,
slots
,
input_lengths
,
max_s
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
FlashGemmaForCausalLM
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
super
().
__init__
()
self
.
model
=
FlashGemmaModel
(
config
,
weights
)
self
.
lm_head
=
TensorParallelHead
.
load
(
config
,
prefix
=
"model.embed_tokens"
if
config
.
tie_word_embeddings
else
"lm_head"
,
weights
=
weights
,
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
position_ids
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
)
if
lm_head_indices
is
not
None
:
hidden_states
=
hidden_states
[
lm_head_indices
]
logits
=
self
.
lm_head
(
hidden_states
)
return
logits
server/text_generation_server/models/flash_gemma.py
0 → 100644
View file @
c86f58d3
import
torch
import
torch.distributed
from
opentelemetry
import
trace
from
typing
import
Optional
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_gemma_modeling
import
(
GemmaTokenizerFast
,
FlashGemmaForCausalLM
,
GemmaConfig
,
)
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
Weights
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
class
FlashGemma
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
use_medusa
:
Optional
[
str
]
=
None
,
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
bfloat16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashGemma is only available on GPU"
)
tokenizer
=
GemmaTokenizerFast
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
trust_remote_code
=
trust_remote_code
,
use_fast
=
True
,
from_slow
=
False
,
)
config
=
GemmaConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
config
.
quantize
=
quantize
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
if
config
.
quantize
in
[
"gptq"
,
"awq"
]:
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
FlashGemmaForCausalLM
(
config
,
weights
)
if
use_medusa
:
from
text_generation_server.utils.medusa
import
MedusaModel
from
huggingface_hub
import
hf_hub_download
import
json
import
os
from
pathlib
import
Path
is_local_model
=
(
Path
(
use_medusa
).
exists
()
and
Path
(
use_medusa
).
is_dir
()
)
or
os
.
getenv
(
"WEIGHTS_CACHE_OVERRIDE"
,
None
)
is
not
None
if
not
is_local_model
:
medusa_config
=
hf_hub_download
(
use_medusa
,
revision
=
revision
,
filename
=
"config.json"
)
medusa_head
=
hf_hub_download
(
use_medusa
,
revision
=
revision
,
filename
=
"medusa_lm_head.pt"
)
else
:
medusa_config
=
str
(
Path
(
use_medusa
)
/
"config.json"
)
medusa_head
=
str
(
Path
(
use_medusa
)
/
"medusa_lm_head.pt"
)
with
open
(
medusa_config
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
medusa_sf
=
medusa_head
[:
-
len
(
".pt"
)]
+
".safetensors"
weights
=
Weights
(
[
medusa_sf
],
device
,
dtype
,
process_group
=
self
.
process_group
)
lm_head
=
model
.
lm_head
model
.
lm_head
=
MedusaModel
(
config
,
weights
,
lm_head
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashGemma
,
self
).
__init__
(
model
=
model
,
tokenizer
=
tokenizer
,
num_layers
=
len
(
model
.
model
.
layers
),
num_kv_heads
=
model
.
model
.
num_key_value_heads
,
head_size
=
model
.
model
.
head_size
,
dtype
=
dtype
,
device
=
device
,
rank
=
rank
,
world_size
=
world_size
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment