Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
eefea5ee
Unverified
Commit
eefea5ee
authored
Apr 12, 2024
by
OlivierDehaene
Committed by
GitHub
Apr 12, 2024
Browse files
feat: medusa v2 (#1734)
parent
1b2670c8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
147 additions
and
46 deletions
+147
-46
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+1
-1
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+9
-17
server/text_generation_server/utils/layers.py
server/text_generation_server/utils/layers.py
+137
-28
No files found.
server/text_generation_server/models/__init__.py
View file @
eefea5ee
...
@@ -145,7 +145,7 @@ def get_model(
...
@@ -145,7 +145,7 @@ def get_model(
if
speculate
is
not
None
:
if
speculate
is
not
None
:
if
speculate
>
speculate_medusa
:
if
speculate
>
speculate_medusa
:
raise
RuntimeError
(
raise
RuntimeError
(
"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match"
f
"Speculate is set to `
{
speculate
}
` but this medusa models only has `
{
speculate_medusa
}
` heads, please make them match"
)
)
else
:
else
:
set_speculate
(
speculate
)
set_speculate
(
speculate
)
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
eefea5ee
...
@@ -814,7 +814,7 @@ class FlashCausalLM(Model):
...
@@ -814,7 +814,7 @@ class FlashCausalLM(Model):
for
bs
in
CUDA_GRAPHS
:
for
bs
in
CUDA_GRAPHS
:
if
self
.
speculate
is
None
or
self
.
speculate
+
1
<=
bs
:
if
self
.
speculate
is
None
or
self
.
speculate
+
1
<=
bs
:
self
.
cuda_graph_warmup
(
bs
,
max_s
,
max_bt
)
self
.
cuda_graph_warmup
(
bs
,
max_s
,
max_bt
)
except
Exception
:
except
torch
.
cuda
.
OutOfMemoryError
:
logger
.
exception
(
f
"Decode cuda graph warmup failed"
)
logger
.
exception
(
f
"Decode cuda graph warmup failed"
)
return
int
(
num_blocks
*
BLOCK_SIZE
)
return
int
(
num_blocks
*
BLOCK_SIZE
)
...
@@ -874,22 +874,14 @@ class FlashCausalLM(Model):
...
@@ -874,22 +874,14 @@ class FlashCausalLM(Model):
lm_head_indices
=
batch
.
prefill_head_indices
lm_head_indices
=
batch
.
prefill_head_indices
bs
=
input_ids
.
shape
[
0
]
bs
=
input_ids
.
shape
[
0
]
padded_bs
=
bs
sorted_padded_bs
=
sorted
([
k
for
k
in
self
.
cuda_graphs
.
keys
()
if
k
>=
bs
])
if
bs
==
3
:
if
sorted_padded_bs
:
padded_bs
=
4
# Get associated cuda graph
elif
3
<
bs
<=
8
:
cuda_graph
=
self
.
cuda_graphs
[
sorted_padded_bs
[
0
]]
padded_bs
=
8
else
:
elif
bs
>
8
:
cuda_graph
=
None
padded_bs
=
(
bs
+
7
)
//
8
*
8
# Try to find an associated cuda graph
cuda_graph
=
self
.
cuda_graphs
.
get
(
padded_bs
,
None
)
if
(
if
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
:
cu_seqlen_prefill
is
not
None
or
cuda_graph
is
None
or
batch
.
speculative_ids
is
not
None
):
return
self
.
model
.
forward
(
return
self
.
model
.
forward
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
...
...
server/text_generation_server/utils/layers.py
View file @
eefea5ee
...
@@ -432,12 +432,12 @@ class ResBlock(torch.nn.Module):
...
@@ -432,12 +432,12 @@ class ResBlock(torch.nn.Module):
class
MedusaModel
(
torch
.
nn
.
Module
):
class
MedusaModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
weights
):
def
__init__
(
self
,
config
,
medusa_config
,
weights
):
super
().
__init__
()
super
().
__init__
()
self
.
heads
=
torch
.
nn
.
ModuleList
(
self
.
heads
=
torch
.
nn
.
ModuleList
(
[
[
MedusaHead
(
config
,
prefix
=
f
"
{
i
}
"
,
weights
=
weights
)
MedusaHead
(
config
,
medusa_config
,
prefix
=
f
"
{
i
}
"
,
weights
=
weights
)
for
i
in
range
(
config
[
"medusa_num_heads"
])
for
i
in
range
(
medusa_
config
[
"medusa_num_heads"
])
]
]
)
)
...
@@ -447,12 +447,12 @@ class MedusaModel(torch.nn.Module):
...
@@ -447,12 +447,12 @@ class MedusaModel(torch.nn.Module):
class
MedusaHead
(
torch
.
nn
.
Module
):
class
MedusaHead
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
prefix
,
weights
):
def
__init__
(
self
,
config
,
medusa_config
,
prefix
,
weights
):
super
().
__init__
()
super
().
__init__
()
self
.
blocks
=
torch
.
nn
.
ModuleList
(
self
.
blocks
=
torch
.
nn
.
ModuleList
(
[
[
ResBlock
(
config
,
prefix
=
f
"
{
prefix
}
.
{
i
}
"
,
weights
=
weights
)
ResBlock
(
config
,
prefix
=
f
"
{
prefix
}
.
{
i
}
"
,
weights
=
weights
)
for
i
in
range
(
config
[
"medusa_num_layers"
])
for
i
in
range
(
medusa_
config
[
"medusa_num_layers"
])
]
]
)
)
n
=
len
(
self
.
blocks
)
n
=
len
(
self
.
blocks
)
...
@@ -467,7 +467,7 @@ class MedusaHead(torch.nn.Module):
...
@@ -467,7 +467,7 @@ class MedusaHead(torch.nn.Module):
return
x
return
x
class
Speculative
Head
(
nn
.
Module
):
class
Medusa
Head
V1
(
nn
.
Module
):
def
__init__
(
self
,
lm_head
,
medusa
):
def
__init__
(
self
,
lm_head
,
medusa
):
super
().
__init__
()
super
().
__init__
()
self
.
lm_head
=
lm_head
self
.
lm_head
=
lm_head
...
@@ -475,38 +475,147 @@ class SpeculativeHead(nn.Module):
...
@@ -475,38 +475,147 @@ class SpeculativeHead(nn.Module):
@
staticmethod
@
staticmethod
def
load
(
config
,
prefix
:
str
,
weights
):
def
load
(
config
,
prefix
:
str
,
weights
):
lm_head
=
TensorParallelHead
.
load
(
config
,
prefix
,
weights
)
from
pathlib
import
Path
from
safetensors
import
safe_open
import
json
use_medusa
=
config
.
use_medusa
use_medusa
=
config
.
use_medusa
if
use_medusa
:
medusa_config
=
str
(
Path
(
use_medusa
)
/
"config.json"
)
filename
=
str
(
Path
(
use_medusa
)
/
"medusa_lm_head.safetensors"
)
with
open
(
medusa_config
,
"r"
)
as
f
:
medusa_config
=
json
.
load
(
f
)
routing
=
weights
.
routing
with
safe_open
(
filename
,
framework
=
"pytorch"
)
as
f
:
for
k
in
f
.
keys
():
if
k
in
routing
and
routing
[
k
]
!=
filename
:
raise
RuntimeError
(
f
"Key
{
k
}
was found in multiple files:
{
filename
}
and
{
routing
[
k
]
}
"
)
routing
[
k
]
=
filename
medusa
=
MedusaModel
(
config
,
medusa_config
,
weights
)
lm_head
=
TensorParallelHead
.
load
(
config
,
prefix
,
weights
)
return
MedusaHeadV1
(
lm_head
,
medusa
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
logits
=
self
.
lm_head
(
input
)
speculative_logits
=
self
.
medusa
(
input
)
return
logits
,
speculative_logits
class
MedusaHeadV2
(
nn
.
Module
):
def
__init__
(
self
,
config
,
prefix
,
weights
):
super
().
__init__
()
from
pathlib
import
Path
from
pathlib
import
Path
from
safetensors
import
safe_open
from
safetensors
import
safe_open
import
json
import
json
use_medusa
=
config
.
use_medusa
medusa_config
=
str
(
Path
(
use_medusa
)
/
"config.json"
)
medusa_config
=
str
(
Path
(
use_medusa
)
/
"config.json"
)
filename
=
str
(
Path
(
use_medusa
)
/
"medusa_lm_head.safetensors"
)
filename
=
str
(
Path
(
use_medusa
)
/
"medusa_lm_head.safetensors"
)
with
open
(
medusa_config
,
"r"
)
as
f
:
with
open
(
medusa_config
,
"r"
)
as
f
:
config
=
json
.
load
(
f
)
medusa_
config
=
json
.
load
(
f
)
routing
=
weights
.
routing
routing
=
weights
.
routing
with
safe_open
(
filename
,
framework
=
"pytorch"
)
as
f
:
with
safe_open
(
filename
,
framework
=
"pytorch"
)
as
f
:
for
k
in
f
.
keys
():
for
k
in
f
.
keys
():
if
k
in
routing
:
if
k
in
routing
and
routing
[
k
]
!=
filename
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Key
{
k
}
was found in multiple files:
{
filename
}
and
{
routing
[
k
]
}
"
f
"Key
{
k
}
was found in multiple files:
{
filename
}
and
{
routing
[
k
]
}
"
)
)
weights
.
routing
[
k
]
=
filename
routing
[
k
]
=
filename
self
.
n_medusa_heads
=
medusa_config
[
"medusa_num_heads"
]
assert
medusa_config
[
"medusa_num_layers"
]
==
1
self
.
linear
=
TensorParallelColumnLinear
.
load_multi
(
config
,
prefixes
=
[
f
"
{
i
}
.0.linear"
for
i
in
range
(
self
.
n_medusa_heads
)],
dim
=
0
,
weights
=
weights
,
bias
=
True
,
)
self
.
process_group
=
weights
.
process_group
self
.
world_size
=
self
.
process_group
.
size
()
self
.
rank
=
self
.
process_group
.
rank
()
medusa
=
MedusaModel
(
config
,
weights
)
self
.
act
=
torch
.
nn
.
SiLU
()
self
.
lm_head
=
TensorParallelHead
.
load
(
config
,
prefix
,
weights
)
def
forward
(
self
,
x
):
size
=
x
.
shape
[
-
1
]
block_size
=
(
size
+
self
.
world_size
-
1
)
//
self
.
world_size
start
=
self
.
rank
*
block_size
stop
=
(
self
.
rank
+
1
)
*
block_size
x_block
=
x
[:,
start
:
stop
]
# Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1
medusa_res
=
self
.
act
(
self
.
linear
(
x
)).
reshape
(
*
x_block
.
shape
[:
-
1
],
self
.
n_medusa_heads
,
x_block
.
shape
[
-
1
]
)
# Apply all residual medusa heads
output
=
x
[:,
start
:
stop
].
unsqueeze
(
-
2
)
+
medusa_res
# Gather medusa heads
world_output
=
[
torch
.
empty_like
(
output
)
for
_
in
range
(
self
.
process_group
.
size
())
]
torch
.
distributed
.
all_gather
(
world_output
,
output
,
group
=
self
.
process_group
)
world_output
=
torch
.
cat
(
world_output
,
dim
=-
1
)
# Stack x and medusa residual x
stacked_x
=
torch
.
cat
([
x
.
unsqueeze
(
-
2
),
world_output
],
dim
=-
2
)
# Compute lm head on x + medusa residual x
logits
=
self
.
lm_head
(
stacked_x
)
# Finally, split logits from speculative logits
logits
,
speculative_logits
=
torch
.
split
(
logits
,
[
1
,
self
.
n_medusa_heads
],
dim
=-
2
)
# Squeeze added dimension
logits
=
logits
.
squeeze
(
-
2
)
return
logits
,
speculative_logits
class
SpeculativeHead
(
nn
.
Module
):
def
__init__
(
self
,
lm_head
,
medusa
):
super
().
__init__
()
self
.
head
=
lm_head
self
.
medusa
=
medusa
@
staticmethod
def
load
(
config
,
prefix
:
str
,
weights
):
use_medusa
=
config
.
use_medusa
if
use_medusa
:
lm_head
=
None
try
:
medusa
=
MedusaHeadV1
.
load
(
config
,
prefix
,
weights
)
except
:
medusa
=
MedusaHeadV2
(
config
,
prefix
,
weights
)
else
:
else
:
lm_head
=
TensorParallelHead
.
load
(
config
,
prefix
,
weights
)
medusa
=
None
medusa
=
None
return
SpeculativeHead
(
lm_head
,
medusa
)
return
SpeculativeHead
(
lm_head
,
medusa
)
def
forward
(
def
forward
(
self
,
input
:
torch
.
Tensor
self
,
input
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
logits
=
self
.
lm_head
(
input
)
if
self
.
medusa
is
not
None
:
speculative_logits
=
self
.
medusa
(
input
)
if
self
.
medusa
is
not
None
else
None
return
self
.
medusa
(
input
)
return
logits
,
speculative_logits
assert
self
.
head
is
not
None
logits
=
self
.
head
(
input
)
return
logits
,
None
class
TensorParallelHead
(
SuperLayer
):
class
TensorParallelHead
(
SuperLayer
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment