Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
853d4eb9
Unverified
Commit
853d4eb9
authored
Jul 05, 2024
by
Nicolas Patry
Browse files
Hotfixing after refactor.
parent
fb2f74e2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
9 deletions
+11
-9
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+8
-8
server/text_generation_server/models/model.py
server/text_generation_server/models/model.py
+3
-1
No files found.
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
853d4eb9
...
...
@@ -355,7 +355,7 @@ class Block(nn.Module):
self
.
ln_2
=
FastLayerNorm
.
load
(
prefix
=
f
"
{
prefix
}
.ln_2"
,
weights
=
weights
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
FlashMQAttention
(
self
.
self_
attn
=
FlashMQAttention
(
prefix
=
f
"
{
prefix
}
.attn"
,
config
=
config
,
weights
=
weights
,
...
...
@@ -378,7 +378,7 @@ class Block(nn.Module):
max_s
,
):
hidden_states
,
residual
=
self
.
ln_1
(
hidden_states
,
residual
)
hidden_states
=
self
.
attn
(
hidden_states
=
self
.
self_
attn
(
hidden_states
,
cu_seqlen_prefill
,
kv_cache
,
...
...
@@ -412,7 +412,7 @@ class FlashSantacoderModel(nn.Module):
reduce
=
False
,
)
self
.
h
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
Block
(
layer_id
,
...
...
@@ -426,8 +426,8 @@ class FlashSantacoderModel(nn.Module):
prefix
=
"transformer.ln_f"
,
weights
=
weights
,
eps
=
config
.
layer_norm_epsilon
)
self
.
head_size
=
self
.
h
[
0
].
attn
.
head_size
self
.
num_heads
=
self
.
h
[
0
].
attn
.
num_heads
self
.
head_size
=
self
.
layers
[
0
].
self_
attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_
attn
.
num_heads
def
forward
(
self
,
...
...
@@ -446,7 +446,7 @@ class FlashSantacoderModel(nn.Module):
torch
.
distributed
.
all_reduce
(
hidden_states
,
group
=
self
.
process_group
)
residual
=
None
for
i
,
layer
in
enumerate
(
self
.
h
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
...
...
@@ -467,7 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
config
.
transpose
=
config
.
architectures
[
0
].
startswith
(
"GPT2"
)
self
.
transformer
=
FlashSantacoderModel
(
config
,
weights
)
self
.
model
=
FlashSantacoderModel
(
config
,
weights
)
self
.
lm_head
=
SpeculativeHead
.
load
(
config
,
prefix
=
"transformer.wte"
,
weights
=
weights
)
...
...
@@ -486,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module):
lm_head_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
adapter_data
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
hidden_states
=
self
.
model
(
input_ids
,
position_ids
,
cu_seqlen_prefill
,
...
...
server/text_generation_server/models/model.py
View file @
853d4eb9
...
...
@@ -60,7 +60,7 @@ class Model(ABC):
self
.
layer_to_adapter_weights
:
Dict
[
str
,
LayerAdapterWeights
]
=
defaultdict
(
LayerAdapterWeights
)
self
.
target_to_layer
=
self
.
adapter_target_to_layer
()
self
.
target_to_layer
=
None
self
.
loaded_adapters
=
set
()
self
.
static_adapter_id
=
adapter_id
...
...
@@ -187,6 +187,8 @@ class Model(ABC):
into model. Otherwise, the adapter weights are applied during the forward
pass and stored separately from the base model parameters.
"""
if
self
.
target_to_layer
is
None
:
self
.
target_to_layer
=
self
.
adapter_target_to_layer
()
if
adapter_index
in
self
.
loaded_adapters
:
# Adapter already loaded
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment