Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7c4f76e3
"docs/vscode:/vscode.git/clone" did not exist on "3fb17d26c83d6314c50c1c2eedc61625738a047d"
Commit
7c4f76e3
authored
Apr 15, 2024
by
zhuwenwen
Browse files
merge v0.4.0
parents
2da0dd3e
51c31bc1
Changes
332
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1775 additions
and
473 deletions
+1775
-473
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+48
-45
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+31
-29
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+29
-26
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+27
-25
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+26
-24
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+67
-30
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+457
-0
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+36
-32
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+36
-36
vllm/model_executor/models/xverse.py
vllm/model_executor/models/xverse.py
+372
-0
vllm/model_executor/neuron_model_loader.py
vllm/model_executor/neuron_model_loader.py
+91
-23
vllm/model_executor/parallel_utils/communication_op.py
vllm/model_executor/parallel_utils/communication_op.py
+9
-11
vllm/model_executor/parallel_utils/cupy_utils.py
vllm/model_executor/parallel_utils/cupy_utils.py
+0
-130
vllm/model_executor/parallel_utils/custom_all_reduce.py
vllm/model_executor/parallel_utils/custom_all_reduce.py
+46
-9
vllm/model_executor/parallel_utils/parallel_state.py
vllm/model_executor/parallel_utils/parallel_state.py
+18
-18
vllm/model_executor/parallel_utils/pynccl.py
vllm/model_executor/parallel_utils/pynccl.py
+264
-0
vllm/model_executor/parallel_utils/pynccl_utils.py
vllm/model_executor/parallel_utils/pynccl_utils.py
+69
-0
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+126
-9
vllm/model_executor/utils.py
vllm/model_executor/utils.py
+0
-17
vllm/model_executor/weight_utils.py
vllm/model_executor/weight_utils.py
+23
-9
No files found.
vllm/model_executor/models/olmo.py
View file @
7c4f76e3
...
@@ -40,33 +40,27 @@ from typing import List, Optional, Tuple
...
@@ -40,33 +40,27 @@ from typing import List, Optional, Tuple
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
# this model must need this dependency
from
hf_olmo
import
OLMoConfig
from
torch
import
nn
from
torch
import
nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
ColumnParallelLinear
,
QKVParallelLinear
,
LinearMethodBase
,
RowParallelLinear
)
QKVParallelLinear
,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
RowParallelLinear
,
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
,
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
,
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
# this model must need this dependency
from
hf_olmo
import
OLMoConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
SwiGLU
(
nn
.
Module
):
class
SwiGLU
(
nn
.
Module
):
...
@@ -81,7 +75,8 @@ class SwiGLU(nn.Module):
...
@@ -81,7 +75,8 @@ class SwiGLU(nn.Module):
class
OlmoAttention
(
nn
.
Module
):
class
OlmoAttention
(
nn
.
Module
):
"""
"""
This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
This is the attention block where the output is computed as
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
(plus another skip connection).
"""
"""
...
@@ -94,11 +89,12 @@ class OlmoAttention(nn.Module):
...
@@ -94,11 +89,12 @@ class OlmoAttention(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
hidden_size
=
config
.
d_model
self
.
hidden_size
=
config
.
d_model
assert
config
.
d_model
%
config
.
n_heads
==
0
assert
config
.
d_model
%
config
.
n_heads
==
0
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
(
tensor_model_parallel_world_size
=
(
)
get_tensor_model_parallel_world_size
()
)
self
.
total_num_heads
=
self
.
config
.
n_heads
self
.
total_num_heads
=
self
.
config
.
n_heads
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
assert
self
.
total_num_heads
%
tensor_model_parallel_world_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tensor_model_parallel_world_size
self
.
num_heads
=
(
self
.
total_num_heads
//
tensor_model_parallel_world_size
)
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
# Layer norms.
# Layer norms.
...
@@ -126,9 +122,9 @@ class OlmoAttention(nn.Module):
...
@@ -126,9 +122,9 @@ class OlmoAttention(nn.Module):
base
=
rope_theta
,
base
=
rope_theta
,
)
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
scale
=
self
.
scaling
)
scale
=
self
.
scaling
)
# Attention output projection.
# Attention output projection.
self
.
attn_out
=
RowParallelLinear
(
self
.
attn_out
=
RowParallelLinear
(
...
@@ -142,23 +138,23 @@ class OlmoAttention(nn.Module):
...
@@ -142,23 +138,23 @@ class OlmoAttention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
attn_norm
(
hidden_states
)
hidden_states
=
self
.
attn_norm
(
hidden_states
)
qkv
,
_
=
self
.
att_proj
(
hidden_states
)
qkv
,
_
=
self
.
att_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
if
self
.
config
.
rope
:
if
self
.
config
.
rope
:
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
attn_out
(
attn_output
)
output
,
_
=
self
.
attn_out
(
attn_output
)
return
output
return
output
class
OlmoMLP
(
nn
.
Module
):
class
OlmoMLP
(
nn
.
Module
):
"""
"""
This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
This is the MLP block where the output is computed as
``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
(plus another skip connection).
"""
"""
...
@@ -217,7 +213,8 @@ class OlmoMLP(nn.Module):
...
@@ -217,7 +213,8 @@ class OlmoMLP(nn.Module):
class
OlmoBlock
(
nn
.
Module
):
class
OlmoBlock
(
nn
.
Module
):
"""
"""
This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
This is a typical transformer block where the output is
computed as ``MLP(LN(x + Attention(LN(x))))``
(plus another skip connection).
(plus another skip connection).
"""
"""
...
@@ -235,12 +232,12 @@ class OlmoBlock(nn.Module):
...
@@ -235,12 +232,12 @@ class OlmoBlock(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]]:
# Attention block.
# Attention block.
og_x
=
hidden_states
og_x
=
hidden_states
x
=
self
.
attn
(
positions
,
hidden_states
,
kv_cache
,
input
_metadata
)
x
=
self
.
attn
(
positions
,
hidden_states
,
kv_cache
,
attn
_metadata
)
x
=
x
+
og_x
x
=
x
+
og_x
# MLP block.
# MLP block.
...
@@ -290,8 +287,8 @@ class OlmoModel(nn.Module):
...
@@ -290,8 +287,8 @@ class OlmoModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
...
@@ -307,7 +304,7 @@ class OlmoModel(nn.Module):
...
@@ -307,7 +304,7 @@ class OlmoModel(nn.Module):
positions
,
positions
,
x
,
x
,
kv_caches
[
block_idx
],
kv_caches
[
block_idx
],
input
_metadata
,
attn
_metadata
,
)
)
# Apply final layer norm.
# Apply final layer norm.
...
@@ -331,30 +328,36 @@ class OLMoForCausalLM(nn.Module):
...
@@ -331,30 +328,36 @@ class OLMoForCausalLM(nn.Module):
self
.
lm_head_weight
=
(
self
.
model
.
transformer
.
wte
.
weight
self
.
lm_head_weight
=
(
self
.
model
.
transformer
.
wte
.
weight
if
config
.
weight_tying
else
if
config
.
weight_tying
else
self
.
model
.
transformer
.
ff_out
.
weight
)
self
.
model
.
transformer
.
ff_out
.
weight
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
kv_caches
=
kv_caches
,
kv_caches
=
kv_caches
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
def
load_weights
(
...
...
vllm/model_executor/models/opt.py
View file @
7c4f76e3
...
@@ -17,20 +17,20 @@
...
@@ -17,20 +17,20 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Inference-only OPT model compatible with HuggingFace weights."""
"""Inference-only OPT model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
OPTConfig
from
transformers
import
OPTConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -41,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
...
@@ -41,8 +41,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
OPTLearnedPositionalEmbedding
(
nn
.
Embedding
):
class
OPTLearnedPositionalEmbedding
(
nn
.
Embedding
):
...
@@ -89,21 +87,19 @@ class OPTAttention(nn.Module):
...
@@ -89,21 +87,19 @@ class OPTAttention(nn.Module):
bias
=
bias
,
bias
=
bias
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
scale
=
self
.
scaling
)
scale
=
self
.
scaling
)
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
key_cache
,
value_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
key_cache
,
value_cache
,
input_metadata
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
out_proj
(
attn_output
)
return
output
return
output
...
@@ -151,8 +147,8 @@ class OPTDecoderLayer(nn.Module):
...
@@ -151,8 +147,8 @@ class OPTDecoderLayer(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
# Self Attention
residual
=
hidden_states
residual
=
hidden_states
...
@@ -161,7 +157,7 @@ class OPTDecoderLayer(nn.Module):
...
@@ -161,7 +157,7 @@ class OPTDecoderLayer(nn.Module):
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn_layer_norm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
hidden_states
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
)
attn
_metadata
=
attn
_metadata
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
# 350m applies layer norm AFTER attention
# 350m applies layer norm AFTER attention
if
not
self
.
do_layer_norm_before
:
if
not
self
.
do_layer_norm_before
:
...
@@ -240,8 +236,8 @@ class OPTDecoder(nn.Module):
...
@@ -240,8 +236,8 @@ class OPTDecoder(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
pos_embeds
=
self
.
embed_positions
(
positions
)
pos_embeds
=
self
.
embed_positions
(
positions
)
...
@@ -251,7 +247,7 @@ class OPTDecoder(nn.Module):
...
@@ -251,7 +247,7 @@ class OPTDecoder(nn.Module):
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
input
_metadata
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn
_metadata
)
if
self
.
final_layer_norm
is
not
None
:
if
self
.
final_layer_norm
is
not
None
:
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
...
@@ -274,10 +270,10 @@ class OPTModel(nn.Module):
...
@@ -274,10 +270,10 @@ class OPTModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
attn
_metadata
)
class
OPTForCausalLM
(
nn
.
Module
):
class
OPTForCausalLM
(
nn
.
Module
):
...
@@ -292,26 +288,32 @@ class OPTForCausalLM(nn.Module):
...
@@ -292,26 +288,32 @@ class OPTForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
OPTModel
(
config
,
linear_method
)
self
.
model
=
OPTModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
model
.
decoder
.
embed_tokens
.
weight
self
.
lm_head_weight
=
self
.
model
.
decoder
.
embed_tokens
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/orion.py
View file @
7c4f76e3
...
@@ -10,17 +10,17 @@ import torch
...
@@ -10,17 +10,17 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -28,8 +28,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
...
@@ -28,8 +28,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
OrionMLP
(
nn
.
Module
):
class
OrionMLP
(
nn
.
Module
):
...
@@ -118,23 +116,22 @@ class OrionAttention(nn.Module):
...
@@ -118,23 +116,22 @@ class OrionAttention(nn.Module):
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
)
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -177,8 +174,8 @@ class OrionDecoderLayer(nn.Module):
...
@@ -177,8 +174,8 @@ class OrionDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -188,7 +185,7 @@ class OrionDecoderLayer(nn.Module):
...
@@ -188,7 +185,7 @@ class OrionDecoderLayer(nn.Module):
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -226,8 +223,8 @@ class OrionModel(nn.Module):
...
@@ -226,8 +223,8 @@ class OrionModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
residual
=
None
...
@@ -237,7 +234,7 @@ class OrionModel(nn.Module):
...
@@ -237,7 +234,7 @@ class OrionModel(nn.Module):
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
residual
,
)
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
...
@@ -256,26 +253,32 @@ class OrionForCausalLM(nn.Module):
...
@@ -256,26 +253,32 @@ class OrionForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
OrionModel
(
config
,
linear_method
)
self
.
model
=
OrionModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/phi.py
View file @
7c4f76e3
...
@@ -35,23 +35,23 @@
...
@@ -35,23 +35,23 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
"""Inference-only Phi-1.5 model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -59,8 +59,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
...
@@ -59,8 +59,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
PhiAttention
(
nn
.
Module
):
class
PhiAttention
(
nn
.
Module
):
...
@@ -108,20 +106,19 @@ class PhiAttention(nn.Module):
...
@@ -108,20 +106,19 @@ class PhiAttention(nn.Module):
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
base
=
rope_theta
,
)
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_size
,
scaling
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_size
,
scaling
)
def
forward
(
def
forward
(
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
dense
(
attn_output
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
return
output
...
@@ -171,8 +168,8 @@ class PhiLayer(nn.Module):
...
@@ -171,8 +168,8 @@ class PhiLayer(nn.Module):
self
,
self
,
position_ids
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
...
@@ -180,7 +177,7 @@ class PhiLayer(nn.Module):
...
@@ -180,7 +177,7 @@ class PhiLayer(nn.Module):
position_ids
=
position_ids
,
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
)
feed_forward_hidden_states
=
self
.
mlp
(
hidden_states
)
feed_forward_hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
attn_outputs
+
feed_forward_hidden_states
+
residual
hidden_states
=
attn_outputs
+
feed_forward_hidden_states
+
residual
...
@@ -208,8 +205,8 @@ class PhiModel(nn.Module):
...
@@ -208,8 +205,8 @@ class PhiModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
...
@@ -218,7 +215,7 @@ class PhiModel(nn.Module):
...
@@ -218,7 +215,7 @@ class PhiModel(nn.Module):
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
)
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
...
@@ -240,28 +237,33 @@ class PhiForCausalLM(nn.Module):
...
@@ -240,28 +237,33 @@ class PhiForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
True
)
bias
=
True
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
,
self
.
lm_head
.
bias
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
head
=
self
.
lm_head
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
sampling_metadata
,
head
.
bias
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/qwen.py
View file @
7c4f76e3
...
@@ -10,18 +10,18 @@ import torch
...
@@ -10,18 +10,18 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -29,8 +29,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
...
@@ -29,8 +29,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
QWenMLP
(
nn
.
Module
):
class
QWenMLP
(
nn
.
Module
):
...
@@ -104,21 +102,19 @@ class QWenAttention(nn.Module):
...
@@ -104,21 +102,19 @@ class QWenAttention(nn.Module):
base
=
rope_theta
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
rope_scaling
=
rope_scaling
,
)
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
qkv
,
_
=
self
.
c_attn
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
c_proj
(
attn_output
)
output
,
_
=
self
.
c_proj
(
attn_output
)
return
output
return
output
...
@@ -152,8 +148,8 @@ class QWenBlock(nn.Module):
...
@@ -152,8 +148,8 @@ class QWenBlock(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -166,7 +162,7 @@ class QWenBlock(nn.Module):
...
@@ -166,7 +162,7 @@ class QWenBlock(nn.Module):
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -200,8 +196,8 @@ class QWenModel(nn.Module):
...
@@ -200,8 +196,8 @@ class QWenModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
hidden_states
=
self
.
wte
(
input_ids
)
residual
=
None
residual
=
None
...
@@ -211,7 +207,7 @@ class QWenModel(nn.Module):
...
@@ -211,7 +207,7 @@ class QWenModel(nn.Module):
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
residual
,
)
)
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
...
@@ -230,26 +226,32 @@ class QWenLMHeadModel(nn.Module):
...
@@ -230,26 +226,32 @@ class QWenLMHeadModel(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
transformer
=
QWenModel
(
config
,
linear_method
)
self
.
transformer
=
QWenModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/qwen2.py
View file @
7c4f76e3
...
@@ -28,18 +28,19 @@ import torch
...
@@ -28,18 +28,19 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Qwen2Config
from
transformers
import
Qwen2Config
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -47,8 +48,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
...
@@ -47,8 +48,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
Qwen2MLP
(
nn
.
Module
):
class
Qwen2MLP
(
nn
.
Module
):
...
@@ -135,24 +134,23 @@ class Qwen2Attention(nn.Module):
...
@@ -135,24 +134,23 @@ class Qwen2Attention(nn.Module):
max_position
=
max_position
,
max_position
=
max_position
,
base
=
self
.
rope_theta
,
base
=
self
.
rope_theta
,
)
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
self
.
sliding_window
)
sliding_window
=
self
.
sliding_window
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -169,7 +167,8 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -169,7 +167,8 @@ class Qwen2DecoderLayer(nn.Module):
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
# Requires transformers > 4.32.0
# Requires transformers > 4.32.0
rope_theta
=
getattr
(
config
,
"rope_theta"
,
1000000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
1000000
)
use_sliding_window
=
config
.
use_sliding_window
and
layer_idx
<
config
.
max_window_layers
use_sliding_window
=
(
config
.
use_sliding_window
and
layer_idx
<
config
.
max_window_layers
)
self
.
self_attn
=
Qwen2Attention
(
self
.
self_attn
=
Qwen2Attention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
...
@@ -194,8 +193,8 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -194,8 +193,8 @@ class Qwen2DecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
...
@@ -209,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
...
@@ -209,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
)
# Fully Connected
# Fully Connected
...
@@ -245,8 +244,8 @@ class Qwen2Model(nn.Module):
...
@@ -245,8 +244,8 @@ class Qwen2Model(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
residual
=
None
...
@@ -256,7 +255,7 @@ class Qwen2Model(nn.Module):
...
@@ -256,7 +255,7 @@ class Qwen2Model(nn.Module):
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
residual
,
residual
,
)
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
...
@@ -264,37 +263,73 @@ class Qwen2Model(nn.Module):
...
@@ -264,37 +263,73 @@ class Qwen2Model(nn.Module):
class
Qwen2ForCausalLM
(
nn
.
Module
):
class
Qwen2ForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
]
embedding_modules
=
{}
embedding_padding_modules
=
[]
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Qwen2Config
,
config
:
Qwen2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
)
->
None
:
del
lora_config
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
Qwen2Model
(
config
,
linear_method
)
self
.
model
=
Qwen2Model
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
if
config
.
tie_word_embeddings
:
self
.
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head_weight
=
self
.
lm_head
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
@@ -310,11 +345,13 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -310,11 +345,13 @@ class Qwen2ForCausalLM(nn.Module):
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
(
remove_duplicate
=
False
))
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
if
self
.
config
.
tie_word_embeddings
and
"lm_head.weight"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
...
...
vllm/model_executor/models/qwen2_moe.py
0 → 100644
View file @
7c4f76e3
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
class
Qwen2MoeMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
reduce_results
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
reduce_results
=
reduce_results
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
Qwen2MoeSparseMoeBlock
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
n_routed_experts
=
config
.
num_experts
self
.
top_k
=
config
.
num_experts_per_tok
if
self
.
tp_size
>
self
.
n_routed_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
n_routed_experts
}
."
)
self
.
experts
=
nn
.
ModuleList
([
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
reduce_results
=
False
)
for
idx
in
range
(
self
.
n_routed_experts
)
])
self
.
pack_params
()
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
n_routed_experts
,
bias
=
False
,
linear_method
=
None
)
if
config
.
shared_expert_intermediate_size
>
0
:
self
.
shared_expert
=
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
shared_expert_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
reduce_results
=
False
,
)
else
:
self
.
shared_expert
=
None
self
.
shared_expert_gate
=
torch
.
nn
.
Linear
(
config
.
hidden_size
,
1
,
bias
=
False
)
def
pack_params
(
self
):
w1
=
[]
w2
=
[]
for
expert
in
self
.
experts
:
w1
.
append
(
expert
.
gate_up_proj
.
weight
)
w2
.
append
(
expert
.
down_proj
.
weight
)
self
.
w1
=
torch
.
_utils
.
_flatten_dense_tensors
(
w1
)
w1s
=
torch
.
_utils
.
_unflatten_dense_tensors
(
self
.
w1
,
w1
)
for
data
,
param
in
zip
(
w1s
,
w1
):
param
.
data
=
data
self
.
w1
=
self
.
w1
.
view
(
len
(
w1
),
*
w1s
[
0
].
shape
)
self
.
w2
=
torch
.
_utils
.
_flatten_dense_tensors
(
w2
)
w2s
=
torch
.
_utils
.
_unflatten_dense_tensors
(
self
.
w2
,
w2
)
for
data
,
param
in
zip
(
w2s
,
w2
):
param
.
data
=
data
self
.
w2
=
self
.
w2
.
view
(
len
(
w2
),
*
w2s
[
0
].
shape
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
shared_output
=
None
if
self
.
shared_expert
is
not
None
:
shared_output
=
self
.
shared_expert
(
hidden_states
)
if
self
.
shared_expert_gate
is
not
None
:
shared_output
=
F
.
sigmoid
(
self
.
shared_expert_gate
(
hidden_states
))
*
shared_output
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
w1
,
self
.
w2
,
router_logits
,
self
.
top_k
,
renormalize
=
self
.
config
.
norm_topk_prob
,
inplace
=
True
)
if
shared_output
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
class
Qwen2MoeAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
True
,
linear_method
=
linear_method
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Qwen2MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
Qwen2MoeAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
)
if
(
config
.
num_experts
is
not
None
and
(
layer_idx
+
1
)
%
config
.
decoder_sparse_step
==
0
):
self
.
mlp
=
Qwen2MoeSparseMoeBlock
(
config
=
config
,
linear_method
=
linear_method
)
else
:
self
.
mlp
=
Qwen2MoeMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
Qwen2MoeModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
Qwen2MoeDecoderLayer
(
config
,
layer_idx
,
linear_method
=
linear_method
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
Qwen2MoeForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
Qwen2MoeModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
,
fall_back_to_pt
=
False
):
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_expert."
in
name
)
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_expert."
in
name
)
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/stablelm.py
View file @
7c4f76e3
# coding=utf-8
# coding=utf-8
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team.
# All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -16,24 +17,25 @@
...
@@ -16,24 +17,25 @@
# This code is based off the following work:
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights."""
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.
model_executor.input_metadata
import
Input
Metadata
from
vllm.
attention
import
Attention
,
Attention
Metadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -41,8 +43,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
...
@@ -41,8 +43,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
StablelmMLP
(
nn
.
Module
):
class
StablelmMLP
(
nn
.
Module
):
...
@@ -102,9 +102,9 @@ class StablelmAttention(nn.Module):
...
@@ -102,9 +102,9 @@ class StablelmAttention(nn.Module):
self
.
kv_size
=
self
.
num_key_value_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_key_value_heads
*
self
.
head_dim
self
.
qkv_bias
=
getattr
(
config
,
"use_qkv_bias"
,
False
)
self
.
qkv_bias
=
getattr
(
config
,
"use_qkv_bias"
,
False
)
if
(
self
.
head_dim
*
self
.
num_heads
*
tp_size
)
!=
self
.
hidden_size
:
if
(
self
.
head_dim
*
self
.
num_heads
*
tp_size
)
!=
self
.
hidden_size
:
raise
ValueError
(
raise
ValueError
(
f
"hidden_size must be divisible by num_heads "
f
"hidden_size must be divisible by num_heads
(got `hidden_size`:
{
self
.
hidden_size
}
"
f
"
(got `hidden_size`:
{
self
.
hidden_size
}
"
f
" and `num_heads`:
{
self
.
num_heads
}
)."
)
f
" and `num_heads`:
{
self
.
num_heads
}
)."
)
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
head_dim
,
...
@@ -122,23 +122,22 @@ class StablelmAttention(nn.Module):
...
@@ -122,23 +122,22 @@ class StablelmAttention(nn.Module):
max_position
=
self
.
config
.
max_position_embeddings
,
max_position
=
self
.
config
.
max_position_embeddings
,
base
=
self
.
config
.
rope_theta
,
base
=
self
.
config
.
rope_theta
,
)
)
self
.
attn
=
Paged
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_key_value_heads
)
num_kv_heads
=
self
.
num_key_value_heads
)
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -163,8 +162,8 @@ class StablelmDecoderLayer(nn.Module):
...
@@ -163,8 +162,8 @@ class StablelmDecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
# Self Attention
residual
=
hidden_states
residual
=
hidden_states
...
@@ -173,7 +172,7 @@ class StablelmDecoderLayer(nn.Module):
...
@@ -173,7 +172,7 @@ class StablelmDecoderLayer(nn.Module):
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -192,7 +191,6 @@ class StableLMEpochModel(nn.Module):
...
@@ -192,7 +191,6 @@ class StableLMEpochModel(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
super
().
__init__
()
super
().
__init__
()
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
...
@@ -209,8 +207,8 @@ class StableLMEpochModel(nn.Module):
...
@@ -209,8 +207,8 @@ class StableLMEpochModel(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
...
@@ -219,7 +217,7 @@ class StableLMEpochModel(nn.Module):
...
@@ -219,7 +217,7 @@ class StableLMEpochModel(nn.Module):
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
],
input
_metadata
,
attn
_metadata
,
)
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
return
hidden_states
...
@@ -237,26 +235,32 @@ class StablelmForCausalLM(nn.Module):
...
@@ -237,26 +235,32 @@ class StablelmForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
StableLMEpochModel
(
config
,
linear_method
)
self
.
model
=
StableLMEpochModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/starcoder2.py
View file @
7c4f76e3
...
@@ -18,37 +18,30 @@
...
@@ -18,37 +18,30 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
""" PyTorch Starcoder2 model."""
""" PyTorch Starcoder2 model."""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
Starcoder2Config
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
get_tensor_model_parallel_world_size
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
try
:
from
transformers
import
Starcoder2Config
except
ImportError
:
# fallback to PretrainedConfig
# NOTE: Please install transformers from source or use transformers>=4.39.0
from
transformers
import
PretrainedConfig
as
Starcoder2Config
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
Starcoder2Attention
(
nn
.
Module
):
class
Starcoder2Attention
(
nn
.
Module
):
...
@@ -103,7 +96,7 @@ class Starcoder2Attention(nn.Module):
...
@@ -103,7 +96,7 @@ class Starcoder2Attention(nn.Module):
base
=
int
(
self
.
rope_theta
),
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
is_neox_style
=
True
,
)
)
self
.
attn
=
Paged
Attention
(
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
self
.
scaling
,
self
.
scaling
,
...
@@ -115,14 +108,13 @@ class Starcoder2Attention(nn.Module):
...
@@ -115,14 +108,13 @@ class Starcoder2Attention(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
return
output
...
@@ -145,8 +137,9 @@ class Starcoder2MLP(nn.Module):
...
@@ -145,8 +137,9 @@ class Starcoder2MLP(nn.Module):
bias
=
config
.
use_bias
,
bias
=
config
.
use_bias
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
intermediate_size
=
config
.
intermediate_size
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
config
.
intermediate_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
hidden_states
,
_
=
self
.
c_fc
(
hidden_states
)
...
@@ -174,8 +167,8 @@ class Starcoder2DecoderLayer(nn.Module):
...
@@ -174,8 +167,8 @@ class Starcoder2DecoderLayer(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
kv_cache
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# Self Attention
# Self Attention
residual
=
hidden_states
residual
=
hidden_states
...
@@ -184,7 +177,7 @@ class Starcoder2DecoderLayer(nn.Module):
...
@@ -184,7 +177,7 @@ class Starcoder2DecoderLayer(nn.Module):
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
kv_cache
=
kv_cache
,
input
_metadata
=
input
_metadata
,
attn
_metadata
=
attn
_metadata
,
)
)
hidden_states
=
residual
+
hidden_states
hidden_states
=
residual
+
hidden_states
...
@@ -220,14 +213,14 @@ class Starcoder2Model(nn.Module):
...
@@ -220,14 +213,14 @@ class Starcoder2Model(nn.Module):
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input
_metadata
)
attn
_metadata
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
return
hidden_states
...
@@ -253,26 +246,33 @@ class Starcoder2ForCausalLM(nn.Module):
...
@@ -253,26 +246,33 @@ class Starcoder2ForCausalLM(nn.Module):
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
)
self
.
lm_head_weight
=
self
.
lm_head
.
weight
self
.
lm_head_weight
=
self
.
lm_head
.
weight
self
.
sampler
=
Sampler
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
torch
.
Tensor
],
input
_metadata
:
Input
Metadata
,
attn
_metadata
:
Attention
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input
_metadata
)
attn
_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
Optional
[
torch
.
Tensor
],
logit
s
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/xverse.py
0 → 100644
View file @
7c4f76e3
# coding=utf-8
# Adapted from
# https://huggingface.co/xverse/XVERSE-7B/blob/main/modeling_xverse.py
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Xverse model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
class
XverseMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
XverseAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
bias
:
bool
=
False
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
# partition the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
bias
,
linear_method
=
linear_method
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
bias
,
linear_method
=
linear_method
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
sliding_window
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
kv_cache
,
attn_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
XverseDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
sliding_window
=
getattr
(
config
,
"sliding_window"
,
None
)
self
.
self_attn
=
XverseAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
getattr
(
config
,
"num_key_value_heads"
,
config
.
num_attention_heads
),
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
bias
=
getattr
(
config
,
"bias"
,
False
),
sliding_window
=
sliding_window
,
)
self
.
mlp
=
XverseMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
attn_metadata
=
attn_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
XverseModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
layers
=
nn
.
ModuleList
([
XverseDecoderLayer
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
XverseForCausalLM
(
nn
.
Module
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
,
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
,
],
}
# LoRA specific attributes
supported_lora_modules
=
[
"qkv_proj"
,
"o_proj"
,
"gate_up_proj"
,
"down_proj"
,
"embed_tokens"
,
"lm_head"
,
]
embedding_modules
=
{
"embed_tokens"
:
"input_embeddings"
,
"lm_head"
:
"output_embeddings"
,
}
embedding_padding_modules
=
[
"lm_head"
]
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
XverseModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
(
"rotary_emb.inv_freq"
in
name
or
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/neuron_model_loader.py
View file @
7c4f76e3
"""Utilities for selecting and loading models."""
"""Utilities for selecting and loading neuron models."""
from
typing
import
Type
import
importlib
import
os
from
typing
import
Optional
,
Type
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
transformers
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
,
DeviceConfig
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
SchedulerConfig
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
TORCH_DTYPE_TO_NEURON_AMP
=
{
TORCH_DTYPE_TO_NEURON_AMP
=
{
"auto"
:
"f32"
,
"auto"
:
"f32"
,
...
@@ -20,30 +26,95 @@ TORCH_DTYPE_TO_NEURON_AMP = {
...
@@ -20,30 +26,95 @@ TORCH_DTYPE_TO_NEURON_AMP = {
torch
.
float32
:
"f32"
,
torch
.
float32
:
"f32"
,
}
}
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS
=
{
"LlamaForCausalLM"
:
(
"transformers_neuronx.llama.model"
,
"LlamaForSampling"
,
"LlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"transformers_neuronx.mistral.model"
,
"MistralForSampling"
,
"MistralForCausalLM"
)
}
class
NeuronCasualLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
model
=
None
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
,
logits_as_input
=
True
)
self
.
sampler
=
Sampler
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_block_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
logits
=
self
.
model
(
input_ids
,
cache_ids
=
positions
,
start_ids
=
input_block_ids
)
return
logits
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
None
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
self
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
**
kwargs
):
arch
=
_get_model_architecture
(
self
.
config
)
neuronx_module_path
,
neuronx_model_cls
,
hf_model_cls
=
(
_NEURON_SUPPORTED_MODELS
[
arch
])
neuronx_module
=
importlib
.
import_module
(
neuronx_module_path
)
neuronx_model_cls
=
getattr
(
neuronx_module
,
neuronx_model_cls
)
split_model_dir
=
f
"
{
model_name_or_path
}
-split"
if
os
.
path
.
isdir
(
os
.
path
.
join
(
model_name_or_path
,
"pytorch_model.bin"
)):
split_model_dir
=
model_name_or_path
elif
not
os
.
path
.
exists
(
f
"
{
model_name_or_path
}
-split"
):
hf_model_cls
=
getattr
(
transformers
,
hf_model_cls
)
from
transformers_neuronx.module
import
save_pretrained_split
hf_model
=
hf_model_cls
.
from_pretrained
(
model_name_or_path
,
low_cpu_mem_usage
=
True
)
save_pretrained_split
(
hf_model
,
f
"
{
model_name_or_path
}
-split"
)
self
.
model
=
neuronx_model_cls
.
from_pretrained
(
split_model_dir
,
**
kwargs
)
self
.
model
.
to_neuron
()
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
Type
[
nn
.
Module
]:
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
Type
[
nn
.
Module
]:
architectures
=
getattr
(
config
,
"architectures"
,
[])
architectures
=
getattr
(
config
,
"architectures"
,
[])
for
arch
in
architectures
:
for
arch
in
architectures
:
model_cls
=
ModelRegistry
.
load_model_cls
(
arch
)
if
arch
in
_NEURON_SUPPORTED_MODELS
:
if
model_cls
is
not
None
:
return
arch
return
model_cls
raise
ValueError
(
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Model architectures
{
architectures
}
are not supported on Neuron "
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
f
"for now. Supported architectures: "
f
"
{
list
(
_NEURON_SUPPORTED_MODELS
.
keys
())
}
"
)
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
**
kwargs
)
->
nn
.
Module
:
from
transformers_neuronx.config
import
NeuronConfig
,
ContinuousBatchingConfig
parallel_config
=
kwargs
.
get
(
"parallel_config"
)
scheduler_config
=
kwargs
.
get
(
"scheduler_config"
)
model_class
=
_get_model_architecture
(
model_config
.
hf_config
)
def
get_neuron_model
(
model_config
:
ModelConfig
,
linear_method
=
None
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
)
->
nn
.
Module
:
from
transformers_neuronx.config
import
(
ContinuousBatchingConfig
,
NeuronConfig
)
# Create a model instance.
# Create a model instance.
model
=
model_class
(
model_config
.
hf_config
,
linear_method
)
model
=
NeuronCasualLM
(
model_config
.
hf_config
)
continuous_batching_config
=
ContinuousBatchingConfig
(
continuous_batching_config
=
ContinuousBatchingConfig
(
batch_size_for_shared_caches
=
scheduler_config
.
max_num_seqs
)
batch_size_for_shared_caches
=
scheduler_config
.
max_num_seqs
)
...
@@ -53,10 +124,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
...
@@ -53,10 +124,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
# Load the weights from the cached or downloaded files.
# Load the weights from the cached or downloaded files.
model
.
load_weights
(
model
.
load_weights
(
model_config
.
model
,
model_config
.
model
,
model_config
.
download_dir
,
tp_degree
=
parallel_config
.
tensor_parallel_size
,
model_config
.
load_format
,
model_config
.
revision
,
tp_degree
=
parallel_config
.
neuron_tp_degree
,
amp
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
amp
=
TORCH_DTYPE_TO_NEURON_AMP
[
model_config
.
dtype
],
neuron_config
=
neuron_config
,
neuron_config
=
neuron_config
,
context_length_estimate
=
[
scheduler_config
.
max_model_len
],
context_length_estimate
=
[
scheduler_config
.
max_model_len
],
...
...
vllm/model_executor/parallel_utils/communication_op.py
View file @
7c4f76e3
...
@@ -4,14 +4,12 @@ from typing import Any, Dict, List, Optional, Union
...
@@ -4,14 +4,12 @@ from typing import Any, Dict, List, Optional, Union
import
torch
import
torch
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
vllm.model_executor.parallel_utils
import
cupy_utils
from
vllm.model_executor.parallel_utils
import
pynccl_utils
from
vllm.model_executor.parallel_utils.custom_all_reduce
import
(
custom_all_reduce
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
is_pynccl_enabled_for_all_reduce
)
get_tensor_model_parallel_group
,
is_cupy_nccl_enabled_for_all_reduce
,
)
from
vllm.model_executor.parallel_utils.custom_all_reduce
import
custom_all_reduce
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -24,7 +22,7 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
...
@@ -24,7 +22,7 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
and GPU topology.
and GPU topology.
TLDR: always assume this function modifies its input, but use the return
TLDR: always assume this function modifies its input, but use the return
value as the output.
value as the output.
"""
"""
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_world_size
()
==
1
:
if
get_tensor_model_parallel_world_size
()
==
1
:
...
@@ -32,9 +30,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
...
@@ -32,9 +30,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
out
=
custom_all_reduce
(
input_
)
out
=
custom_all_reduce
(
input_
)
if
out
is
not
None
:
if
out
is
not
None
:
return
out
return
out
if
is_
cu
py
_
nccl_enabled_for_all_reduce
():
if
is_pynccl_enabled_for_all_reduce
():
# TODO: support multiple parallel groups.
# TODO: support multiple parallel groups.
cu
py_utils
.
all_reduce
(
input_
)
py
nccl
_utils
.
all_reduce
(
input_
)
else
:
else
:
torch
.
distributed
.
all_reduce
(
input_
,
torch
.
distributed
.
all_reduce
(
input_
,
group
=
get_tensor_model_parallel_group
())
group
=
get_tensor_model_parallel_group
())
...
@@ -176,7 +174,7 @@ def broadcast_tensor_dict(
...
@@ -176,7 +174,7 @@ def broadcast_tensor_dict(
for
key
,
value
in
metadata_list
:
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
tensor_dict
[
key
]
tensor
=
tensor_dict
[
key
]
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
)
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
group
=
group
)
else
:
else
:
recv_metadata_list
=
[
None
]
recv_metadata_list
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
...
...
vllm/model_executor/parallel_utils/cupy_utils.py
deleted
100644 → 0
View file @
2da0dd3e
"""CuPy utilities for all-reduce.
We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing
CUDA graphs, because torch.distributed.all_reduce causes errors when capturing
CUDA graphs.
NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8.
TODO: Remove this file when torch.distributed.all_reduce is fixed.
"""
import
contextlib
import
torch
from
torch.distributed
import
ReduceOp
try
:
import
cupy
from
cupy.cuda
import
nccl
from
cupyx.distributed
import
NCCLBackend
except
ImportError
as
e
:
cupy
=
e
nccl
=
None
class
NCCLBackend
:
...
_OP_MAPPING
=
{
ReduceOp
.
SUM
:
"sum"
,
ReduceOp
.
PRODUCT
:
"prod"
,
ReduceOp
.
MIN
:
"min"
,
ReduceOp
.
MAX
:
"max"
,
}
class
NCCLBackendWithBFloat16
(
NCCLBackend
):
# This is enough to add bfloat16 support for most operations,
# but broadcast will fail (will require changes in compiled
# cupy code).
def
_get_nccl_dtype_and_count
(
self
,
array
,
count
=
None
):
nccl_dtype
,
count
=
super
().
_get_nccl_dtype_and_count
(
array
,
count
)
torch_dtype
=
getattr
(
array
,
"_torch_dtype"
,
None
)
if
torch_dtype
is
torch
.
bfloat16
:
nccl_dtype
=
nccl
.
NCCL_BFLOAT16
return
nccl_dtype
,
count
def
barrier
(
self
)
->
None
:
raise
RuntimeError
(
"Currently, CuPy NCCL barrier is not supported since the TCP "
"store is immediately stopped after the initialization."
)
_NCCL_BACKEND
=
None
_WORLD_SIZE
=
0
def
is_initialized
()
->
bool
:
"""Returns whether the NCCL backend is initialized."""
return
_NCCL_BACKEND
is
not
None
@
contextlib
.
contextmanager
def
set_cupy_stream
(
stream
:
torch
.
cuda
.
Stream
):
"""Set the cuda stream for communication"""
cupy_stream
=
cupy
.
cuda
.
ExternalStream
(
stream
.
cuda_stream
,
stream
.
device_index
)
with
cupy_stream
:
yield
def
init_process_group
(
world_size
:
int
,
rank
:
int
,
host
:
str
,
port
:
int
)
->
None
:
"""Initializes the CuPy NCCL backend.
# TODO: handle NCCL timeouts.
"""
assert
not
is_initialized
()
if
isinstance
(
cupy
,
Exception
):
raise
ImportError
(
"NCCLBackend is not available. Please install cupy."
)
from
cupy
# TODO(woosuk): Create TP and PP process groups for CuPy.
global
_NCCL_BACKEND
global
_WORLD_SIZE
assert
world_size
>
0
,
f
"
{
world_size
=
}
should be a positive integer"
assert
0
<=
rank
<
world_size
,
(
f
"
{
rank
=
}
should be a integer between [0,
{
world_size
}
)"
)
cupy
.
cuda
.
runtime
.
setDevice
(
torch
.
cuda
.
current_device
())
_NCCL_BACKEND
=
NCCLBackendWithBFloat16
(
world_size
,
rank
,
host
,
port
)
_WORLD_SIZE
=
world_size
# Stop the TCP store to prevent the deadlock issues at termination time.
# FIXME(woosuk): This is hacky. Find a more robust solution.
if
rank
==
0
and
hasattr
(
_NCCL_BACKEND
,
"_store"
):
_NCCL_BACKEND
.
_store
.
stop
()
def
all_reduce
(
input_
:
torch
.
Tensor
,
op
=
ReduceOp
.
SUM
)
->
None
:
"""All-reduces the input tensor across the process group."""
assert
input_
.
is_cuda
,
f
"
{
input_
}
should be a cuda tensor"
# Hack to support bfloat16
torch_dtype
=
input_
.
dtype
if
torch_dtype
is
torch
.
bfloat16
:
# We need to view as float16, otherwise
# cupy will fail. This will not change
# the underlying data.
input_
=
input_
.
view
(
torch
.
float16
)
cupy_input
=
cupy
.
asarray
(
input_
)
cupy_input
.
_torch_dtype
=
torch_dtype
# pylint: disable=protected-access
_NCCL_BACKEND
.
all_reduce
(
in_array
=
cupy_input
,
out_array
=
cupy_input
,
op
=
_OP_MAPPING
[
op
])
def
destroy_process_group
()
->
None
:
"""Destroys the NCCL backend."""
global
_NCCL_BACKEND
global
_WORLD_SIZE
_NCCL_BACKEND
=
None
_WORLD_SIZE
=
0
def
get_world_size
()
->
int
:
"""Returns the world size."""
return
_WORLD_SIZE
def
get_nccl_backend
():
return
_NCCL_BACKEND
vllm/model_executor/parallel_utils/custom_all_reduce.py
View file @
7c4f76e3
...
@@ -6,11 +6,12 @@ import torch.distributed as dist
...
@@ -6,11 +6,12 @@ import torch.distributed as dist
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_
world_size
,
get_tensor_model_parallel_
rank
)
get_tensor_model_parallel_
rank
,
get_tensor_model_parallel_
world_size
)
try
:
try
:
from
vllm._C
import
custom_ar
import
pynvml
import
pynvml
from
vllm._C
import
custom_ar
except
ImportError
:
except
ImportError
:
# For AMD GPUs
# For AMD GPUs
custom_ar
=
None
custom_ar
=
None
...
@@ -37,16 +38,23 @@ def init_custom_ar() -> None:
...
@@ -37,16 +38,23 @@ def init_custom_ar() -> None:
logger
.
warn
(
logger
.
warn
(
"Custom allreduce is disabled due to an unsupported world size: "
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify"
"%d. Supported world sizes: %s. To silence this warning, specify"
"disable_custom_all_reduce=True explicitly."
,
world_size
,
"
disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
_SUPPORTED_WORLD_SIZES
))
str
(
_SUPPORTED_WORLD_SIZES
))
return
return
if
not
_can_p2p
(
rank
,
world_size
):
if
not
_can_p2p
(
rank
,
world_size
):
logger
.
warn
(
logger
.
warn
(
"Custom allreduce is disabled because your platform lacks GPU P2P"
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability. To silence this warning, specify"
" capability or P2P test failed. To silence this warning, specify"
"disable_custom_all_reduce=True explicitly."
)
" disable_custom_all_reduce=True explicitly."
)
return
full_nvlink
=
_is_full_nvlink
(
rank
,
world_size
)
if
world_size
>
2
and
not
full_nvlink
:
logger
.
warn
(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly."
)
return
return
_CA_HANDLE
=
CustomAllreduce
(
rank
,
world_size
)
_CA_HANDLE
=
CustomAllreduce
(
rank
,
world_size
,
full_nvlink
)
def
begin_capture
()
->
None
:
def
begin_capture
()
->
None
:
...
@@ -134,18 +142,48 @@ def _is_full_nvlink(rank, world_size):
...
@@ -134,18 +142,48 @@ def _is_full_nvlink(rank, world_size):
def
_can_p2p
(
rank
:
int
,
world_size
:
int
)
->
bool
:
def
_can_p2p
(
rank
:
int
,
world_size
:
int
)
->
bool
:
num_dev
=
torch
.
cuda
.
device_count
()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if
num_dev
<
world_size
:
logger
.
warn
(
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set."
)
return
False
for
i
in
range
(
world_size
):
for
i
in
range
(
world_size
):
if
i
==
rank
:
if
i
==
rank
:
continue
continue
if
not
torch
.
cuda
.
can_device_access_peer
(
rank
,
i
):
if
not
torch
.
cuda
.
can_device_access_peer
(
rank
,
i
):
return
False
return
False
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
if
not
_can_actually_p2p
(
rank
,
i
):
return
False
return
True
return
True
# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def
_can_actually_p2p
(
idx_a
,
idx_b
):
dev_i
=
f
"cuda:
{
idx_a
}
"
dev_j
=
f
"cuda:
{
idx_b
}
"
a
=
torch
.
randn
(
5
,
device
=
dev_i
)
+
123.0
b
=
a
.
to
(
dev_j
)
c
=
b
.
to
(
dev_i
)
return
torch
.
all
(
a
==
c
)
class
CustomAllreduce
:
class
CustomAllreduce
:
# max_size: max supported allreduce size
# max_size: max supported allreduce size
def
__init__
(
self
,
rank
,
world_size
,
max_size
=
8192
*
1024
)
->
None
:
def
__init__
(
self
,
rank
,
world_size
,
full_nvlink
,
max_size
=
8192
*
1024
)
->
None
:
# buffers memory are owned by this Python class and passed to C++
# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
# (256 bytes) and a temporary buffer for storing intermediate
...
@@ -167,11 +205,10 @@ class CustomAllreduce:
...
@@ -167,11 +205,10 @@ class CustomAllreduce:
self
.
max_size
=
max_size
self
.
max_size
=
max_size
self
.
world_size
=
world_size
self
.
world_size
=
world_size
handles
,
offsets
=
self
.
_get_ipc_meta
(
self
.
meta
)
handles
,
offsets
=
self
.
_get_ipc_meta
(
self
.
meta
)
self
.
full_nvlink
=
_is_
full_nvlink
(
rank
,
world_size
)
self
.
full_nvlink
=
full_nvlink
self
.
_ptr
=
custom_ar
.
init_custom_ar
(
self
.
meta
,
self
.
rank_data
,
self
.
_ptr
=
custom_ar
.
init_custom_ar
(
self
.
meta
,
self
.
rank_data
,
handles
,
offsets
,
rank
,
handles
,
offsets
,
rank
,
self
.
full_nvlink
)
self
.
full_nvlink
)
self
.
fast_cond
=
self
.
full_nvlink
or
world_size
<=
2
self
.
register_buffer
(
self
.
buffer
)
self
.
register_buffer
(
self
.
buffer
)
def
_get_ipc_meta
(
self
,
inp
:
torch
.
Tensor
):
def
_get_ipc_meta
(
self
,
inp
:
torch
.
Tensor
):
...
...
vllm/model_executor/parallel_utils/parallel_state.py
View file @
7c4f76e3
...
@@ -7,7 +7,7 @@ import contextlib
...
@@ -7,7 +7,7 @@ import contextlib
import
torch
import
torch
from
vllm.model_executor.parallel_utils
import
cu
py_utils
from
vllm.model_executor.parallel_utils
import
py
nccl
_utils
# Tensor model parallel group that the current rank belongs to.
# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP
=
None
_TENSOR_MODEL_PARALLEL_GROUP
=
None
...
@@ -210,36 +210,36 @@ def destroy_model_parallel():
...
@@ -210,36 +210,36 @@ def destroy_model_parallel():
global
_PIPELINE_GLOBAL_RANKS
global
_PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS
=
None
_PIPELINE_GLOBAL_RANKS
=
None
# Destroy the
cu
py states if any.
# Destroy the py
nccl
states if any.
cu
py_utils
.
destroy_process_group
()
py
nccl
_utils
.
destroy_process_group
()
# Whether to use
cu
py for nccl all reduce.
# Whether to use py
nccl
for nccl all reduce.
# We use
cu
py for all reduce when using CUDA graph, because torch.distributed
# We use py
nccl
for all reduce when using CUDA graph, because torch.distributed
# is not well supported by CUDA graph.
# is not well supported by CUDA graph.
_ENABLE_
CU
PY_FOR_ALL_REDUCE
=
False
_ENABLE_PY
NCCL
_FOR_ALL_REDUCE
=
False
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
with_
cu
py
_
nccl_for_all_reduce
():
def
with_pynccl_for_all_reduce
():
"""use
CuPy
nccl instead of torch.distributed for all reduce"""
"""use
py
nccl instead of torch.distributed for all reduce"""
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
if
tp_size
==
1
:
if
tp_size
==
1
:
# No-op.
# No-op.
# NOTE(woosuk): We don't initialize
CuPy
when tp_size is 1.
# NOTE(woosuk): We don't initialize
pynccl
when tp_size is 1.
yield
yield
else
:
else
:
global
_ENABLE_
CU
PY_FOR_ALL_REDUCE
global
_ENABLE_PY
NCCL
_FOR_ALL_REDUCE
old
=
_ENABLE_
CU
PY_FOR_ALL_REDUCE
old
=
_ENABLE_PY
NCCL
_FOR_ALL_REDUCE
_ENABLE_
CU
PY_FOR_ALL_REDUCE
=
True
_ENABLE_PY
NCCL
_FOR_ALL_REDUCE
=
True
stream
=
torch
.
cuda
.
current_stream
()
stream
=
torch
.
cuda
.
current_stream
()
with
cu
py_utils
.
set_
cu
py_stream
(
stream
):
with
py
nccl
_utils
.
set_py
nccl
_stream
(
stream
):
yield
yield
_ENABLE_
CU
PY_FOR_ALL_REDUCE
=
old
_ENABLE_PY
NCCL
_FOR_ALL_REDUCE
=
old
def
is_
cu
py
_
nccl_enabled_for_all_reduce
():
def
is_pynccl_enabled_for_all_reduce
():
"""check if
CuPy
nccl is enabled for all reduce"""
"""check if
py
nccl is enabled for all reduce"""
global
_ENABLE_
CU
PY_FOR_ALL_REDUCE
global
_ENABLE_PY
NCCL
_FOR_ALL_REDUCE
return
_ENABLE_
CU
PY_FOR_ALL_REDUCE
return
_ENABLE_PY
NCCL
_FOR_ALL_REDUCE
vllm/model_executor/parallel_utils/pynccl.py
0 → 100644
View file @
7c4f76e3
# This file is a pure Python wrapper for the NCCL library.
# The main purpose is to use NCCL combined with CUDA graph.
# Before writing this script, we tried the following approach:
# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself
# often gets stuck when initializing the NCCL communicator.
# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce`
# contains many other potential cuda APIs, that are not allowed during
# capturing the CUDA graph. For further details, please check
# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ .
#
# Another rejected idea is to write a C/C++ binding for NCCL. It is usually
# doable, but we often encounter issues related with nccl versions, and need
# to switch between different versions of NCCL. See
# https://github.com/NVIDIA/nccl/issues/1234 for more details.
# A C/C++ binding is not flexible enough to handle this. It requires
# recompilation of the code every time we want to switch between different
# versions. This current implementation, with a **pure** Python wrapper, is
# more flexible. We can easily switch between different versions of NCCL by
# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file`
# variable in the code.
import
ctypes
import
datetime
import
os
# ===================== import region =====================
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ReduceOp
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
so_file
=
os
.
environ
.
get
(
"VLLM_NCCL_SO_PATH"
,
""
)
# manually load the nccl library
if
so_file
:
logger
.
info
(
f
"Loading nccl from environment variable VLLM_NCCL_SO_PATH=
{
so_file
}
"
)
else
:
if
torch
.
version
.
cuda
is
not
None
:
so_file
=
"libnccl.so.2"
elif
torch
.
version
.
hip
is
not
None
:
so_file
=
"librccl.so.1"
else
:
raise
ValueError
(
"NCCL only supports CUDA and ROCm backends."
)
logger
.
debug
(
f
"Loading nccl from library
{
so_file
}
"
)
try
:
nccl
=
ctypes
.
CDLL
(
so_file
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to load NCCL library from
{
so_file
}
."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise please set the environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path."
)
raise
e
# === export types and functions from nccl to Python ===
# for the original nccl definition, please check
# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in
ncclResult_t
=
ctypes
.
c_int
# equivalent to c declaration:
# ncclResult_t ncclGetVersion(int *version);
_c_ncclGetVersion
=
nccl
.
ncclGetVersion
_c_ncclGetVersion
.
restype
=
ctypes
.
c_int
_c_ncclGetVersion
.
argtypes
=
[
ctypes
.
POINTER
(
ctypes
.
c_int
)]
def
ncclGetVersion
()
->
str
:
version
=
ctypes
.
c_int
()
result
=
_c_ncclGetVersion
(
ctypes
.
byref
(
version
))
assert
result
==
0
# something like 21903 --> "2.19.3"
version_str
=
str
(
version
.
value
)
major
=
version_str
[
0
].
lstrip
(
"0"
)
minor
=
version_str
[
1
:
3
].
lstrip
(
"0"
)
patch
=
version_str
[
3
:].
lstrip
(
"0"
)
return
f
"
{
major
}
.
{
minor
}
.
{
patch
}
"
class
NcclUniqueId
(
ctypes
.
Structure
):
_fields_
=
[(
"internal"
,
ctypes
.
c_byte
*
128
)]
# equivalent to c declaration:
# ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId);
_c_ncclGetUniqueId
=
nccl
.
ncclGetUniqueId
_c_ncclGetUniqueId
.
restype
=
ctypes
.
c_int
_c_ncclGetUniqueId
.
argtypes
=
[
ctypes
.
POINTER
(
NcclUniqueId
)]
def
ncclGetUniqueId
()
->
NcclUniqueId
:
unique_id
=
NcclUniqueId
()
result
=
_c_ncclGetUniqueId
(
ctypes
.
byref
(
unique_id
))
assert
result
==
0
return
unique_id
# equivalent to c declaration:
# ncclResult_t ncclCommInitRank(
# ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
# note that ncclComm_t is a pointer type, so the first argument
# is a pointer to a pointer
_c_ncclCommInitRank
=
nccl
.
ncclCommInitRank
_c_ncclCommInitRank
.
restype
=
ctypes
.
c_int
_c_ncclCommInitRank
.
argtypes
=
[
ctypes
.
POINTER
(
ctypes
.
c_void_p
),
ctypes
.
c_int
,
NcclUniqueId
,
ctypes
.
c_int
]
# enums
class
ncclDataType_t
(
ctypes
.
c_int
):
ncclInt8
=
0
ncclChar
=
0
ncclUint8
=
1
ncclInt32
=
2
ncclInt
=
2
ncclUint32
=
3
ncclInt64
=
4
ncclUint64
=
5
ncclFloat16
=
6
ncclHalf
=
6
ncclFloat32
=
7
ncclFloat
=
7
ncclFloat64
=
8
ncclDouble
=
8
ncclBfloat16
=
9
ncclNumTypes
=
10
@
classmethod
def
from_torch
(
cls
,
dtype
:
torch
.
dtype
)
->
'ncclDataType_t'
:
if
dtype
==
torch
.
int8
:
return
cls
.
ncclInt8
if
dtype
==
torch
.
uint8
:
return
cls
.
ncclUint8
if
dtype
==
torch
.
int32
:
return
cls
.
ncclInt32
if
dtype
==
torch
.
int64
:
return
cls
.
ncclInt64
if
dtype
==
torch
.
float16
:
return
cls
.
ncclFloat16
if
dtype
==
torch
.
float32
:
return
cls
.
ncclFloat32
if
dtype
==
torch
.
float64
:
return
cls
.
ncclFloat64
if
dtype
==
torch
.
bfloat16
:
return
cls
.
ncclBfloat16
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
class
ncclRedOp_t
(
ctypes
.
c_int
):
ncclSum
=
0
ncclProd
=
1
ncclMax
=
2
ncclMin
=
3
ncclAvg
=
4
ncclNumOps
=
5
@
classmethod
def
from_torch
(
cls
,
op
:
ReduceOp
)
->
'ncclRedOp_t'
:
if
op
==
ReduceOp
.
SUM
:
return
cls
.
ncclSum
if
op
==
ReduceOp
.
PRODUCT
:
return
cls
.
ncclProd
if
op
==
ReduceOp
.
MAX
:
return
cls
.
ncclMax
if
op
==
ReduceOp
.
MIN
:
return
cls
.
ncclMin
if
op
==
ReduceOp
.
AVG
:
return
cls
.
ncclAvg
raise
ValueError
(
f
"Unsupported op:
{
op
}
"
)
# equivalent to c declaration:
# ncclResult_t ncclAllReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# udaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument is a pointer
_c_ncclAllReduce
=
nccl
.
ncclAllReduce
_c_ncclAllReduce
.
restype
=
ctypes
.
c_int
_c_ncclAllReduce
.
argtypes
=
[
ctypes
.
c_void_p
,
ctypes
.
c_void_p
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ncclRedOp_t
,
ctypes
.
c_void_p
,
ctypes
.
c_void_p
]
# equivalent to c declaration:
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
_c_ncclCommDestroy
=
nccl
.
ncclCommDestroy
_c_ncclCommDestroy
.
restype
=
ctypes
.
c_int
_c_ncclCommDestroy
.
argtypes
=
[
ctypes
.
c_void_p
]
class
NCCLCommunicator
:
def
__init__
(
self
,
backend
=
None
,
init_method
=
None
,
timeout
=
datetime
.
timedelta
(
seconds
=
10
),
world_size
:
int
=
-
1
,
rank
:
int
=
-
1
,
store
=
None
,
group_name
:
str
=
""
,
pg_options
=
None
,
local_rank
:
int
=
-
1
,
):
if
not
dist
.
is_initialized
():
backend
=
backend
or
"nccl"
assert
backend
==
'nccl'
,
(
"only use nccl backend for starting the NCCL communicator"
)
dist
.
init_process_group
(
backend
=
backend
,
init_method
=
init_method
,
timeout
=
timeout
,
world_size
=
world_size
,
rank
=
rank
,
store
=
store
,
group_name
=
group_name
,
pg_options
=
pg_options
)
self
.
rank
=
dist
.
get_rank
()
self
.
world_size
=
dist
.
get_world_size
()
if
local_rank
==
-
1
:
local_rank
=
self
.
rank
self
.
local_rank
=
local_rank
torch
.
cuda
.
set_device
(
local_rank
)
if
rank
==
0
:
self
.
unique_id
=
ncclGetUniqueId
()
else
:
self
.
unique_id
=
NcclUniqueId
()
tensor
=
torch
.
ByteTensor
(
list
(
self
.
unique_id
.
internal
)).
cuda
(
local_rank
)
dist
.
broadcast
(
tensor
,
src
=
0
)
byte_list
=
tensor
.
cpu
().
tolist
()
for
i
,
byte
in
enumerate
(
byte_list
):
self
.
unique_id
.
internal
[
i
]
=
byte
self
.
comm
=
ctypes
.
c_void_p
()
result
=
_c_ncclCommInitRank
(
ctypes
.
byref
(
self
.
comm
),
world_size
,
self
.
unique_id
,
rank
)
assert
result
==
0
self
.
stream
=
torch
.
cuda
.
Stream
(
device
=
f
"cuda:
{
local_rank
}
"
)
def
all_reduce
(
self
,
tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
stream
=
None
):
if
stream
is
None
:
stream
=
self
.
stream
result
=
_c_ncclAllReduce
(
ctypes
.
c_void_p
(
tensor
.
data_ptr
()),
ctypes
.
c_void_p
(
tensor
.
data_ptr
()),
tensor
.
numel
(),
ncclDataType_t
.
from_torch
(
tensor
.
dtype
),
ncclRedOp_t
.
from_torch
(
op
),
self
.
comm
,
ctypes
.
c_void_p
(
stream
.
cuda_stream
))
assert
result
==
0
def
__del__
(
self
):
# `dist` module might have been already destroyed
if
hasattr
(
dist
,
'destroy_process_group'
):
dist
.
destroy_process_group
()
_c_ncclCommDestroy
(
self
.
comm
)
vllm/model_executor/parallel_utils/pynccl_utils.py
0 → 100644
View file @
7c4f76e3
import
contextlib
from
typing
import
Optional
import
torch
from
torch.distributed
import
ReduceOp
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
try
:
from
vllm.model_executor.parallel_utils.pynccl
import
(
NCCLCommunicator
,
ncclGetVersion
)
except
Exception
as
e
:
# in non-NVIDIA environments, we can't import the nccl module
# e.g. when running on machines with AMD GPUs
logger
.
info
(
f
"Failed to import NCCL library:
{
e
}
"
)
logger
.
info
(
"It is expected if you are not running on NVIDIA GPUs."
)
pass
comm
:
Optional
[
"NCCLCommunicator"
]
=
None
def
is_initialized
()
->
bool
:
"""Returns whether the NCCL backend is initialized."""
return
comm
is
not
None
@
contextlib
.
contextmanager
def
set_pynccl_stream
(
stream
:
torch
.
cuda
.
Stream
):
"""Set the cuda stream for communication"""
try
:
comm
.
stream
=
stream
yield
finally
:
pass
def
init_process_group
(
world_size
:
int
,
rank
:
int
,
init_method
:
str
,
local_rank
:
int
=
-
1
)
->
None
:
assert
not
is_initialized
()
global
comm
logger
.
info
(
f
"vLLM is using nccl==
{
ncclGetVersion
()
}
"
)
comm
=
NCCLCommunicator
(
init_method
=
init_method
,
world_size
=
world_size
,
local_rank
=
local_rank
,
rank
=
rank
)
def
all_reduce
(
input_
:
torch
.
Tensor
,
op
=
ReduceOp
.
SUM
)
->
None
:
"""All-reduces the input tensor across the process group."""
assert
input_
.
is_cuda
,
f
"
{
input_
}
should be a cuda tensor"
comm
.
all_reduce
(
input_
,
op
)
def
destroy_process_group
()
->
None
:
global
comm
comm
=
None
def
get_world_size
()
->
int
:
"""Returns the world size."""
return
comm
.
world_size
def
get_nccl_backend
():
return
comm
vllm/model_executor/sampling_metadata.py
View file @
7c4f76e3
import
random
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
vllm.model_executor.layers.ops.sample
import
get_num_triton_sampler_splits
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
from
vllm.utils
import
i
n_wsl
,
is_neuron
from
vllm.utils
import
i
s_pin_memory_available
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
_SEED_0_REPLACEMENT
=
3403598558
class
SamplingMetadata
:
class
SamplingMetadata
:
...
@@ -67,14 +70,28 @@ class SamplingTensors:
...
@@ -67,14 +70,28 @@ class SamplingTensors:
presence_penalties
:
torch
.
Tensor
presence_penalties
:
torch
.
Tensor
frequency_penalties
:
torch
.
Tensor
frequency_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
repetition_penalties
:
torch
.
Tensor
sampling_seeds
:
torch
.
Tensor
sample_indices
:
torch
.
Tensor
extra_seeds
:
Optional
[
torch
.
Tensor
]
prompt_tokens
:
torch
.
Tensor
prompt_tokens
:
torch
.
Tensor
output_tokens
:
torch
.
Tensor
output_tokens
:
torch
.
Tensor
@
classmethod
@
classmethod
def
from_sampling_metadata
(
def
from_sampling_metadata
(
cls
,
sampling_metadata
:
"SamplingMetadata"
,
vocab_size
:
int
,
cls
,
device
:
torch
.
device
,
sampling_metadata
:
"SamplingMetadata"
,
dtype
:
torch
.
dtype
)
->
Tuple
[
"SamplingTensors"
,
bool
,
bool
,
bool
]:
vocab_size
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
*
,
extra_seeds_to_generate
:
int
=
0
,
extra_entropy
:
Optional
[
Tuple
[
int
,
...]]
=
None
)
->
Tuple
[
"SamplingTensors"
,
bool
,
bool
,
bool
]:
"""
extra_seeds_to_generate: extra seeds to generate using the
user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds.
"""
prompt_tokens
:
List
[
List
[
int
]]
=
[]
prompt_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
top_ks
:
List
[
int
]
=
[]
top_ks
:
List
[
int
]
=
[]
...
@@ -84,9 +101,18 @@ class SamplingTensors:
...
@@ -84,9 +101,18 @@ class SamplingTensors:
presence_penalties
:
List
[
float
]
=
[]
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
sampling_seeds
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
prompt_best_of
:
List
[
int
]
=
[]
do_penalties
=
False
do_penalties
=
False
do_top_p_top_k
=
False
do_top_p_top_k
=
False
do_min_p
=
False
do_min_p
=
False
# We need one base seed per Triton slice.
seeds_to_generate
=
(
extra_seeds_to_generate
+
get_num_triton_sampler_splits
(
vocab_size
))
sample_indices_start_idx
=
0
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
temperature
=
sampling_params
.
temperature
...
@@ -95,6 +121,10 @@ class SamplingTensors:
...
@@ -95,6 +121,10 @@ class SamplingTensors:
r
=
sampling_params
.
repetition_penalty
r
=
sampling_params
.
repetition_penalty
top_p
=
sampling_params
.
top_p
top_p
=
sampling_params
.
top_p
min_p
=
sampling_params
.
min_p
min_p
=
sampling_params
.
min_p
seed
=
sampling_params
.
seed
is_greedy
=
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
# k should not be greater than the vocab size.
# k should not be greater than the vocab size.
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
...
@@ -112,9 +142,11 @@ class SamplingTensors:
...
@@ -112,9 +142,11 @@ class SamplingTensors:
or
abs
(
f
)
>=
_SAMPLING_EPS
or
abs
(
f
)
>=
_SAMPLING_EPS
or
abs
(
r
-
1.0
)
>=
_SAMPLING_EPS
):
or
abs
(
r
-
1.0
)
>=
_SAMPLING_EPS
):
do_penalties
=
True
do_penalties
=
True
if
(
i
<
sampling_metadata
.
num_prompts
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get their logprobs
# For tokens in the prompt that we only need to get
# their logprobs
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
...
@@ -137,10 +169,34 @@ class SamplingTensors:
...
@@ -137,10 +169,34 @@ class SamplingTensors:
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
is_prompt
=
i
<
sampling_metadata
.
num_prompts
if
is_prompt
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: the sampling position is the last token
# in the prompt
sample_indices_start_idx
+=
prompt_len
-
1
for
seq_id
in
seq_ids
:
seq_data
=
sampling_metadata
.
seq_data
[
seq_id
]
extra_entropy
=
extra_entropy
or
()
seq_seeds
=
cls
.
_get_sequence_seeds
(
seed
,
seq_data
.
get_len
(),
*
extra_entropy
,
seq_id
,
seeds_to_generate
=
seeds_to_generate
,
is_greedy
=
is_greedy
)
sampling_seeds
.
append
(
seq_seeds
)
sample_indices
.
append
(
sample_indices_start_idx
)
sample_indices_start_idx
+=
1
sampling_tensors
=
SamplingTensors
.
from_lists
(
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
frequency_penalties
,
repetition_penalties
,
prompt_tokens
,
frequency_penalties
,
repetition_penalties
,
sampling_seeds
,
output_tokens
,
vocab_size
,
device
,
dtype
)
sample_indices
,
prompt_tokens
,
output_tokens
,
vocab_size
,
extra_seeds_to_generate
,
device
,
dtype
)
return
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
return
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
@
classmethod
@
classmethod
...
@@ -149,13 +205,14 @@ class SamplingTensors:
...
@@ -149,13 +205,14 @@ class SamplingTensors:
presence_penalties
:
List
[
float
],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
sampling_seeds
:
List
[
int
],
sample_indices
:
List
[
int
],
prompt_tokens
:
List
[
List
[
int
]],
prompt_tokens
:
List
[
List
[
int
]],
output_tokens
:
List
[
List
[
int
]],
vocab_size
:
int
,
output_tokens
:
List
[
List
[
int
]],
vocab_size
:
int
,
device
:
torch
.
device
,
extra_seeds_to_generate
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
# Note that the performance will be very bad without
# Note that the performance will be very bad without
# pinned memory.
# pinned memory.
pin_memory
=
not
in_wsl
()
and
not
is_neuron
()
pin_memory
=
is_pin_memory_available
()
prompt_max_len
=
max
(
len
(
tokens
)
for
tokens
in
prompt_tokens
)
prompt_max_len
=
max
(
len
(
tokens
)
for
tokens
in
prompt_tokens
)
prompt_padded_tokens
=
[
prompt_padded_tokens
=
[
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
tokens
+
[
vocab_size
]
*
(
prompt_max_len
-
len
(
tokens
))
...
@@ -209,6 +266,12 @@ class SamplingTensors:
...
@@ -209,6 +266,12 @@ class SamplingTensors:
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
)
)
sample_indices_t
=
torch
.
tensor
(
sample_indices
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
)
prompt_tensor
=
torch
.
tensor
(
prompt_tensor
=
torch
.
tensor
(
prompt_padded_tokens
,
prompt_padded_tokens
,
device
=
"cpu"
,
device
=
"cpu"
,
...
@@ -221,8 +284,28 @@ class SamplingTensors:
...
@@ -221,8 +284,28 @@ class SamplingTensors:
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
)
)
# need to transpose and make contiguous to
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
sampling_seeds_t
=
torch
.
tensor
(
sampling_seeds
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
).
T
.
contiguous
()
# Because the memory is pinned, we can do non-blocking
# Because the memory is pinned, we can do non-blocking
# transfer to device.
# transfer to device.
# How many seeds the sample operation itself will need.
num_base_seeds
=
sampling_seeds_t
.
shape
[
0
]
-
extra_seeds_to_generate
sampling_seeds_gpu
=
sampling_seeds_t
.
to
(
device
=
device
,
non_blocking
=
True
)
extra_seeds_gpu
=
sampling_seeds_gpu
[
num_base_seeds
:]
if
not
extra_seeds_gpu
.
numel
():
extra_seeds_gpu
=
None
sampling_seeds_gpu
=
sampling_seeds_gpu
[:
num_base_seeds
]
return
cls
(
return
cls
(
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
temperatures
=
temperatures_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
top_ps
=
top_ps_t
.
to
(
device
=
device
,
non_blocking
=
True
),
...
@@ -236,4 +319,38 @@ class SamplingTensors:
...
@@ -236,4 +319,38 @@ class SamplingTensors:
non_blocking
=
True
),
non_blocking
=
True
),
prompt_tokens
=
prompt_tensor
.
to
(
device
=
device
,
non_blocking
=
True
),
prompt_tokens
=
prompt_tensor
.
to
(
device
=
device
,
non_blocking
=
True
),
output_tokens
=
output_tensor
.
to
(
device
=
device
,
non_blocking
=
True
),
output_tokens
=
output_tensor
.
to
(
device
=
device
,
non_blocking
=
True
),
sampling_seeds
=
sampling_seeds_gpu
,
sample_indices
=
sample_indices_t
.
to
(
device
=
device
,
non_blocking
=
True
),
extra_seeds
=
extra_seeds_gpu
,
)
)
@
staticmethod
def
_get_sequence_seeds
(
seed
:
int
,
*
extra_entropy
:
int
,
seeds_to_generate
:
int
,
is_greedy
:
bool
,
):
"""Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
if
not
is_greedy
:
if
seed
is
None
:
randint_fn
=
random
.
randint
else
:
generator
=
random
.
Random
(
str
((
seed
,
)
+
extra_entropy
))
randint_fn
=
generator
.
randint
lo
,
hi
=
torch
.
iinfo
(
torch
.
long
).
min
,
torch
.
iinfo
(
torch
.
long
).
max
# If the user/random sets seed = 0 but request should
# have sampling, we need to change it to something
# else. We use a constant in that case.
# This way we don't need to create and load a bool
# matrix in the sampling kernel, which reduces CPU
# overhead and latency.
seq_seeds
=
[
randint_fn
(
lo
,
hi
)
or
_SEED_0_REPLACEMENT
for
_
in
range
(
seeds_to_generate
)
]
else
:
# For the kernel, seed == 0 means greedy decoding.
seq_seeds
=
[
0
]
*
seeds_to_generate
return
seq_seeds
vllm/model_executor/utils.py
View file @
7c4f76e3
"""Utils for model executor."""
"""Utils for model executor."""
import
random
import
random
import
importlib
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.config
import
DeviceConfig
,
ModelConfig
DEVICE_TO_MODEL_LOADER_MAP
=
{
"cuda"
:
"model_loader"
,
"neuron"
:
"neuron_model_loader"
,
}
def
set_random_seed
(
seed
:
int
)
->
None
:
def
set_random_seed
(
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
random
.
seed
(
seed
)
...
@@ -41,12 +33,3 @@ def set_weight_attrs(
...
@@ -41,12 +33,3 @@ def set_weight_attrs(
assert
not
hasattr
(
assert
not
hasattr
(
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
weight
,
key
),
(
f
"Overwriting existing tensor attribute:
{
key
}
"
)
setattr
(
weight
,
key
,
value
)
setattr
(
weight
,
key
,
value
)
def
get_model
(
model_config
:
ModelConfig
,
device_config
:
DeviceConfig
,
**
kwargs
)
->
torch
.
nn
.
Module
:
model_loader_module
=
DEVICE_TO_MODEL_LOADER_MAP
[
device_config
.
device_type
]
imported_model_loader
=
importlib
.
import_module
(
f
"vllm.model_executor.
{
model_loader_module
}
"
)
get_model_fn
=
imported_model_loader
.
get_model
return
get_model_fn
(
model_config
,
device_config
,
**
kwargs
)
vllm/model_executor/weight_utils.py
View file @
7c4f76e3
"""Utilities for downloading and initializing model weights."""
"""Utilities for downloading and initializing model weights."""
import
filelock
import
glob
import
fnmatch
import
fnmatch
import
glob
import
hashlib
import
json
import
json
import
os
import
os
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Any
,
Iterator
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Iterator
,
List
,
Optional
,
Tuple
from
huggingface_hub
import
snapshot_download
,
HfFileSystem
import
filelock
import
numpy
as
np
import
numpy
as
np
from
safetensors.torch
import
load_file
,
save_file
,
safe_open
import
torch
import
torch
from
huggingface_hub
import
HfFileSystem
,
snapshot_download
from
safetensors.torch
import
load_file
,
safe_open
,
save_file
from
tqdm.auto
import
tqdm
from
tqdm.auto
import
tqdm
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
(
get_q
uantization
_c
onfig
,
from
vllm.model_executor.layers.quantization
import
(
Q
uantization
C
onfig
,
Q
uantization
C
onfig
)
get_q
uantization
_c
onfig
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir
=
os
.
environ
.
get
(
'TMPDIR'
)
or
os
.
environ
.
get
(
'TEMP'
)
or
os
.
environ
.
get
(
'TMP'
)
or
"/tmp/"
class
Disabledtqdm
(
tqdm
):
class
Disabledtqdm
(
tqdm
):
...
@@ -28,9 +36,15 @@ class Disabledtqdm(tqdm):
...
@@ -28,9 +36,15 @@ class Disabledtqdm(tqdm):
def
get_lock
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
):
def
get_lock
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
):
lock_dir
=
cache_dir
if
cache_dir
is
not
None
else
"/tmp"
lock_dir
=
cache_dir
or
temp_dir
lock_file_name
=
model_name_or_path
.
replace
(
"/"
,
"-"
)
+
".lock"
os
.
makedirs
(
os
.
path
.
dirname
(
lock_dir
),
exist_ok
=
True
)
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
))
model_name
=
model_name_or_path
.
replace
(
"/"
,
"-"
)
hash_name
=
hashlib
.
sha256
(
model_name
.
encode
()).
hexdigest
()
# add hash to avoid conflict with old users' lock files
lock_file_name
=
hash_name
+
model_name
+
".lock"
# mode 0o666 is required for the filelock to be shared across users
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
),
mode
=
0o666
)
return
lock
return
lock
...
...
Prev
1
…
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment