Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
30620a9a
Commit
30620a9a
authored
Apr 10, 2024
by
OlivierDehaene
Browse files
hotfix: mixtral
parent
ad9d6288
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
8 deletions
+14
-8
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
...n_server/models/custom_modeling/flash_mixtral_modeling.py
+14
-8
No files found.
server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py
View file @
30620a9a
...
...
@@ -464,9 +464,9 @@ class DenseMoE(nn.Module):
class
MixtralLayer
(
nn
.
Module
):
def
__init__
(
self
,
layer_id
,
config
,
weights
):
def
__init__
(
self
,
prefix
,
layer_id
,
config
,
weights
):
super
().
__init__
()
prefix
=
f
"
model
.layers.
{
layer_id
}
"
prefix
=
f
"
{
prefix
}
.layers.
{
layer_id
}
"
self
.
self_attn
=
MixtralAttention
(
prefix
=
f
"
{
prefix
}
.self_attn"
,
config
=
config
,
weights
=
weights
...
...
@@ -525,16 +525,20 @@ class MixtralLayer(nn.Module):
class
MixtralModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
self
.
embed_tokens
=
TensorParallelEmbedding
(
prefix
=
"model.embed_tokens"
,
weights
=
weights
prefix
=
(
"model.embed_tokens"
if
not
prefix
else
f
"
{
prefix
}
.model.embed_tokens"
),
weights
=
weights
,
)
self
.
layers
=
nn
.
ModuleList
(
[
MixtralLayer
(
"model"
if
not
prefix
else
f
"
{
prefix
}
.model"
,
layer_id
,
config
,
weights
,
...
...
@@ -543,7 +547,9 @@ class MixtralModel(torch.nn.Module):
]
)
self
.
norm
=
FastRMSNorm
.
load
(
prefix
=
"model.norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
prefix
=
"model.norm"
if
not
prefix
else
f
"
{
prefix
}
.model.norm"
,
weights
=
weights
,
eps
=
config
.
rms_norm_eps
,
)
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
...
...
@@ -593,13 +599,13 @@ class MixtralModel(torch.nn.Module):
class
FlashMixtralForCausalLM
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
def
__init__
(
self
,
prefix
,
config
,
weights
):
super
().
__init__
()
self
.
model
=
MixtralModel
(
config
,
weights
)
self
.
model
=
MixtralModel
(
prefix
,
config
,
weights
)
self
.
lm_head
=
SpeculativeHead
.
load
(
config
,
prefix
=
"lm_head"
,
prefix
=
"lm_head"
if
not
prefix
else
f
"
{
prefix
}
.lm_head"
,
weights
=
weights
,
)
self
.
max_past
=
config
.
sliding_window
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment