Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
0b7df771
Unverified
Commit
0b7df771
authored
Sep 26, 2024
by
Alvaro Bartolome
Committed by
GitHub
Sep 26, 2024
Browse files
Add LoRA adapters support for Gemma2 (#2567)
* Add LoRA adapters support for Gemma2 * Make `black` formatting happy
parent
7efcb5e0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
70 additions
and
14 deletions
+70
-14
server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py
...on_server/models/custom_modeling/flash_gemma2_modeling.py
+70
-14
No files found.
server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py
View file @
0b7df771
...
@@ -38,6 +38,8 @@ from text_generation_server.layers import (
...
@@ -38,6 +38,8 @@ from text_generation_server.layers import (
TensorParallelEmbedding
,
TensorParallelEmbedding
,
SpeculativeHead
,
SpeculativeHead
,
get_linear
,
get_linear
,
TensorParallelMultiAdapterLinear
,
TensorParallelAdapterRowLinear
,
)
)
from
text_generation_server.layers.rotary
import
PositionRotaryEmbedding
from
text_generation_server.layers.rotary
import
PositionRotaryEmbedding
from
text_generation_server.layers.layernorm
import
(
from
text_generation_server.layers.layernorm
import
(
...
@@ -161,7 +163,9 @@ def _load_gqa(config, prefix: str, weights):
...
@@ -161,7 +163,9 @@ def _load_gqa(config, prefix: str, weights):
class
FlashGemma2Attention
(
torch
.
nn
.
Module
):
class
FlashGemma2Attention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
,
causal
:
bool
,
is_sliding
:
bool
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
,
layer_id
,
causal
:
bool
,
is_sliding
:
bool
):
super
().
__init__
()
super
().
__init__
()
self
.
num_heads
=
config
.
num_attention_heads
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_size
=
config
.
head_dim
self
.
head_size
=
config
.
head_dim
...
@@ -192,14 +196,32 @@ class FlashGemma2Attention(torch.nn.Module):
...
@@ -192,14 +196,32 @@ class FlashGemma2Attention(torch.nn.Module):
)
)
self
.
softcap
=
config
.
attn_logit_softcapping
self
.
softcap
=
config
.
attn_logit_softcapping
self
.
query_key_value
=
load_attention
(
config
,
prefix
,
weights
)
query_key_value
=
load_attention
(
config
,
prefix
,
weights
)
self
.
query_key_value
=
TensorParallelMultiAdapterLinear
.
load
(
query_key_value
,
layer_id
,
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
sizes
=
[
self
.
head_size
*
config
.
num_attention_heads
,
self
.
head_size
*
config
.
num_key_value_heads
,
self
.
head_size
*
config
.
num_key_value_heads
,
],
process_group
=
weights
.
process_group
,
)
self
.
o_proj
=
TensorParallelRowLinear
.
load
(
o_proj
=
TensorParallelRowLinear
.
load
(
config
,
config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
weights
=
weights
,
weights
=
weights
,
bias
=
False
,
bias
=
False
,
)
)
self
.
o_proj
=
TensorParallelAdapterRowLinear
.
load
(
o_proj
,
layer_id
,
"o_proj"
,
process_group
=
weights
.
process_group
,
)
self
.
num_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
num_groups
=
self
.
num_heads
//
self
.
num_key_value_heads
self
.
kv_head_mapping
=
torch
.
arange
(
self
.
kv_head_mapping
=
torch
.
arange
(
0
,
self
.
num_key_value_heads
,
dtype
=
torch
.
int32
,
device
=
weights
.
device
0
,
self
.
num_key_value_heads
,
dtype
=
torch
.
int32
,
device
=
weights
.
device
...
@@ -216,8 +238,9 @@ class FlashGemma2Attention(torch.nn.Module):
...
@@ -216,8 +238,9 @@ class FlashGemma2Attention(torch.nn.Module):
slots
,
slots
,
seqlen
,
seqlen
,
max_s
,
max_s
,
adapter_data
,
):
):
qkv
=
self
.
query_key_value
(
hidden_states
)
qkv
=
self
.
query_key_value
(
hidden_states
,
adapter_data
)
query
,
kv
=
qkv
.
split
(
query
,
kv
=
qkv
.
split
(
[
[
self
.
head_size
*
self
.
num_heads
,
self
.
head_size
*
self
.
num_heads
,
...
@@ -260,11 +283,13 @@ class FlashGemma2Attention(torch.nn.Module):
...
@@ -260,11 +283,13 @@ class FlashGemma2Attention(torch.nn.Module):
softcap
=
self
.
softcap
,
softcap
=
self
.
softcap
,
)
)
return
self
.
o_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
))
return
self
.
o_proj
(
attn_output
.
view
(
-
1
,
self
.
num_heads
*
self
.
head_size
),
adapter_data
)
class
Gemma2MLP
(
nn
.
Module
):
class
Gemma2MLP
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
def
__init__
(
self
,
prefix
,
config
,
weights
,
layer_id
):
super
().
__init__
()
super
().
__init__
()
act
=
config
.
hidden_activation
act
=
config
.
hidden_activation
self
.
act
=
(
self
.
act
=
(
...
@@ -278,40 +303,65 @@ class Gemma2MLP(nn.Module):
...
@@ -278,40 +303,65 @@ class Gemma2MLP(nn.Module):
)
)
)
)
# Fuse gate and up proj
# Fuse gate and up proj
self
.
gate_up_proj
=
TensorParallelColumnLinear
.
load_multi
(
gate_up_proj
=
TensorParallelColumnLinear
.
load_multi
(
config
,
config
,
prefixes
=
[
f
"
{
prefix
}
.gate_proj"
,
f
"
{
prefix
}
.up_proj"
],
prefixes
=
[
f
"
{
prefix
}
.gate_proj"
,
f
"
{
prefix
}
.up_proj"
],
weights
=
weights
,
weights
=
weights
,
dim
=
0
,
dim
=
0
,
bias
=
False
,
bias
=
False
,
)
)
self
.
down_proj
=
TensorParallelRowLinear
.
load
(
self
.
gate_up_proj
=
TensorParallelMultiAdapterLinear
.
load
(
gate_up_proj
,
layer_id
,
[
"gate_proj"
,
"up_proj"
],
sizes
=
[
config
.
intermediate_size
,
config
.
intermediate_size
,
],
process_group
=
weights
.
process_group
,
)
down_proj
=
TensorParallelRowLinear
.
load
(
config
,
config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
weights
=
weights
,
weights
=
weights
,
bias
=
False
,
bias
=
False
,
)
)
self
.
down_proj
=
TensorParallelAdapterRowLinear
.
load
(
down_proj
,
layer_id
,
"down_proj"
,
process_group
=
weights
.
process_group
,
)
self
.
intermediate_size
=
(
self
.
intermediate_size
=
(
config
.
intermediate_size
//
weights
.
process_group
.
size
()
config
.
intermediate_size
//
weights
.
process_group
.
size
()
)
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
,
adapter_data
):
gate_up_states
=
self
.
gate_up_proj
(
hidden_states
)
gate_up_states
=
self
.
gate_up_proj
(
hidden_states
,
adapter_data
)
gate_up_states
=
gate_up_states
.
view
(
-
1
,
2
,
self
.
intermediate_size
)
gate_up_states
=
gate_up_states
.
view
(
-
1
,
2
,
self
.
intermediate_size
)
return
self
.
down_proj
(
self
.
act
(
gate_up_states
[:,
0
])
*
gate_up_states
[:,
1
])
return
self
.
down_proj
(
self
.
act
(
gate_up_states
[:,
0
])
*
gate_up_states
[:,
1
],
adapter_data
)
class
FlashGemma2Layer
(
nn
.
Module
):
class
FlashGemma2Layer
(
nn
.
Module
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
,
causal
:
bool
,
is_sliding
:
bool
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
,
layer_id
,
causal
:
bool
,
is_sliding
:
bool
):
super
().
__init__
()
super
().
__init__
()
self
.
self_attn
=
FlashGemma2Attention
(
self
.
self_attn
=
FlashGemma2Attention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
config
=
config
,
weights
=
weights
,
weights
=
weights
,
layer_id
=
layer_id
,
causal
=
causal
,
causal
=
causal
,
is_sliding
=
is_sliding
,
is_sliding
=
is_sliding
,
)
)
self
.
mlp
=
Gemma2MLP
(
prefix
=
f
"
{
prefix
}
.mlp"
,
config
=
config
,
weights
=
weights
)
self
.
mlp
=
Gemma2MLP
(
prefix
=
f
"
{
prefix
}
.mlp"
,
config
=
config
,
weights
=
weights
,
layer_id
=
layer_id
)
self
.
input_layernorm
=
Gemma2FastRMSNorm
.
load
(
self
.
input_layernorm
=
Gemma2FastRMSNorm
.
load
(
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
prefix
=
f
"
{
prefix
}
.input_layernorm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
...
@@ -344,6 +394,7 @@ class FlashGemma2Layer(nn.Module):
...
@@ -344,6 +394,7 @@ class FlashGemma2Layer(nn.Module):
slots
,
slots
,
seqlen
,
seqlen
,
max_s
,
max_s
,
adapter_data
,
):
):
normed_hidden_states
,
res
=
self
.
input_layernorm
(
hidden_states
,
residual
)
normed_hidden_states
,
res
=
self
.
input_layernorm
(
hidden_states
,
residual
)
...
@@ -358,6 +409,7 @@ class FlashGemma2Layer(nn.Module):
...
@@ -358,6 +409,7 @@ class FlashGemma2Layer(nn.Module):
slots
,
slots
,
seqlen
,
seqlen
,
max_s
,
max_s
,
adapter_data
,
)
)
# faster post attention rms norm
# faster post attention rms norm
...
@@ -366,7 +418,7 @@ class FlashGemma2Layer(nn.Module):
...
@@ -366,7 +418,7 @@ class FlashGemma2Layer(nn.Module):
res
=
normed_attn_res_output
res
=
normed_attn_res_output
pre_normed
,
_
=
self
.
pre_feedforward_layernorm
(
normed_attn_res_output
)
pre_normed
,
_
=
self
.
pre_feedforward_layernorm
(
normed_attn_res_output
)
mlp_output
=
self
.
mlp
(
pre_normed
)
mlp_output
=
self
.
mlp
(
pre_normed
,
adapter_data
)
post_hidden_states
,
_
=
self
.
post_feedforward_layernorm
(
mlp_output
)
post_hidden_states
,
_
=
self
.
post_feedforward_layernorm
(
mlp_output
)
return
post_hidden_states
,
normed_attn_res_output
return
post_hidden_states
,
normed_attn_res_output
...
@@ -385,6 +437,7 @@ class FlashGemma2Model(torch.nn.Module):
...
@@ -385,6 +437,7 @@ class FlashGemma2Model(torch.nn.Module):
prefix
=
f
"
{
prefix
}
.layers.
{
layer_id
}
"
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_id
}
"
,
config
=
config
,
config
=
config
,
weights
=
weights
,
weights
=
weights
,
layer_id
=
layer_id
,
causal
=
causal
,
causal
=
causal
,
is_sliding
=
layer_id
%
2
==
0
,
is_sliding
=
layer_id
%
2
==
0
,
)
)
...
@@ -409,6 +462,7 @@ class FlashGemma2Model(torch.nn.Module):
...
@@ -409,6 +462,7 @@ class FlashGemma2Model(torch.nn.Module):
slots
:
torch
.
Tensor
,
slots
:
torch
.
Tensor
,
seqlen
:
Seqlen
,
seqlen
:
Seqlen
,
max_s
:
int
,
max_s
:
int
,
adapter_data
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
...
@@ -431,6 +485,7 @@ class FlashGemma2Model(torch.nn.Module):
...
@@ -431,6 +485,7 @@ class FlashGemma2Model(torch.nn.Module):
slots
,
slots
,
seqlen
,
seqlen
,
max_s
,
max_s
,
adapter_data
,
)
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
@@ -492,6 +547,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
...
@@ -492,6 +547,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
slots
,
slots
,
seqlen
,
seqlen
,
max_s
,
max_s
,
adapter_data
,
)
)
if
lm_head_indices
is
not
None
:
if
lm_head_indices
is
not
None
:
hidden_states
=
hidden_states
[
lm_head_indices
]
hidden_states
=
hidden_states
[
lm_head_indices
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment