Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
853d4eb9
Unverified
Commit
853d4eb9
authored
Jul 05, 2024
by
Nicolas Patry
Browse files
Hotfixing after refactor.
parent
fb2f74e2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
9 deletions
+11
-9
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+8
-8
server/text_generation_server/models/model.py
server/text_generation_server/models/model.py
+3
-1
No files found.
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
853d4eb9
...
@@ -355,7 +355,7 @@ class Block(nn.Module):
...
@@ -355,7 +355,7 @@ class Block(nn.Module):
self
.
ln_2
=
FastLayerNorm
.
load
(
self
.
ln_2
=
FastLayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.ln_2"
,
weights
=
weights
,
eps
=
config
.
layer_norm_epsilon
prefix
=
f
"
{
prefix
}
.ln_2"
,
weights
=
weights
,
eps
=
config
.
layer_norm_epsilon
)
)
self
.
attn
=
FlashMQAttention
(
self
.
self_
attn
=
FlashMQAttention
(
prefix
=
f
"
{
prefix
}
.attn"
,
prefix
=
f
"
{
prefix
}
.attn"
,
config
=
config
,
config
=
config
,
weights
=
weights
,
weights
=
weights
,
...
@@ -378,7 +378,7 @@ class Block(nn.Module):
...
@@ -378,7 +378,7 @@ class Block(nn.Module):
max_s
,
max_s
,
):
):
hidden_states
,
residual
=
self
.
ln_1
(
hidden_states
,
residual
)
hidden_states
,
residual
=
self
.
ln_1
(
hidden_states
,
residual
)
hidden_states
=
self
.
attn
(
hidden_states
=
self
.
self_
attn
(
hidden_states
,
hidden_states
,
cu_seqlen_prefill
,
cu_seqlen_prefill
,
kv_cache
,
kv_cache
,
...
@@ -412,7 +412,7 @@ class FlashSantacoderModel(nn.Module):
...
@@ -412,7 +412,7 @@ class FlashSantacoderModel(nn.Module):
reduce
=
False
,
reduce
=
False
,
)
)
self
.
h
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
[
Block
(
Block
(
layer_id
,
layer_id
,
...
@@ -426,8 +426,8 @@ class FlashSantacoderModel(nn.Module):
...
@@ -426,8 +426,8 @@ class FlashSantacoderModel(nn.Module):
prefix
=
"transformer.ln_f"
,
weights
=
weights
,
eps
=
config
.
layer_norm_epsilon
prefix
=
"transformer.ln_f"
,
weights
=
weights
,
eps
=
config
.
layer_norm_epsilon
)
)
self
.
head_size
=
self
.
h
[
0
].
attn
.
head_size
self
.
head_size
=
self
.
layers
[
0
].
self_
attn
.
head_size
self
.
num_heads
=
self
.
h
[
0
].
attn
.
num_heads
self
.
num_heads
=
self
.
layers
[
0
].
self_
attn
.
num_heads
def
forward
(
def
forward
(
self
,
self
,
...
@@ -446,7 +446,7 @@ class FlashSantacoderModel(nn.Module):
...
@@ -446,7 +446,7 @@ class FlashSantacoderModel(nn.Module):
torch
.
distributed
.
all_reduce
(
hidden_states
,
group
=
self
.
process_group
)
torch
.
distributed
.
all_reduce
(
hidden_states
,
group
=
self
.
process_group
)
residual
=
None
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
h
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
hidden_states
,
hidden_states
,
residual
,
residual
,
...
@@ -467,7 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module):
...
@@ -467,7 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module):
def
__init__
(
self
,
prefix
,
config
,
weights
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
super
().
__init__
()
config
.
transpose
=
config
.
architectures
[
0
].
startswith
(
"GPT2"
)
config
.
transpose
=
config
.
architectures
[
0
].
startswith
(
"GPT2"
)
self
.
transformer
=
FlashSantacoderModel
(
config
,
weights
)
self
.
model
=
FlashSantacoderModel
(
config
,
weights
)
self
.
lm_head
=
SpeculativeHead
.
load
(
self
.
lm_head
=
SpeculativeHead
.
load
(
config
,
prefix
=
"transformer.wte"
,
weights
=
weights
config
,
prefix
=
"transformer.wte"
,
weights
=
weights
)
)
...
@@ -486,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module):
...
@@ -486,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module):
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
adapter_data
:
Optional
[
torch
.
Tensor
]
=
None
,
adapter_data
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
hidden_states
=
self
.
model
(
input_ids
,
input_ids
,
position_ids
,
position_ids
,
cu_seqlen_prefill
,
cu_seqlen_prefill
,
...
...
server/text_generation_server/models/model.py
View file @
853d4eb9
...
@@ -60,7 +60,7 @@ class Model(ABC):
...
@@ -60,7 +60,7 @@ class Model(ABC):
self
.
layer_to_adapter_weights
:
Dict
[
str
,
LayerAdapterWeights
]
=
defaultdict
(
self
.
layer_to_adapter_weights
:
Dict
[
str
,
LayerAdapterWeights
]
=
defaultdict
(
LayerAdapterWeights
LayerAdapterWeights
)
)
self
.
target_to_layer
=
self
.
adapter_target_to_layer
()
self
.
target_to_layer
=
None
self
.
loaded_adapters
=
set
()
self
.
loaded_adapters
=
set
()
self
.
static_adapter_id
=
adapter_id
self
.
static_adapter_id
=
adapter_id
...
@@ -187,6 +187,8 @@ class Model(ABC):
...
@@ -187,6 +187,8 @@ class Model(ABC):
into model. Otherwise, the adapter weights are applied during the forward
into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters.
pass and stored separately from the base model parameters.
"""
"""
if
self
.
target_to_layer
is
None
:
self
.
target_to_layer
=
self
.
adapter_target_to_layer
()
if
adapter_index
in
self
.
loaded_adapters
:
if
adapter_index
in
self
.
loaded_adapters
:
# Adapter already loaded
# Adapter already loaded
return
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment