Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
4b49c50f
Unverified
Commit
4b49c50f
authored
Jul 26, 2024
by
Daniël de Kok
Committed by
GitHub
Jul 26, 2024
Browse files
Support tied embeddings in 0.5B and 1.5B Qwen2 models (#2313)
parent
3905f854
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
5 deletions
+9
-5
server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py
...ion_server/models/custom_modeling/flash_qwen2_modeling.py
+9
-5
No files found.
server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py
View file @
4b49c50f
...
@@ -262,6 +262,9 @@ class Qwen2Layer(nn.Module):
...
@@ -262,6 +262,9 @@ class Qwen2Layer(nn.Module):
class
Qwen2Model
(
torch
.
nn
.
Module
):
class
Qwen2Model
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
):
super
().
__init__
()
super
().
__init__
()
prefix
=
f
"
{
prefix
}
.model"
if
prefix
else
"model"
process_group
=
weights
.
process_group
process_group
=
weights
.
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
tp_world_size
=
process_group
.
size
()
...
@@ -335,15 +338,16 @@ class Qwen2ForCausalLM(torch.nn.Module):
...
@@ -335,15 +338,16 @@ class Qwen2ForCausalLM(torch.nn.Module):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
):
def
__init__
(
self
,
prefix
:
str
,
config
,
weights
):
super
().
__init__
()
super
().
__init__
()
if
not
prefix
:
self
.
model
=
Qwen2Model
(
prefix
,
config
,
weights
)
prefix
=
"model"
if
config
.
tie_word_embeddings
:
suffix
=
"model.embed_tokens"
else
:
else
:
pre
fix
=
f
"
{
prefix
}
.model
"
suf
fix
=
"lm_head
"
self
.
model
=
Qwen2Model
(
prefix
,
config
,
weights
)
self
.
lm_head
=
SpeculativeHead
.
load
(
self
.
lm_head
=
SpeculativeHead
.
load
(
config
,
config
,
prefix
=
"lm_head"
,
prefix
=
f
"
{
prefix
}
.
{
suffix
}
"
if
prefix
else
suffix
,
weights
=
weights
,
weights
=
weights
,
)
)
self
.
max_past
=
config
.
sliding_window
self
.
max_past
=
config
.
sliding_window
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment