Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
3ea8259a
Unverified
Commit
3ea8259a
authored
Jun 27, 2024
by
Nicolas Patry
Committed by
GitHub
Jun 27, 2024
Browse files
Fixing gemma2. (#2135)
* Fixing gemma2. * Adding new model.
parent
0e4ab6d3
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
622 additions
and
9 deletions
+622
-9
docs/source/supported_models.md
docs/source/supported_models.md
+1
-0
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+30
-0
server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py
...on_server/models/custom_modeling/flash_gemma2_modeling.py
+500
-0
server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
...ion_server/models/custom_modeling/flash_gemma_modeling.py
+0
-2
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+11
-7
server/text_generation_server/models/flash_gemma2.py
server/text_generation_server/models/flash_gemma2.py
+75
-0
server/text_generation_server/models/globals.py
server/text_generation_server/models/globals.py
+5
-0
No files found.
docs/source/supported_models.md
View file @
3ea8259a
...
@@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
...
@@ -10,6 +10,7 @@ Text Generation Inference enables serving optimized models on specific hardware
-
[
Llama
](
https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
)
-
[
Llama
](
https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
)
-
[
Phi 3
](
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
)
-
[
Phi 3
](
https://huggingface.co/microsoft/Phi-3-mini-4k-instruct
)
-
[
Gemma
](
https://huggingface.co/google/gemma-7b
)
-
[
Gemma
](
https://huggingface.co/google/gemma-7b
)
-
[
Gemma2
](
https://huggingface.co/google/gemma2-9b
)
-
[
Cohere
](
https://huggingface.co/CohereForAI/c4ai-command-r-plus
)
-
[
Cohere
](
https://huggingface.co/CohereForAI/c4ai-command-r-plus
)
-
[
Dbrx
](
https://huggingface.co/databricks/dbrx-instruct
)
-
[
Dbrx
](
https://huggingface.co/databricks/dbrx-instruct
)
-
[
Mamba
](
https://huggingface.co/state-spaces/mamba-2.8b-slimpj
)
-
[
Mamba
](
https://huggingface.co/state-spaces/mamba-2.8b-slimpj
)
...
...
server/text_generation_server/models/__init__.py
View file @
3ea8259a
...
@@ -68,6 +68,9 @@ try:
...
@@ -68,6 +68,9 @@ try:
from
text_generation_server.models.flash_gemma
import
(
from
text_generation_server.models.flash_gemma
import
(
FlashGemma
,
FlashGemma
,
)
)
from
text_generation_server.models.flash_gemma2
import
(
FlashGemma2
,
)
from
text_generation_server.models.pali_gemma
import
(
from
text_generation_server.models.pali_gemma
import
(
PaliGemma
,
PaliGemma
,
)
)
...
@@ -102,6 +105,7 @@ if FLASH_ATTENTION:
...
@@ -102,6 +105,7 @@ if FLASH_ATTENTION:
__all__
.
append
(
FlashQwen2
)
__all__
.
append
(
FlashQwen2
)
__all__
.
append
(
FlashStarcoder2
)
__all__
.
append
(
FlashStarcoder2
)
__all__
.
append
(
FlashGemma
)
__all__
.
append
(
FlashGemma
)
__all__
.
append
(
FlashGemma2
)
__all__
.
append
(
FlashCohere
)
__all__
.
append
(
FlashCohere
)
MAMBA_AVAILABLE
=
True
MAMBA_AVAILABLE
=
True
...
@@ -143,6 +147,11 @@ class ModelType(enum.Enum):
...
@@ -143,6 +147,11 @@ class ModelType(enum.Enum):
"name"
:
"Gemma"
,
"name"
:
"Gemma"
,
"url"
:
"https://huggingface.co/google/gemma-7b"
,
"url"
:
"https://huggingface.co/google/gemma-7b"
,
}
}
GEMMA2
=
{
"type"
:
"gemma2"
,
"name"
:
"Gemma2"
,
"url"
:
"https://huggingface.co/google/gemma2-9b"
,
}
COHERE
=
{
COHERE
=
{
"type"
:
"cohere"
,
"type"
:
"cohere"
,
"name"
:
"Cohere"
,
"name"
:
"Cohere"
,
...
@@ -630,6 +639,27 @@ def get_model(
...
@@ -630,6 +639,27 @@ def get_model(
dtype
=
dtype
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
elif
model_type
==
GEMMA2
:
if
FLASH_ATTENTION
:
return
FlashGemma2
(
model_id
,
revision
,
quantize
=
quantize
,
speculator
=
speculator
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
elif
sharded
:
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
"Sharded Gemma2"
))
else
:
return
CausalLM
(
model_id
,
revision
,
quantize
=
quantize
,
speculator
=
speculator
,
dtype
=
dtype
,
trust_remote_code
=
trust_remote_code
,
)
if
model_type
==
COHERE
:
if
model_type
==
COHERE
:
if
FLASH_ATTENTION
:
if
FLASH_ATTENTION
:
...
...
server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py
0 → 100644
View file @
3ea8259a
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.distributed
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
from
text_generation_server.layers.attention
import
(
paged_attention
,
attention
,
reshape_and_cache
,
)
from
text_generation_server.layers
import
(
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
SpeculativeHead
,
get_linear
,
)
from
text_generation_server.layers.rotary
import
PositionRotaryEmbedding
from
text_generation_server.layers.layernorm
import
(
FastRMSNorm
,
)
class
Gemma2Config
(
PretrainedConfig
):
def
__init__
(
self
,
vocab_size
=
256128
,
hidden_size
=
3072
,
intermediate_size
=
24576
,
num_hidden_layers
=
28
,
num_attention_heads
=
16
,
num_key_value_heads
=
16
,
head_dim
=
256
,
hidden_act
=
"gelu_pytorch_tanh"
,
max_position_embeddings
=
8192
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
None
,
bos_token_id
=
1
,
eos_token_id
=
2
,
tie_word_embeddings
=
True
,
rope_theta
=
10000.0
,
rope_scaling
=
None
,
attention_bias
=
False
,
attention_dropout
=
0.0
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
head_dim
=
head_dim
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
use_cache
=
use_cache
self
.
rope_theta
=
rope_theta
self
.
rope_scaling
=
rope_scaling
self
.
attention_bias
=
attention_bias
self
.
attention_dropout
=
attention_dropout
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
class
Gemma2FastRMSNorm
(
FastRMSNorm
):
@
classmethod
def
load
(
cls
,
prefix
,
weights
,
eps
=
1e-6
):
dtype
=
weights
.
dtype
weights
.
dtype
=
torch
.
float32
weight
=
weights
.
get_tensor
(
f
"
{
prefix
}
.weight"
)
+
1
weights
.
dtype
=
dtype
new
=
cls
(
weight
,
eps
)
new
.
dtype
=
dtype
return
new
# perform the multiplication in full precision and downcast after
def
forward
(
self
,
hidden_states
,
residual
=
None
):
if
residual
is
not
None
:
hidden_states
+=
residual
residual
=
hidden_states
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
hidden_states
=
hidden_states
*
self
.
weight
return
hidden_states
.
to
(
self
.
dtype
),
residual
def
load_attention
(
config
,
prefix
,
weights
):
if
config
.
num_attention_heads
!=
config
.
num_key_value_heads
:
return
_load_gqa
(
config
,
prefix
,
weights
)
else
:
return
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
dim
=
0
,
weights
=
weights
,
bias
=
False
,
)
def
_load_gqa
(
config
,
prefix
:
str
,
weights
):
assert
config
.
num_attention_heads
%
weights
.
process_group
.
size
()
==
0
weight
=
weights
.
get_multi_weights_col
(
prefixes
=
[
f
"
{
prefix
}
.q_proj"
,
f
"
{
prefix
}
.k_proj"
,
f
"
{
prefix
}
.v_proj"
],
quantize
=
config
.
quantize
,
dim
=
0
,
)
if
config
.
quantize
not
in
[
"gptq"
,
"awq"
,
"marlin"
]:
weight
=
weight
.
to
(
dtype
=
weights
.
dtype
).
to
(
device
=
weights
.
device
)
head_size
=
config
.
head_dim
num_heads
=
config
.
num_attention_heads
//
weights
.
process_group
.
size
()
num_key_value_heads
=
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
assert
list
(
weight
.
shape
)
==
[
(
num_heads
+
2
*
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
,
],
f
"
{
list
(
weight
.
shape
)
}
!=
{
[(
num_heads
+
2
*
config
.
num_key_value_heads
)
*
head_size
,
config
.
hidden_size
]
}
"
return
TensorParallelColumnLinear
(
get_linear
(
weight
,
bias
=
None
,
quantize
=
config
.
quantize
)
)
class
FlashGemma2Attention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
,
causal
:
bool
,
is_sliding
:
bool
):
super
().
__init__
()
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_size
=
config
.
head_dim
self
.
causal
=
causal
if
is_sliding
:
self
.
window_size
=
config
.
sliding_window
else
:
self
.
window_size
=
-
1
self
.
rotary_emb
=
PositionRotaryEmbedding
.
static
(
config
=
config
,
dim
=
self
.
head_size
,
base
=
config
.
rope_theta
,
device
=
weights
.
device
,
)
# self.softmax_scale = self.head_size**-0.5
self
.
softmax_scale
=
config
.
query_pre_attn_scalar
**-
0.5
if
self
.
num_heads
%
weights
.
process_group
.
size
()
!=
0
:
raise
ValueError
(
f
"`num_heads` must be divisible by `num_shards` (got `num_heads`:
{
self
.
num_heads
}
"
f
"and `num_shards`:
{
weights
.
process_group
.
size
()
}
"
)
self
.
num_heads
=
self
.
num_heads
//
weights
.
process_group
.
size
()
self
.
num_key_value_heads
=
(
config
.
num_key_value_heads
//
weights
.
process_group
.
size
()
)
self
.
query_key_value
=
load_attention
(
config
,
prefix
,
weights
)
self
.
o_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
weights
=
weights
,
bias
=
False
,
)
self
.
num_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
kv_head_mapping
=
torch
.
arange
(
0
,
self
.
num_key_value_heads
,
dtype
=
torch
.
int32
,
device
=
weights
.
device
).
repeat_interleave
(
self
.
num_groups
)
def
forward
(
self
,
hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
):
qkv
=
self
.
query_key_value
(
hidden_states
)
query
,
kv
=
qkv
.
split
(
[
self
.
head_size
*
self
.
num_heads
,
2
*
self
.
head_size
*
self
.
num_key_value_heads
,
],
dim
=
1
,
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
kv
=
kv
.
view
(
-
1
,
2
,
self
.
num_key_value_heads
,
self
.
head_size
)
self
.
rotary_emb
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
cos
,
sin
)
reshape_and_cache
(
kv
[:,
0
],
kv
[:,
1
],
kv_cache
[
0
],
kv_cache
[
1
],
slots
)
# output tensor
attn_output
=
torch
.
empty_like
(
query
)
# Prefill
if
cu_seqlen_prefill
is
not
None
:
# flash attention
attention
(
query
,
torch
.
select
(
kv
,
dim
=
1
,
index
=
0
),
torch
.
select
(
kv
,
dim
=
1
,
index
=
1
),
attn_output
,
cu_seqlen_prefill
,
max_s
,
self
.
softmax_scale
,
causal
=
self
.
causal
,
window_size_left
=
self
.
window_size
,
)
# Decode
else
:
paged_attention
(
attn_output
,
query
,
kv_cache
[
0
],
kv_cache
[
1
],
self
.
kv_head_mapping
,
self
.
softmax_scale
,
block_tables
,
input_lengths
,
max_s
,
)
return
self
.
o_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
))
class
Gemma2MLP
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
act
=
config
.
hidden_act
self
.
act
=
(
ACT2FN
[
act
]
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
(
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
),
)
)
# Fuse gate and up proj
self
.
gate_up_proj
=
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
prefix
}
.gate_proj"
,
f
"
{
prefix
}
.up_proj"
],
weights
=
weights
,
dim
=
0
,
bias
=
False
,
)
self
.
down_proj
=
TensorParallelRowLinear
.
load
(
config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
weights
=
weights
,
bias
=
False
,
)
self
.
intermediate_size
=
(
config
.
intermediate_size
//
weights
.
process_group
.
size
()
)
def
forward
(
self
,
hidden_states
):
gate_up_states
=
self
.
gate_up_proj
(
hidden_states
)
gate_up_states
=
gate_up_states
.
view
(
-
1
,
2
,
self
.
intermediate_size
)
return
self
.
down_proj
(
self
.
act
(
gate_up_states
[:,
0
])
*
gate_up_states
[:,
1
])
class
FlashGemma2Layer
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
,
causal
:
bool
,
is_sliding
:
bool
):
super
().
__init__
()
self
.
self_attn
=
FlashGemma2Attention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
,
causal
=
causal
,
is_sliding
=
is_sliding
,
)
self
.
mlp
=
Gemma2MLP
(
prefix
=
f
"
{
prefix
}
.mlp"
,
config
=
config
,
weights
=
weights
)
self
.
input_layernorm
=
Gemma2FastRMSNorm
.
load
(
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
Gemma2FastRMSNorm
.
load
(
prefix
=
f
"
{
prefix
}
.post_attention_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
,
)
self
.
pre_feedforward_layernorm
=
Gemma2FastRMSNorm
.
load
(
prefix
=
f
"
{
prefix
}
.pre_feedforward_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
,
)
self
.
post_feedforward_layernorm
=
Gemma2FastRMSNorm
.
load
(
prefix
=
f
"
{
prefix
}
.post_feedforward_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
,
)
def
forward
(
self
,
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
):
normed_hidden_states
,
res
=
self
.
input_layernorm
(
hidden_states
,
residual
)
# Self Attention
attn_output
=
self
.
self_attn
(
normed_hidden_states
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
)
# faster post attention rms norm
normed_attn_res_output
,
_
=
self
.
post_attention_layernorm
(
attn_output
)
normed_attn_res_output
=
normed_attn_res_output
+
res
res
=
normed_attn_res_output
pre_normed
,
_
=
self
.
pre_feedforward_layernorm
(
normed_attn_res_output
)
mlp_output
=
self
.
mlp
(
pre_normed
)
post_hidden_states
,
_
=
self
.
post_feedforward_layernorm
(
mlp_output
)
return
post_hidden_states
,
normed_attn_res_output
class
FlashGemma2Model
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
,
causal
:
bool
):
super
().
__init__
()
process_group
=
weights
.
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
layers
=
nn
.
ModuleList
(
[
FlashGemma2Layer
(
prefix
=
f
"
{
prefix
}
.layers.
{
layer_id
}
"
,
config
=
config
,
weights
=
weights
,
causal
=
causal
,
is_sliding
=
layer_id
%
2
==
0
,
)
for
layer_id
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
norm
=
Gemma2FastRMSNorm
.
load
(
prefix
=
f
"
{
prefix
}
.norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
self
.
num_key_value_heads
=
self
.
layers
[
0
].
self_attn
.
num_key_value_heads
def
forward
(
self
,
inputs_embeds
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
)
->
torch
.
Tensor
:
hidden_states
=
inputs_embeds
# Get rotary cos and sin for this forward
# Avoid to index in each layer
cos
,
sin
=
self
.
layers
[
0
].
self_attn
.
rotary_emb
.
get_cos_sin
(
position_ids
,
max_s
,
hidden_states
.
dtype
)
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
cos
,
sin
,
cu_seqlen_prefill
,
kv_cache
[
i
],
block_tables
,
slots
,
input_lengths
,
max_s
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
FlashGemma2ForCausalLM
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
,
causal
:
bool
):
super
().
__init__
()
embed_norm
=
config
.
hidden_size
**
0.5
if
not
prefix
:
prefix
=
"model"
else
:
prefix
=
f
"
{
prefix
}
.model"
self
.
embed_tokens
=
TensorParallelEmbedding
(
prefix
=
f
"
{
prefix
}
.embed_tokens"
,
weights
=
weights
)
self
.
embed_tokens
.
weight
*=
embed_norm
self
.
model
=
FlashGemma2Model
(
prefix
=
prefix
,
config
=
config
,
weights
=
weights
,
causal
=
causal
)
self
.
lm_head
=
SpeculativeHead
.
load
(
prefix
=
(
f
"
{
prefix
}
.embed_tokens"
if
config
.
tie_word_embeddings
else
f
"
{
prefix
}
.lm_head"
),
config
=
config
,
weights
=
weights
,
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
cu_seqlen_prefill
:
Optional
[
torch
.
Tensor
],
kv_cache
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
block_tables
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
input_lengths
:
torch
.
Tensor
,
max_s
:
int
,
prefill_cache_indices
:
Optional
[
torch
.
Tensor
],
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
adapter_data
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
input_embeds
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
model
(
input_embeds
,
position_ids
,
cu_seqlen_prefill
,
kv_cache
,
block_tables
,
slots
,
input_lengths
,
max_s
,
)
if
lm_head_indices
is
not
None
:
hidden_states
=
hidden_states
[
lm_head_indices
]
logits
,
speculative_logits
=
self
.
lm_head
(
hidden_states
)
return
logits
,
speculative_logits
server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py
View file @
3ea8259a
...
@@ -375,8 +375,6 @@ class FlashGemmaModel(torch.nn.Module):
...
@@ -375,8 +375,6 @@ class FlashGemmaModel(torch.nn.Module):
prefix
=
f
"
{
prefix
}
.norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
prefix
=
f
"
{
prefix
}
.norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
)
)
self
.
gradient_checkpointing
=
False
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
self
.
num_key_value_heads
=
self
.
layers
[
0
].
self_attn
.
num_key_value_heads
self
.
num_key_value_heads
=
self
.
layers
[
0
].
self_attn
.
num_key_value_heads
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
3ea8259a
...
@@ -28,8 +28,12 @@ from text_generation_server.models.types import (
...
@@ -28,8 +28,12 @@ from text_generation_server.models.types import (
GeneratedText
,
GeneratedText
,
)
)
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.pb
import
generate_pb2
from
text_generation_server.models.globals
import
MEM_POOL
,
CUDA_GRAPHS
from
text_generation_server.models.globals
import
(
import
text_generation_server.models.globals
as
tgi_globals
MEM_POOL
,
CUDA_GRAPHS
,
get_adapter_to_index
,
MODEL_ID
,
)
from
text_generation_server.utils
import
StoppingCriteria
,
HeterogeneousNextTokenChooser
from
text_generation_server.utils
import
StoppingCriteria
,
HeterogeneousNextTokenChooser
from
text_generation_server.utils.dist
import
MEMORY_FRACTION
from
text_generation_server.utils.dist
import
MEMORY_FRACTION
from
text_generation_server.utils.segments
import
SegmentConcatBuilder
,
find_segments
from
text_generation_server.utils.segments
import
SegmentConcatBuilder
,
find_segments
...
@@ -233,7 +237,8 @@ class FlashCausalLMBatch(Batch):
...
@@ -233,7 +237,8 @@ class FlashCausalLMBatch(Batch):
stopping_criterias
.
append
(
stopping_criteria
)
stopping_criterias
.
append
(
stopping_criteria
)
top_n_tokens
.
append
(
r
.
top_n_tokens
)
top_n_tokens
.
append
(
r
.
top_n_tokens
)
adapter_index
=
tgi_globals
.
ADAPTER_TO_INDEX
.
get
(
r
.
adapter_id
,
0
)
ADAPTER_TO_INDEX
=
get_adapter_to_index
()
adapter_index
=
ADAPTER_TO_INDEX
.
get
(
r
.
adapter_id
,
0
)
adapter_indices_list
.
append
(
torch
.
full
((
input_length
,),
adapter_index
))
adapter_indices_list
.
append
(
torch
.
full
((
input_length
,),
adapter_index
))
adapter_set
.
add
(
adapter_index
)
adapter_set
.
add
(
adapter_index
)
...
@@ -499,9 +504,8 @@ class FlashCausalLMBatch(Batch):
...
@@ -499,9 +504,8 @@ class FlashCausalLMBatch(Batch):
top_n_tokens
.
append
(
self
.
top_n_tokens
[
idx
])
top_n_tokens
.
append
(
self
.
top_n_tokens
[
idx
])
adapter_index
=
tgi_globals
.
ADAPTER_TO_INDEX
.
get
(
ADAPTER_TO_INDEX
=
get_adapter_to_index
()
self
.
requests
[
idx
].
adapter_id
,
0
adapter_index
=
ADAPTER_TO_INDEX
.
get
(
self
.
requests
[
idx
].
adapter_id
,
0
)
)
adapter_set
.
add
(
adapter_index
)
adapter_set
.
add
(
adapter_index
)
remaining_tokens
=
(
remaining_tokens
=
(
...
@@ -1017,7 +1021,7 @@ class FlashCausalLM(Model):
...
@@ -1017,7 +1021,7 @@ class FlashCausalLM(Model):
tunableop_filepath
=
os
.
path
.
join
(
tunableop_filepath
=
os
.
path
.
join
(
HUGGINGFACE_HUB_CACHE
,
HUGGINGFACE_HUB_CACHE
,
f
"tunableop_
{
tgi_globals
.
MODEL_ID
.
replace
(
'/'
,
'-'
)
}
_tp
{
self
.
world_size
}
_rank
{
self
.
rank
}
.csv"
,
f
"tunableop_
{
MODEL_ID
.
replace
(
'/'
,
'-'
)
}
_tp
{
self
.
world_size
}
_rank
{
self
.
rank
}
.csv"
,
)
)
logger
.
info
(
logger
.
info
(
...
...
server/text_generation_server/models/flash_gemma2.py
0 → 100644
View file @
3ea8259a
import
torch
import
torch.distributed
from
opentelemetry
import
trace
from
typing
import
Optional
from
transformers
import
PretrainedConfig
,
AutoTokenizer
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_gemma2_modeling
import
(
FlashGemma2ForCausalLM
,
)
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
Weights
,
)
tracer
=
trace
.
get_tracer
(
__name__
)
class
FlashGemma2
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
speculator
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
trust_remote_code
:
bool
=
False
,
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
bfloat16
if
dtype
is
None
else
dtype
else
:
raise
NotImplementedError
(
"FlashGemma2 is only available on GPU"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
,
trust_remote_code
=
trust_remote_code
,
)
config
=
PretrainedConfig
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
config
.
quantize
=
quantize
config
.
speculator
=
speculator
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
if
config
.
quantize
in
[
"gptq"
,
"awq"
,
"marlin"
]:
weights
.
_set_gptq_params
(
model_id
,
revision
)
# TODO hardcoded
prefix
=
""
model
=
FlashGemma2ForCausalLM
(
prefix
,
config
,
weights
,
causal
=
True
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashGemma2
,
self
).
__init__
(
model_id
=
model_id
,
model
=
model
,
tokenizer
=
tokenizer
,
num_layers
=
len
(
model
.
model
.
layers
),
num_kv_heads
=
model
.
model
.
num_key_value_heads
,
head_size
=
model
.
model
.
head_size
,
dtype
=
dtype
,
device
=
device
,
rank
=
rank
,
world_size
=
world_size
,
)
server/text_generation_server/models/globals.py
View file @
3ea8259a
...
@@ -44,3 +44,8 @@ ADAPTER_TO_INDEX: Dict[str, int] = None
...
@@ -44,3 +44,8 @@ ADAPTER_TO_INDEX: Dict[str, int] = None
def
set_adapter_to_index
(
adapter_to_index
:
Dict
[
str
,
int
]):
def
set_adapter_to_index
(
adapter_to_index
:
Dict
[
str
,
int
]):
global
ADAPTER_TO_INDEX
global
ADAPTER_TO_INDEX
ADAPTER_TO_INDEX
=
adapter_to_index
ADAPTER_TO_INDEX
=
adapter_to_index
def
get_adapter_to_index
():
global
ADAPTER_TO_INDEX
return
ADAPTER_TO_INDEX
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment