Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
be0b3af9
Unverified
Commit
be0b3af9
authored
Jun 29, 2024
by
wangding zeng
Committed by
GitHub
Jun 28, 2024
Browse files
Support Deepseek-V2 (#4650)
Co-authored-by:
Philipp Moritz
<
pcmoritz@gmail.com
>
parent
2cd402e1
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
700 additions
and
1 deletion
+700
-1
vllm/config.py
vllm/config.py
+6
-0
vllm/model_executor/layers/fused_moe/__init__.py
vllm/model_executor/layers/fused_moe/__init__.py
+2
-1
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+31
-0
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+126
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+1
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+534
-0
No files found.
vllm/config.py
View file @
be0b3af9
...
...
@@ -297,6 +297,12 @@ class ModelConfig:
return
self
.
hf_text_config
.
hidden_size
def
get_head_size
(
self
)
->
int
:
# TODO remove hard code
if
hasattr
(
self
.
hf_text_config
,
"model_type"
)
and
self
.
hf_text_config
.
model_type
==
'deepseek_v2'
:
# FlashAttention supports only head_size 32, 64, 128, 256,
# we need to pad head_size 192 to 256
return
256
if
hasattr
(
self
.
hf_text_config
,
"head_dim"
):
return
self
.
hf_text_config
.
head_dim
# FIXME(woosuk): This may not be true for all models.
...
...
vllm/model_executor/layers/fused_moe/__init__.py
View file @
be0b3af9
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
,
fused_moe
,
fused_topk
,
get_config_file_name
)
fused_experts
,
fused_moe
,
fused_topk
,
get_config_file_name
,
grouped_topk
)
__all__
=
[
"fused_moe"
,
"fused_topk"
,
"fused_experts"
,
"get_config_file_name"
,
"grouped_topk"
,
]
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
be0b3af9
...
...
@@ -367,6 +367,37 @@ def fused_topk(
return
topk_weights
,
topk_ids
# This is used by the Deepseek-V2 model
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
,
):
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
num_token
=
scores
.
shape
[
0
]
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
group_mask
.
scatter_
(
1
,
group_idx
,
1
)
# [n, n_group]
score_mask
=
group_mask
.
unsqueeze
(
-
1
).
expand
(
num_token
,
num_expert_group
,
scores
.
shape
[
-
1
]
//
num_expert_group
).
reshape
(
num_token
,
-
1
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
topk_weights
,
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
return
topk_weights
,
topk_ids
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
be0b3af9
...
...
@@ -610,6 +610,119 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
return
query
.
flatten
(
-
2
),
key
.
flatten
(
-
2
)
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
class
DeepseekScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
*
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
beta_fast
:
int
=
32
,
beta_slow
:
int
=
1
,
mscale
:
float
=
1
,
mscale_all_dim
:
float
=
0
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
self
.
beta_fast
=
beta_fast
self
.
beta_slow
=
beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale
))
/
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale_all_dim
))
*
attn_factor
)
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
,
device
=
"cuda"
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
scaling_factor
*
pos_freqs
)
low
,
high
=
_yarn_find_correction_range
(
self
.
beta_fast
,
self
.
beta_slow
,
self
.
rotary_dim
,
self
.
base
,
self
.
max_position_embeddings
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask
=
(
1
-
_yarn_linear_ramp_mask
(
low
,
high
,
self
.
rotary_dim
//
2
,
dtype
=
torch
.
float
))
*
self
.
extrapolation_factor
inv_freq
=
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
(
freqs
.
cos
()
*
self
.
mscale
)
sin
=
(
freqs
.
sin
()
*
self
.
mscale
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
print
(
"Cache shape"
,
cache
.
shape
)
return
cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
else
:
query
=
query_rot
key
=
key_rot
return
query
,
key
class
GemmaRotaryEmbedding
(
RotaryEmbedding
):
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
...
...
@@ -679,6 +792,19 @@ def get_rope(
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
)
elif
scaling_type
==
"deepseek_yarn"
:
original_max_position
=
rope_scaling
[
"original_max_position_embeddings"
]
# assert max_position == original_max_position * scaling_factor
extra_kwargs
=
{
k
:
v
for
k
,
v
in
rope_scaling
.
items
()
if
k
in
(
"extrapolation_factor"
,
"attn_factor"
,
"beta_fast"
,
"beta_slow"
,
"mscale"
,
"mscale_all_dim"
)
}
rotary_emb
=
DeepseekScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
original_max_position
,
base
,
is_neox_style
,
scaling_factor
,
dtype
,
**
extra_kwargs
)
# The correct one should be "longrope" but keep "su" here
# for backward compatible
elif
scaling_type
==
"su"
or
scaling_type
==
"longrope"
:
...
...
vllm/model_executor/models/__init__.py
View file @
be0b3af9
...
...
@@ -21,6 +21,7 @@ _GENERATION_MODELS = {
"DbrxForCausalLM"
:
(
"dbrx"
,
"DbrxForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"DeepseekV2ForCausalLM"
:
(
"deepseek_v2"
,
"DeepseekV2ForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"GemmaForCausalLM"
:
(
"gemma"
,
"GemmaForCausalLM"
),
"Gemma2ForCausalLM"
:
(
"gemma2"
,
"Gemma2ForCausalLM"
),
...
...
vllm/model_executor/models/deepseek_v2.py
0 → 100644
View file @
be0b3af9
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment