Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
c86f58d3
Unverified
Commit
c86f58d3
authored
Feb 21, 2024
by
OlivierDehaene
Committed by
GitHub
Feb 21, 2024
Browse files
feat: add support for Gemma (#1583)
parent
fa8a8e05
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
1338 additions
and
0 deletions
+1338
-0
integration-tests/conftest.py
integration-tests/conftest.py
+3
-0
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json
...dels/__snapshots__/test_flash_gemma/test_flash_gemma.json
+89
-0
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json
...shots__/test_flash_gemma/test_flash_gemma_all_params.json
+89
-0
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json
...__snapshots__/test_flash_gemma/test_flash_gemma_load.json
+358
-0
integration-tests/models/test_flash_gemma.py
integration-tests/models/test_flash_gemma.py
+61
-0
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+25
-0
server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
...ion_server/models/custom_modeling/flash_gemma_modeling.py
+609
-0
server/text_generation_server/models/flash_gemma.py
server/text_generation_server/models/flash_gemma.py
+104
-0
No files found.
integration-tests/conftest.py
View file @
c86f58d3
...
...
@@ -40,6 +40,9 @@ class ResponseComparator(JSONSnapshotExtension):
exclude
=
None
,
matcher
=
None
,
):
if
isinstance
(
data
,
Response
):
data
=
data
.
dict
()
if
isinstance
(
data
,
List
):
data
=
[
d
.
dict
()
for
d
in
data
]
...
...
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json
0 → 100644
View file @
c86f58d3
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
2
,
"logprob"
:
null
,
"text"
:
"<bos>"
},
{
"id"
:
2015
,
"logprob"
:
-10.0
,
"text"
:
"Test"
},
{
"id"
:
3853
,
"logprob"
:
-10.875
,
"text"
:
" request"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
1736
,
"logprob"
:
-2.09375
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
109
,
"logprob"
:
-1.8671875
,
"special"
:
false
,
"text"
:
"
\n\n
"
},
{
"id"
:
651
,
"logprob"
:
-2.4375
,
"special"
:
false
,
"text"
:
"The"
},
{
"id"
:
2121
,
"logprob"
:
-1.8203125
,
"special"
:
false
,
"text"
:
" test"
},
{
"id"
:
3853
,
"logprob"
:
-0.23242188
,
"special"
:
false
,
"text"
:
" request"
},
{
"id"
:
1736
,
"logprob"
:
-0.08544922
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
603
,
"logprob"
:
-0.9375
,
"special"
:
false
,
"text"
:
" is"
},
{
"id"
:
1671
,
"logprob"
:
-1.671875
,
"special"
:
false
,
"text"
:
" used"
},
{
"id"
:
577
,
"logprob"
:
-0.40429688
,
"special"
:
false
,
"text"
:
" to"
},
{
"id"
:
3853
,
"logprob"
:
-1.1875
,
"special"
:
false
,
"text"
:
" request"
}
],
"top_tokens"
:
null
},
"generated_text"
:
" form
\n\n
The test request form is used to request"
}
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_all_params.json
0 → 100644
View file @
c86f58d3
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
2
,
"logprob"
:
null
,
"text"
:
"<bos>"
},
{
"id"
:
2015
,
"logprob"
:
-10.0
,
"text"
:
"Test"
},
{
"id"
:
3853
,
"logprob"
:
-10.875
,
"text"
:
" request"
}
],
"seed"
:
0
,
"tokens"
:
[
{
"id"
:
7539
,
"logprob"
:
-0.73046875
,
"special"
:
false
,
"text"
:
" forms"
},
{
"id"
:
708
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" are"
},
{
"id"
:
671
,
"logprob"
:
-1.703125
,
"special"
:
false
,
"text"
:
" an"
},
{
"id"
:
8727
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" essential"
},
{
"id"
:
1702
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" part"
},
{
"id"
:
576
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" of"
},
{
"id"
:
573
,
"logprob"
:
0.0
,
"special"
:
false
,
"text"
:
" the"
},
{
"id"
:
11859
,
"logprob"
:
-1.6953125
,
"special"
:
false
,
"text"
:
" lab"
},
{
"id"
:
2185
,
"logprob"
:
-1.3125
,
"special"
:
false
,
"text"
:
" process"
},
{
"id"
:
578
,
"logprob"
:
-1.5
,
"special"
:
false
,
"text"
:
" and"
}
],
"top_tokens"
:
null
},
"generated_text"
:
"Test request forms are an essential part of the lab process and"
}
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma_load.json
0 → 100644
View file @
c86f58d3
[
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
2
,
"logprob"
:
null
,
"text"
:
"<bos>"
},
{
"id"
:
2015
,
"logprob"
:
-10.0
,
"text"
:
"Test"
},
{
"id"
:
3853
,
"logprob"
:
-10.875
,
"text"
:
" request"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
1736
,
"logprob"
:
-2.09375
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
109
,
"logprob"
:
-1.9140625
,
"special"
:
false
,
"text"
:
"
\n\n
"
},
{
"id"
:
651
,
"logprob"
:
-2.453125
,
"special"
:
false
,
"text"
:
"The"
},
{
"id"
:
2121
,
"logprob"
:
-1.8984375
,
"special"
:
false
,
"text"
:
" test"
},
{
"id"
:
3853
,
"logprob"
:
-0.23535156
,
"special"
:
false
,
"text"
:
" request"
},
{
"id"
:
1736
,
"logprob"
:
-0.091308594
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
603
,
"logprob"
:
-0.96875
,
"special"
:
false
,
"text"
:
" is"
},
{
"id"
:
1671
,
"logprob"
:
-1.6484375
,
"special"
:
false
,
"text"
:
" used"
},
{
"id"
:
577
,
"logprob"
:
-0.43164062
,
"special"
:
false
,
"text"
:
" to"
},
{
"id"
:
3853
,
"logprob"
:
-1.2421875
,
"special"
:
false
,
"text"
:
" request"
}
],
"top_tokens"
:
null
},
"generated_text"
:
" form
\n\n
The test request form is used to request"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
2
,
"logprob"
:
null
,
"text"
:
"<bos>"
},
{
"id"
:
2015
,
"logprob"
:
-10.0
,
"text"
:
"Test"
},
{
"id"
:
3853
,
"logprob"
:
-10.875
,
"text"
:
" request"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
1736
,
"logprob"
:
-2.09375
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
109
,
"logprob"
:
-1.9140625
,
"special"
:
false
,
"text"
:
"
\n\n
"
},
{
"id"
:
651
,
"logprob"
:
-2.453125
,
"special"
:
false
,
"text"
:
"The"
},
{
"id"
:
2121
,
"logprob"
:
-1.8984375
,
"special"
:
false
,
"text"
:
" test"
},
{
"id"
:
3853
,
"logprob"
:
-0.23535156
,
"special"
:
false
,
"text"
:
" request"
},
{
"id"
:
1736
,
"logprob"
:
-0.091308594
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
603
,
"logprob"
:
-0.96875
,
"special"
:
false
,
"text"
:
" is"
},
{
"id"
:
1671
,
"logprob"
:
-1.6484375
,
"special"
:
false
,
"text"
:
" used"
},
{
"id"
:
577
,
"logprob"
:
-0.43164062
,
"special"
:
false
,
"text"
:
" to"
},
{
"id"
:
3853
,
"logprob"
:
-1.2421875
,
"special"
:
false
,
"text"
:
" request"
}
],
"top_tokens"
:
null
},
"generated_text"
:
" form
\n\n
The test request form is used to request"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
2
,
"logprob"
:
null
,
"text"
:
"<bos>"
},
{
"id"
:
2015
,
"logprob"
:
-10.0
,
"text"
:
"Test"
},
{
"id"
:
3853
,
"logprob"
:
-10.875
,
"text"
:
" request"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
1736
,
"logprob"
:
-2.09375
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
109
,
"logprob"
:
-1.9140625
,
"special"
:
false
,
"text"
:
"
\n\n
"
},
{
"id"
:
651
,
"logprob"
:
-2.453125
,
"special"
:
false
,
"text"
:
"The"
},
{
"id"
:
2121
,
"logprob"
:
-1.8984375
,
"special"
:
false
,
"text"
:
" test"
},
{
"id"
:
3853
,
"logprob"
:
-0.23535156
,
"special"
:
false
,
"text"
:
" request"
},
{
"id"
:
1736
,
"logprob"
:
-0.091308594
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
603
,
"logprob"
:
-0.96875
,
"special"
:
false
,
"text"
:
" is"
},
{
"id"
:
1671
,
"logprob"
:
-1.6484375
,
"special"
:
false
,
"text"
:
" used"
},
{
"id"
:
577
,
"logprob"
:
-0.43164062
,
"special"
:
false
,
"text"
:
" to"
},
{
"id"
:
3853
,
"logprob"
:
-1.2421875
,
"special"
:
false
,
"text"
:
" request"
}
],
"top_tokens"
:
null
},
"generated_text"
:
" form
\n\n
The test request form is used to request"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"length"
,
"generated_tokens"
:
10
,
"prefill"
:
[
{
"id"
:
2
,
"logprob"
:
null
,
"text"
:
"<bos>"
},
{
"id"
:
2015
,
"logprob"
:
-10.0
,
"text"
:
"Test"
},
{
"id"
:
3853
,
"logprob"
:
-10.875
,
"text"
:
" request"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
1736
,
"logprob"
:
-2.09375
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
109
,
"logprob"
:
-1.9140625
,
"special"
:
false
,
"text"
:
"
\n\n
"
},
{
"id"
:
651
,
"logprob"
:
-2.453125
,
"special"
:
false
,
"text"
:
"The"
},
{
"id"
:
2121
,
"logprob"
:
-1.8984375
,
"special"
:
false
,
"text"
:
" test"
},
{
"id"
:
3853
,
"logprob"
:
-0.23535156
,
"special"
:
false
,
"text"
:
" request"
},
{
"id"
:
1736
,
"logprob"
:
-0.091308594
,
"special"
:
false
,
"text"
:
" form"
},
{
"id"
:
603
,
"logprob"
:
-0.96875
,
"special"
:
false
,
"text"
:
" is"
},
{
"id"
:
1671
,
"logprob"
:
-1.6484375
,
"special"
:
false
,
"text"
:
" used"
},
{
"id"
:
577
,
"logprob"
:
-0.43164062
,
"special"
:
false
,
"text"
:
" to"
},
{
"id"
:
3853
,
"logprob"
:
-1.2421875
,
"special"
:
false
,
"text"
:
" request"
}
],
"top_tokens"
:
null
},
"generated_text"
:
" form
\n\n
The test request form is used to request"
}
]
integration-tests/models/test_flash_gemma.py
0 → 100644
View file @
c86f58d3
import
pytest
@
pytest
.
fixture
(
scope
=
"module"
)
def
flash_gemma_handle
(
launcher
):
with
launcher
(
"gg-hf/gemma-2b"
,
num_shard
=
1
)
as
handle
:
yield
handle
@
pytest
.
fixture
(
scope
=
"module"
)
async
def
flash_gemma
(
flash_gemma_handle
):
await
flash_gemma_handle
.
health
(
300
)
return
flash_gemma_handle
.
client
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_gemma
(
flash_gemma
,
response_snapshot
):
response
=
await
flash_gemma
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
decoder_input_details
=
True
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
response_snapshot
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_gemma_all_params
(
flash_gemma
,
response_snapshot
):
response
=
await
flash_gemma
.
generate
(
"Test request"
,
max_new_tokens
=
10
,
repetition_penalty
=
1.2
,
return_full_text
=
True
,
stop_sequences
=
[
"test"
],
temperature
=
0.5
,
top_p
=
0.9
,
top_k
=
10
,
truncate
=
5
,
typical_p
=
0.9
,
watermark
=
True
,
decoder_input_details
=
True
,
seed
=
0
,
)
assert
response
.
details
.
generated_tokens
==
10
assert
response
==
response_snapshot
@
pytest
.
mark
.
skip
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
private
async
def
test_flash_gemma_load
(
flash_gemma
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
flash_gemma
,
"Test request"
,
max_new_tokens
=
10
,
n
=
4
)
assert
len
(
responses
)
==
4
assert
all
([
r
.
generated_text
==
responses
[
0
].
generated_text
for
r
in
responses
])
assert
responses
==
response_snapshot
server/text_generation_server/models/__init__.py
View file @
c86f58d3
...
...
@@ -52,6 +52,9 @@ try:
from
text_generation_server.models.flash_llama
import
(
FlashLlama
,
)
from
text_generation_server.models.flash_gemma
import
(
FlashGemma
,
)
from
text_generation_server.models.flash_santacoder
import
(
FlashSantacoderSharded
,
)
...
...
@@ -312,6 +315,28 @@ def get_model(
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
model_type
==
"gemma"
:
if
FLASH_ATTENTION
:
return
FlashGemma
(
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
use_medusa
=
use_medusa
,
)
elif
sharded
:
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
"Sharded Golden Gate"
)
)
else
:
return
CausalLM
(
model_id
,
revision
,
quantize
=
quantize
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
model_type
in
[
"RefinedWeb"
,
"RefinedWebModel"
,
"falcon"
]:
if
sharded
:
...
...
server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
0 → 100644
View file @
c86f58d3
This diff is collapsed.
Click to expand it.
server/text_generation_server/models/flash_gemma.py
0 → 100644
View file @
c86f58d3
import
torch
import
torch.distributed
from
opentelemetry
import
trace
from
typing
import
Optional
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_gemma_modeling
import
(
GemmaTokenizerFast
,
FlashGemmaForCausalLM
,
GemmaConfig
,
)
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
Weights
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
class
FlashGemma
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
use_medusa
:
Optional
[
str
]
=
None
,
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
bfloat16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashGemma is only available on GPU"
)
tokenizer
=
GemmaTokenizerFast
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
trust_remote_code
=
trust_remote_code
,
use_fast
=
True
,
from_slow
=
False
,
)
config
=
GemmaConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
config
.
quantize
=
quantize
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
if
config
.
quantize
in
[
"gptq"
,
"awq"
]:
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
FlashGemmaForCausalLM
(
config
,
weights
)
if
use_medusa
:
from
text_generation_server.utils.medusa
import
MedusaModel
from
huggingface_hub
import
hf_hub_download
import
json
import
os
from
pathlib
import
Path
is_local_model
=
(
Path
(
use_medusa
).
exists
()
and
Path
(
use_medusa
).
is_dir
()
)
or
os
.
getenv
(
"WEIGHTS_CACHE_OVERRIDE"
,
None
)
is
not
None
if
not
is_local_model
:
medusa_config
=
hf_hub_download
(
use_medusa
,
revision
=
revision
,
filename
=
"config.json"
)
medusa_head
=
hf_hub_download
(
use_medusa
,
revision
=
revision
,
filename
=
"medusa_lm_head.pt"
)
else
:
medusa_config
=
str
(
Path
(
use_medusa
)
/
"config.json"
)
medusa_head
=
str
(
Path
(
use_medusa
)
/
"medusa_lm_head.pt"
)
with
open
(
medusa_config
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
medusa_sf
=
medusa_head
[:
-
len
(
".pt"
)]
+
".safetensors"
weights
=
Weights
(
[
medusa_sf
],
device
,
dtype
,
process_group
=
self
.
process_group
)
lm_head
=
model
.
lm_head
model
.
lm_head
=
MedusaModel
(
config
,
weights
,
lm_head
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashGemma
,
self
).
__init__
(
model
=
model
,
tokenizer
=
tokenizer
,
num_layers
=
len
(
model
.
model
.
layers
),
num_kv_heads
=
model
.
model
.
num_key_value_heads
,
head_size
=
model
.
model
.
head_size
,
dtype
=
dtype
,
device
=
device
,
rank
=
rank
,
world_size
=
world_size
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment