Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ba0bfd40
Unverified
Commit
ba0bfd40
authored
Oct 02, 2023
by
Zhuohan Li
Committed by
GitHub
Oct 02, 2023
Browse files
TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)
parent
84e4e37d
Changes
41
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
654 additions
and
1102 deletions
+654
-1102
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+16
-13
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+15
-16
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+27
-22
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+34
-29
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+34
-27
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+29
-23
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+24
-20
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+4
-8
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+4
-8
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+19
-17
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+28
-23
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+5
-9
vllm/model_executor/parallel_utils/__init__.py
vllm/model_executor/parallel_utils/__init__.py
+0
-7
vllm/model_executor/parallel_utils/communication_op.py
vllm/model_executor/parallel_utils/communication_op.py
+47
-0
vllm/model_executor/parallel_utils/layers.py
vllm/model_executor/parallel_utils/layers.py
+303
-0
vllm/model_executor/parallel_utils/parallel_state.py
vllm/model_executor/parallel_utils/parallel_state.py
+56
-376
vllm/model_executor/parallel_utils/tensor_parallel/__init__.py
...model_executor/parallel_utils/tensor_parallel/__init__.py
+0
-50
vllm/model_executor/parallel_utils/tensor_parallel/mappings.py
...model_executor/parallel_utils/tensor_parallel/mappings.py
+0
-281
vllm/model_executor/parallel_utils/tensor_parallel/random.py
vllm/model_executor/parallel_utils/tensor_parallel/random.py
+0
-164
vllm/model_executor/parallel_utils/utils.py
vllm/model_executor/parallel_utils/utils.py
+9
-9
No files found.
vllm/model_executor/models/bloom.py
View file @
ba0bfd40
...
...
@@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.parallel_utils.layers
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -85,14 +86,12 @@ class BloomAttention(nn.Module):
3
*
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
# Create the alibi slopes and slice them.
...
...
@@ -129,15 +128,17 @@ class BloomMLP(nn.Module):
def
__init__
(
self
,
config
:
BloomConfig
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
hidden_size
,
4
*
hidden_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
hidden_size
,
4
*
hidden_size
,
gather_output
=
False
,
)
self
.
act
=
get_act_fn
(
"gelu"
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
4
*
hidden_size
,
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
4
*
hidden_size
,
hidden_size
,
input_is_parallel
=
True
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
dense_h_to_4h
(
x
)
...
...
@@ -208,7 +209,9 @@ class BloomModel(nn.Module):
# Embedding + LN Embedding
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
,
perform_initialization
=
False
)
config
.
vocab_size
,
self
.
embed_dim
,
)
self
.
word_embeddings_layernorm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
...
...
vllm/model_executor/models/falcon.py
View file @
ba0bfd40
...
...
@@ -36,9 +36,11 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
,
reduce_from_tensor_model_parallel_region
)
from
vllm.model_executor.parallel_utils.layers
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
)
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs
import
RWConfig
...
...
@@ -109,7 +111,6 @@ class FalconAttention(nn.Module):
self
.
head_dim
,
bias
=
config
.
bias
,
gather_output
=
False
,
perform_initialization
=
False
,
skip_bias_add
=
True
,
)
elif
self
.
multi_query
:
...
...
@@ -120,7 +121,6 @@ class FalconAttention(nn.Module):
self
.
total_num_heads
*
self
.
head_dim
,
bias
=
config
.
bias
,
gather_output
=
False
,
perform_initialization
=
False
,
skip_bias_add
=
True
,
)
self
.
key_value
=
FalconLinear
(
self
.
hidden_size
,
...
...
@@ -135,7 +135,6 @@ class FalconAttention(nn.Module):
self
.
head_dim
,
bias
=
config
.
bias
,
gather_output
=
False
,
perform_initialization
=
False
,
skip_bias_add
=
True
,
)
...
...
@@ -151,7 +150,6 @@ class FalconAttention(nn.Module):
self
.
hidden_size
,
bias
=
config
.
bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
skip_bias_add
=
True
,
reduce_results
=
self
.
reduce_row_parallel_results
)
...
...
@@ -231,7 +229,6 @@ class FalconMLP(nn.Module):
4
*
hidden_size
,
bias
=
config
.
bias
,
gather_output
=
False
,
perform_initialization
=
False
,
skip_bias_add
=
True
)
self
.
act
=
nn
.
GELU
()
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
...
...
@@ -241,7 +238,6 @@ class FalconMLP(nn.Module):
hidden_size
,
bias
=
config
.
bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
skip_bias_add
=
True
,
reduce_results
=
self
.
reduce_row_parallel_results
)
...
...
@@ -325,7 +321,7 @@ class FalconDecoderLayer(nn.Module):
# only one all-reduce operator to reduce the results from
# both MLP and Attention layers.
mlp_output
+=
attention_output
mlp_output
=
reduce_from_
tensor_model_parallel_
region
(
mlp_output
)
mlp_output
=
tensor_model_parallel_
all_reduce
(
mlp_output
)
if
attention_bias
is
not
None
:
mlp_output
+=
attention_bias
if
mlp_bias
is
not
None
:
...
...
@@ -347,7 +343,9 @@ class FalconModel(nn.Module):
# Embedding + LN Embedding
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
,
perform_initialization
=
False
)
config
.
vocab_size
,
self
.
embed_dim
,
)
# Transformer blocks
self
.
h
=
nn
.
ModuleList
([
...
...
@@ -389,11 +387,12 @@ class FalconForCausalLM(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
transformer
=
FalconModel
(
config
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
...
...
vllm/model_executor/models/gpt2.py
View file @
ba0bfd40
...
...
@@ -36,8 +36,9 @@ from vllm.model_executor.weight_utils import (
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.parallel_utils.layers
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -56,16 +57,18 @@ class GPT2Attention(nn.Module):
self
.
head_dim
=
self
.
hidden_size
//
total_num_heads
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
3
*
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
)
...
...
@@ -95,16 +98,18 @@ class GPT2MLP(nn.Module):
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
gather_output
=
False
,
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
ba0bfd40
...
...
@@ -37,8 +37,9 @@ from vllm.model_executor.weight_utils import (
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.parallel_utils.layers
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -62,29 +63,31 @@ class GPTBigCodeAttention(nn.Module):
if
self
.
multi_query
:
self
.
num_kv_heads
=
1
self
.
kv_dim
=
self
.
head_dim
self
.
c_attn_q
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_attn_q
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
gather_output
=
False
,
)
self
.
c_attn_kv
=
nn
.
Linear
(
self
.
hidden_size
,
2
*
self
.
kv_dim
,
bias
=
True
)
else
:
self
.
num_kv_heads
=
self
.
num_heads
self
.
kv_dim
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
+
2
*
self
.
kv_dim
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
c_attn
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
+
2
*
self
.
kv_dim
,
bias
=
True
,
gather_output
=
False
,
)
self
.
c_proj
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scale
,
...
...
@@ -124,16 +127,18 @@ class GPTBigMLP(nn.Module):
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
c_fc
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
True
,
gather_output
=
False
,
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/model_executor/models/gpt_j.py
View file @
ba0bfd40
...
...
@@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.parallel_utils.layers
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -49,16 +50,18 @@ class GPTJAttention(nn.Module):
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
qkv_proj
=
ColumnParallelLinear
(
config
.
hidden_size
,
3
*
config
.
hidden_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
out_proj
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
qkv_proj
=
ColumnParallelLinear
(
config
.
hidden_size
,
3
*
config
.
hidden_size
,
bias
=
False
,
gather_output
=
False
,
)
self
.
out_proj
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
total_num_heads
%
tp_world_size
==
0
...
...
@@ -102,14 +105,16 @@ class GPTJMLP(nn.Module):
def
__init__
(
self
,
intermediate_size
:
int
,
config
:
GPTJConfig
):
super
().
__init__
()
hidden_size
=
config
.
n_embd
self
.
fc_in
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
fc_out
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
fc_in
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
gather_output
=
False
,
)
self
.
fc_out
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
input_is_parallel
=
True
,
)
self
.
act
=
get_act_fn
(
config
.
activation_function
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -159,9 +164,10 @@ class GPTJModel(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
n_embd
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
,
perform_initialization
=
False
)
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
,
)
self
.
h
=
nn
.
ModuleList
(
[
GPTJBlock
(
config
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
...
...
@@ -199,10 +205,11 @@ class GPTJForCausalLM(nn.Module):
self
.
config
=
config
assert
not
config
.
tie_word_embeddings
self
.
transformer
=
GPTJModel
(
config
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
n_embd
,
config
.
vocab_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
n_embd
,
config
.
vocab_size
,
gather_output
=
False
,
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
...
...
vllm/model_executor/models/gpt_neox.py
View file @
ba0bfd40
...
...
@@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.parallel_utils.layers
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -59,11 +60,12 @@ class GPTNeoXAttention(nn.Module):
config
.
hidden_size
,
3
*
config
.
hidden_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
)
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
)
scaling
=
self
.
head_size
**-
0.5
rotary_dim
=
int
(
self
.
head_size
*
config
.
rotary_pct
)
...
...
@@ -100,14 +102,16 @@ class GPTNeoXMLP(nn.Module):
def
__init__
(
self
,
config
:
GPTNeoXConfig
):
super
().
__init__
()
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
gather_output
=
False
,
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
input_is_parallel
=
True
,
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
)
def
forward
(
self
,
hidden_states
):
...
...
@@ -169,9 +173,10 @@ class GPTNeoXModel(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
embed_in
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
self
.
embed_in
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
(
[
GPTNeoXLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
final_layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
...
...
@@ -209,11 +214,12 @@ class GPTNeoXForCausalLM(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
gpt_neox
=
GPTNeoXModel
(
config
)
self
.
embed_out
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
embed_out
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
...
...
vllm/model_executor/models/internlm.py
View file @
ba0bfd40
...
...
@@ -12,8 +12,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
ColumnParallelLinear
,
RowParallelLinear
,
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.layers
import
(
ColumnParallelLinear
,
RowParallelLinear
,
VocabParallelEmbedding
)
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
...
...
@@ -31,16 +32,18 @@ class InternLMMLP(nn.Module):
hidden_act
:
str
,
):
super
().
__init__
()
self
.
gate_up_proj
=
ColumnParallelLinear
(
hidden_size
,
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
gate_up_proj
=
ColumnParallelLinear
(
hidden_size
,
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -80,14 +83,12 @@ class InternLMAttention(nn.Module):
3
*
self
.
total_num_heads
*
self
.
head_dim
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
True
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
...
...
@@ -176,7 +177,9 @@ class InternLMModel(nn.Module):
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
embed_tokens
=
VocabParallelEmbedding
(
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
InternLMDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
...
...
@@ -216,11 +219,12 @@ class InternLMForCausalLM(nn.Module):
self
.
config
=
config
self
.
model
=
InternLMModel
(
config
)
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
vocab_size
,
bias
=
False
,
gather_output
=
False
,
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
...
...
vllm/model_executor/models/llama.py
View file @
ba0bfd40
...
...
@@ -39,8 +39,7 @@ from vllm.model_executor.layers.sampler import Sampler
from
vllm.model_executor.layers.quantized_linear
import
ParallelLinear
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.layers
import
VocabParallelEmbedding
from
vllm.model_executor.quantization_utils
import
QuantizationConfig
from
vllm.model_executor.weight_utils
import
(
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
...
...
@@ -64,13 +63,11 @@ class LlamaMLP(nn.Module):
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
ParallelLinear
.
row
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
...
...
@@ -127,7 +124,6 @@ class LlamaAttention(nn.Module):
self
.
head_dim
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
ParallelLinear
.
row
(
...
...
@@ -135,7 +131,6 @@ class LlamaAttention(nn.Module):
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
quant_config
=
quant_config
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
...
...
@@ -241,7 +236,9 @@ class LlamaModel(nn.Module):
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
embed_tokens
=
VocabParallelEmbedding
(
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
...
...
@@ -291,7 +288,6 @@ class LlamaForCausalLM(nn.Module):
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
quant_config
=
None
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
...
...
vllm/model_executor/models/mistral.py
View file @
ba0bfd40
...
...
@@ -38,8 +38,7 @@ from vllm.model_executor.layers.sampler import Sampler
from
vllm.model_executor.layers.quantized_linear
import
ParallelLinear
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.layers
import
VocabParallelEmbedding
from
vllm.model_executor.quantization_utils
import
QuantizationConfig
from
vllm.model_executor.weight_utils
import
(
convert_pyslice_to_tensor
,
hf_model_weights_iterator
,
...
...
@@ -64,13 +63,11 @@ class MistralMLP(nn.Module):
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
quant_config
=
quant_config
)
self
.
down_proj
=
ParallelLinear
.
row
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
quant_config
=
quant_config
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
...
...
@@ -116,7 +113,6 @@ class MistralAttention(nn.Module):
self
.
head_dim
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
quant_config
=
quant_config
,
)
self
.
o_proj
=
ParallelLinear
.
row
(
...
...
@@ -124,7 +120,6 @@ class MistralAttention(nn.Module):
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
quant_config
=
quant_config
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
...
...
@@ -225,7 +220,9 @@ class MistralModel(nn.Module):
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
embed_tokens
=
VocabParallelEmbedding
(
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
MistralDecoderLayer
(
config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
...
...
@@ -275,7 +272,6 @@ class MistralForCausalLM(nn.Module):
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
quant_config
=
None
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
...
...
vllm/model_executor/models/mpt.py
View file @
ba0bfd40
...
...
@@ -15,8 +15,9 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.parallel_utils.layers
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
...
...
@@ -53,7 +54,6 @@ class MPTAttention(nn.Module):
3
*
self
.
d_model
,
bias
=
not
config
.
no_bias
,
gather_output
=
False
,
perform_initialization
=
False
,
)
if
self
.
qk_ln
:
self
.
q_ln
=
nn
.
LayerNorm
(
self
.
d_model
)
...
...
@@ -63,7 +63,6 @@ class MPTAttention(nn.Module):
self
.
d_model
,
bias
=
not
config
.
no_bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
tp_world_size
=
get_tensor_model_parallel_world_size
()
...
...
@@ -113,17 +112,19 @@ class MPTMLP(nn.Module):
hidden_size
=
config
.
d_model
expansion_ratio
=
config
.
expansion_ratio
intermediate_size
=
expansion_ratio
*
hidden_size
self
.
up_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
not
config
.
no_bias
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
up_proj
=
ColumnParallelLinear
(
hidden_size
,
intermediate_size
,
bias
=
not
config
.
no_bias
,
gather_output
=
False
,
)
self
.
act
=
get_act_fn
(
"gelu"
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
not
config
.
no_bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
not
config
.
no_bias
,
input_is_parallel
=
True
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
,
_
=
self
.
up_proj
(
x
)
...
...
@@ -172,9 +173,10 @@ class MPTModel(nn.Module):
assert
config
.
embedding_fraction
==
1.0
assert
config
.
norm_type
==
"low_precision_layernorm"
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
d_model
,
perform_initialization
=
False
)
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
d_model
,
)
self
.
blocks
=
nn
.
ModuleList
(
[
MPTBlock
(
config
)
for
_
in
range
(
config
.
n_layers
)])
self
.
norm_f
=
nn
.
LayerNorm
(
config
.
d_model
)
...
...
vllm/model_executor/models/opt.py
View file @
ba0bfd40
...
...
@@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.parallel_utils.layers
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -73,16 +74,18 @@ class OPTAttention(nn.Module):
self
.
head_dim
=
embed_dim
//
total_num_heads
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
qkv_proj
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
qkv_proj
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
gather_output
=
False
,
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
input_is_parallel
=
True
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
scaling
)
...
...
@@ -120,16 +123,18 @@ class OPTDecoderLayer(nn.Module):
self
.
self_attn_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
self
.
fc1
=
ColumnParallelLinear
(
self
.
embed_dim
,
config
.
ffn_dim
,
bias
=
config
.
enable_bias
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
fc2
=
RowParallelLinear
(
config
.
ffn_dim
,
self
.
embed_dim
,
bias
=
config
.
enable_bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
fc1
=
ColumnParallelLinear
(
self
.
embed_dim
,
config
.
ffn_dim
,
bias
=
config
.
enable_bias
,
gather_output
=
False
,
)
self
.
fc2
=
RowParallelLinear
(
config
.
ffn_dim
,
self
.
embed_dim
,
bias
=
config
.
enable_bias
,
input_is_parallel
=
True
,
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
self
.
embed_dim
,
elementwise_affine
=
config
.
layer_norm_elementwise_affine
)
...
...
@@ -182,7 +187,7 @@ class OPTDecoder(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
word_embed_proj_dim
,
perform_initialization
=
False
)
)
# Positional embeddings are replicated (not sharded).
self
.
embed_positions
=
OPTLearnedPositionalEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
...
...
vllm/model_executor/models/qwen.py
View file @
ba0bfd40
...
...
@@ -28,7 +28,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
vllm.model_executor.parallel_utils.
tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.
layers
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
,
...
...
@@ -53,14 +53,12 @@ class QWenMLP(nn.Module):
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
c_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
...
...
@@ -98,14 +96,12 @@ class QWenAttention(nn.Module):
3
*
hidden_size
,
bias
=
True
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
c_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
attn
=
PagedAttentionWithRoPE
(
...
...
@@ -190,9 +186,10 @@ class QWenModel(nn.Module):
self
.
vocab_size
=
config
.
vocab_size
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
wte
=
VocabParallelEmbedding
(
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
self
.
wte
=
VocabParallelEmbedding
(
vocab_size
,
config
.
hidden_size
,
)
self
.
h
=
nn
.
ModuleList
(
[
QWenBlock
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
...
@@ -235,7 +232,6 @@ class QWenLMHeadModel(nn.Module):
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
...
...
vllm/model_executor/parallel_utils/__init__.py
View file @
ba0bfd40
import
vllm.model_executor.parallel_utils.parallel_state
import
vllm.model_executor.parallel_utils.tensor_parallel
__all__
=
[
"parallel_state"
,
"tensor_parallel"
,
]
vllm/model_executor/parallel_utils/communication_op.py
0 → 100644
View file @
ba0bfd40
import
torch
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_group
,
)
def
tensor_model_parallel_all_reduce
(
input_
):
"""All-reduce the input tensor across model parallel group.
Note: This operation is applied in-place on the input tensor.
"""
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_world_size
()
==
1
:
return
input_
# All-reduce.
torch
.
distributed
.
all_reduce
(
input_
,
group
=
get_tensor_model_parallel_group
())
return
input_
def
tensor_model_parallel_all_gather
(
input_
,
dim
=-
1
):
"""All-gather the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
input_size
=
input_
.
size
()
# Allocate output tensor.
output_tensor
=
torch
.
empty
((
world_size
,
)
+
input_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
# All-gather.
torch
.
distributed
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
get_tensor_model_parallel_group
())
# Reshape
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
world_size
*
input_size
[
dim
],
)
+
input_size
[
dim
+
1
:])
return
output_tensor
vllm/model_executor/parallel_utils/
tensor_parallel/
layers.py
→
vllm/model_executor/parallel_utils/layers.py
View file @
ba0bfd40
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
...
...
@@ -8,60 +9,22 @@ from typing import Optional
import
torch
import
torch.nn.functional
as
F
import
torch.nn.init
as
init
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
)
from
.mappings
import
(
gather_from_tensor_model_parallel_region
,
reduce_from_tensor_model_parallel_region
,
scatter_to_tensor_model_parallel_region
,
)
from
vllm.model_executor.quantization_utils
import
QuantizationConfig
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_gather
)
from
.utils
import
(
from
vllm.model_executor.parallel_utils
.utils
import
(
divide
,
VocabUtility
,
split_tensor_along_last_dim
,
)
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
'tensor_model_parallel'
:
False
,
'partition_dim'
:
-
1
,
'partition_stride'
:
1
}
def
param_is_not_tensor_parallel_duplicate
(
param
):
return
(
hasattr
(
param
,
'tensor_model_parallel'
)
and
param
.
tensor_model_parallel
)
or
(
get_tensor_model_parallel_rank
()
==
0
)
def
set_tensor_model_parallel_attributes
(
tensor
,
is_parallel
,
dim
,
stride
):
# Make sure the attributes are not set.
for
attribute
in
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
:
assert
not
hasattr
(
tensor
,
attribute
)
# Set the attributes.
setattr
(
tensor
,
'tensor_model_parallel'
,
is_parallel
)
setattr
(
tensor
,
'partition_dim'
,
dim
)
setattr
(
tensor
,
'partition_stride'
,
stride
)
def
set_defaults_if_not_set_tensor_model_parallel_attributes
(
tensor
):
def
maybe_set
(
attribute
,
value
):
if
not
hasattr
(
tensor
,
attribute
):
setattr
(
tensor
,
attribute
,
value
)
for
attribute
in
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
:
maybe_set
(
attribute
,
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
[
attribute
])
def
copy_tensor_model_parallel_attributes
(
destination_tensor
,
source_tensor
):
def
maybe_copy
(
attribute
):
if
hasattr
(
source_tensor
,
attribute
):
setattr
(
destination_tensor
,
attribute
,
getattr
(
source_tensor
,
attribute
))
for
attribute
in
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
:
maybe_copy
(
attribute
)
class
VocabParallelEmbedding
(
torch
.
nn
.
Module
):
"""Embedding parallelized in the vocabulary dimension.
...
...
@@ -71,22 +34,14 @@ class VocabParallelEmbedding(torch.nn.Module):
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
Keyword Arguments:
init_method: method to initialize weights.
params_dtype
use_cpu_initialization
perform_initialization
params_dtype: type of the parameters.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
*
,
init_method
=
init
.
xavier_normal_
,
params_dtype
:
torch
.
dtype
=
None
,
use_cpu_initialization
:
bool
=
False
,
perform_initialization
:
bool
=
False
):
super
(
VocabParallelEmbedding
,
self
).
__init__
()
assert
not
perform_initialization
assert
not
use_cpu_initialization
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
super
().
__init__
()
# Keep the input dimensions.
self
.
num_embeddings
=
num_embeddings
...
...
@@ -94,46 +49,39 @@ class VocabParallelEmbedding(torch.nn.Module):
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
# Set the defaults for compatibility.
self
.
padding_idx
=
None
self
.
max_norm
=
None
self
.
norm_type
=
2.
self
.
scale_grad_by_freq
=
False
self
.
sparse
=
False
self
.
_weight
=
None
self
.
tensor_model_parallel_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
# TODO: Handle vocab padding here.
# Divide the weight matrix along the vocaburaly dimension.
self
.
vocab_start_index
,
self
.
vocab_end_index
=
\
self
.
vocab_start_index
,
self
.
vocab_end_index
=
(
VocabUtility
.
vocab_range_from_global_vocab_size
(
self
.
num_embeddings
,
get_tensor_model_parallel_rank
(),
self
.
t
ensor_model_parallel
_size
)
self
.
num_embeddings_per_partition
=
self
.
vocab_end_index
-
\
self
.
vocab_start_index
self
.
t
p
_size
)
)
self
.
num_embeddings_per_partition
=
(
self
.
vocab_end_index
-
self
.
vocab_start_index
)
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
))
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
self
.
embedding_dim
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
))
def
forward
(
self
,
input_
):
if
self
.
t
ensor_model_parallel
_size
>
1
:
if
self
.
t
p
_size
>
1
:
# Build the mask.
input_mask
=
(
input_
<
self
.
vocab_start_index
)
|
\
(
input_
>=
self
.
vocab_end_index
)
input_mask
=
(
(
input_
<
self
.
vocab_start_index
)
|
(
input_
>=
self
.
vocab_end_index
)
)
# Mask the input.
masked_input
=
input_
.
clone
()
-
self
.
vocab_start_index
masked_input
[
input_mask
]
=
0
else
:
masked_input
=
input_
# Get the embeddings.
output_parallel
=
F
.
embedding
(
masked_input
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
output_parallel
=
F
.
embedding
(
masked_input
,
self
.
weight
)
# Mask the output embedding.
if
self
.
t
ensor_model_parallel
_size
>
1
:
if
self
.
t
p
_size
>
1
:
output_parallel
[
input_mask
,
:]
=
0.0
# Reduce across all the model parallel GPUs.
output
=
reduce_from_
tensor_model_parallel_
region
(
output_parallel
)
output
=
tensor_model_parallel_
all_reduce
(
output_parallel
)
return
output
...
...
@@ -152,40 +100,32 @@ class ColumnParallelLinear(torch.nn.Module):
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
params_dtype:
use_cpu_initialization:
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configuration.
"""
def
__init__
(
self
,
input_size
,
output_size
,
*
,
bias
=
True
,
gather_output
=
True
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
keep_master_weight_for_test
=
False
,
skip_bias_add
=
False
,
params_dtype
=
None
,
use_cpu_initialization
=
False
,
perform_initialization
=
False
,
quant_config
=
None
,
):
super
(
ColumnParallelLinear
,
self
).
__init__
()
assert
not
perform_initialization
assert
not
use_cpu_initialization
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
gather_output
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
# Keep input parameters
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
gather_output
=
gather_output
# Divide the weight matrix along the last dimension.
self
.
world
_size
=
get_tensor_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
self
.
world
_size
)
self
.
tp
_size
=
get_tensor_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
self
.
tp
_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
quant_config
=
quant_config
...
...
@@ -198,21 +138,19 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
create_weights
(
params_dtype
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
))
set_tensor_model_parallel_attributes
(
self
.
bias
,
True
,
0
,
stride
)
# Always initialize bias to zero.
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
))
else
:
self
.
register_parameter
(
'bias'
,
None
)
def
create_weights
(
self
,
dtype
:
torch
.
dtype
)
->
None
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
input_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
))
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
self
.
input_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
))
def
apply_weights
(
self
,
...
...
@@ -225,7 +163,7 @@ class ColumnParallelLinear(torch.nn.Module):
"""Forward of ColumnParallelLinear
Args:
input_:
3D t
ensor whose
order of
dimension is
[sequence, batch, hidden]
input_:
T
ensor whose
last
dimension is
`input_size`.
Returns:
- output
...
...
@@ -238,7 +176,7 @@ class ColumnParallelLinear(torch.nn.Module):
output_parallel
=
self
.
apply_weights
(
input_parallel
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
gather_from_
tensor_model_parallel_
region
(
output_parallel
)
output
=
tensor_model_parallel_
all_gather
(
output_parallel
)
else
:
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
...
...
@@ -266,36 +204,25 @@ class RowParallelLinear(torch.nn.Module):
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip
adding bias but instead return it.
params_dtype:
use_cpu_initialization:
perform_initialization:
reduce_results:
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configuration.
"""
def
__init__
(
self
,
input_size
,
output_size
,
*
,
bias
=
True
,
input_is_parallel
=
False
,
init_method
=
init
.
xavier_normal_
,
stride
=
1
,
keep_master_weight_for_test
=
False
,
skip_bias_add
=
False
,
params_dtype
=
None
,
use_cpu_initialization
=
False
,
perform_initialization
=
False
,
reduce_results
=
True
,
quant_config
=
None
,
):
super
(
RowParallelLinear
,
self
).
__init__
()
assert
not
perform_initialization
assert
not
use_cpu_initialization
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
input_is_parallel
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
# Keep input parameters
self
.
input_size
=
input_size
self
.
output_size
=
output_size
...
...
@@ -305,21 +232,22 @@ class RowParallelLinear(torch.nn.Module):
params_dtype
=
torch
.
get_default_dtype
()
# Divide the weight matrix along the last dimension.
self
.
world
_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
world
_size
)
self
.
tp
_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp
_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
quant_config
=
quant_config
self
.
create_weights
(
params_dtype
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"
When not reduce the results, adding bias to the
"
"
results can lead to incorrect results
"
)
raise
ValueError
(
'
When not reduce the results, adding bias to the
'
'
results can lead to incorrect results
'
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
params_dtype
))
# Always initialize bias to zero.
with
torch
.
no_grad
():
...
...
@@ -328,9 +256,11 @@ class RowParallelLinear(torch.nn.Module):
self
.
register_parameter
(
'bias'
,
None
)
def
create_weights
(
self
,
dtype
:
torch
.
dtype
)
->
None
:
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
input_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
))
self
.
weight
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
self
.
input_size_per_partition
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
dtype
))
def
apply_weights
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
self
.
weight
)
...
...
@@ -339,7 +269,9 @@ class RowParallelLinear(torch.nn.Module):
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
input_: tensor whose last dimension is `input_size`. If
`input_is_parallel` is set, then the last dimension
is `input_size // tp_size`.
Returns:
- output
...
...
@@ -349,11 +281,16 @@ class RowParallelLinear(torch.nn.Module):
if
self
.
input_is_parallel
:
input_parallel
=
input_
else
:
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# TODO: simplify code below
tp_rank
=
get_tensor_model_parallel_rank
()
splitted_input
=
split_tensor_along_last_dim
(
input_
,
num_partitions
=
self
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
# Matrix multiply.
output_parallel
=
self
.
apply_weights
(
input_parallel
)
if
self
.
reduce_results
and
self
.
world
_size
>
1
:
output_
=
reduce_from_
tensor_model_parallel_
region
(
output_parallel
)
if
self
.
reduce_results
and
self
.
tp
_size
>
1
:
output_
=
tensor_model_parallel_
all_reduce
(
output_parallel
)
else
:
output_
=
output_parallel
...
...
vllm/model_executor/parallel_utils/parallel_state.py
View file @
ba0bfd40
This diff is collapsed.
Click to expand it.
vllm/model_executor/parallel_utils/tensor_parallel/__init__.py
deleted
100644 → 0
View file @
84e4e37d
from
.layers
import
(
ColumnParallelLinear
,
RowParallelLinear
,
VocabParallelEmbedding
,
set_tensor_model_parallel_attributes
,
set_defaults_if_not_set_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
,
param_is_not_tensor_parallel_duplicate
,
)
from
.mappings
import
(
copy_to_tensor_model_parallel_region
,
gather_from_tensor_model_parallel_region
,
gather_from_sequence_parallel_region
,
reduce_from_tensor_model_parallel_region
,
scatter_to_tensor_model_parallel_region
,
scatter_to_sequence_parallel_region
,
)
from
.random
import
(
get_cuda_rng_tracker
,
model_parallel_cuda_manual_seed
,
)
from
.utils
import
(
split_tensor_along_last_dim
,
)
__all__
=
[
#layers.py
"ColumnParallelLinear"
,
"RowParallelLinear"
,
"VocabParallelEmbedding"
,
"set_tensor_model_parallel_attributes"
,
"set_defaults_if_not_set_tensor_model_parallel_attributes"
,
"copy_tensor_model_parallel_attributes"
,
"param_is_not_tensor_parallel_duplicate"
,
# mappings.py
"copy_to_tensor_model_parallel_region"
,
"gather_from_tensor_model_parallel_region"
,
"gather_from_sequence_parallel_region"
,
"reduce_from_tensor_model_parallel_region"
,
"scatter_to_tensor_model_parallel_region"
,
"scatter_to_sequence_parallel_region"
,
# random.py
"get_cuda_rng_tracker"
,
"model_parallel_cuda_manual_seed"
,
# utils.py
"split_tensor_along_last_dim"
,
]
vllm/model_executor/parallel_utils/tensor_parallel/mappings.py
deleted
100644 → 0
View file @
84e4e37d
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
torch
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_group
,
)
from
.utils
import
split_tensor_along_last_dim
def
_reduce
(
input_
):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_world_size
()
==
1
:
return
input_
# All-reduce.
torch
.
distributed
.
all_reduce
(
input_
,
group
=
get_tensor_model_parallel_group
())
return
input_
def
_split_along_last_dim
(
input_
):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
# Split along last dimension.
input_list
=
split_tensor_along_last_dim
(
input_
,
world_size
)
# Note: torch.split does not create contiguous tensors by default.
rank
=
get_tensor_model_parallel_rank
()
output
=
input_list
[
rank
].
contiguous
()
return
output
def
_split_along_first_dim
(
input_
):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
# Split along first dimension.
dim_size
=
input_
.
size
()[
0
]
assert
dim_size
%
world_size
==
0
,
\
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size
=
dim_size
//
world_size
rank
=
get_tensor_model_parallel_rank
()
dim_offset
=
rank
*
local_dim_size
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
].
contiguous
()
return
output
def
_gather_along_last_dim
(
input_
):
"""Gather tensors and concatinate along the last dimension."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
# Size and dimension.
last_dim
=
input_
.
dim
()
-
1
rank
=
get_tensor_model_parallel_rank
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
get_tensor_model_parallel_group
())
# Note: torch.cat already creates a contiguous tensor.
output
=
torch
.
cat
(
tensor_list
,
dim
=
last_dim
).
contiguous
()
return
output
def
_gather_along_first_dim
(
input_
):
"""Gather tensors and concatinate along the first dimension."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
dim_size
=
list
(
input_
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_all_gather_base
(
output
,
input_
.
contiguous
(),
group
=
get_tensor_model_parallel_group
())
return
output
def
_reduce_scatter_along_first_dim
(
input_
):
"""Reduce-scatter the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
dim_size
=
list
(
input_
.
size
())
assert
dim_size
[
0
]
%
world_size
==
0
,
\
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size
[
0
]
=
dim_size
[
0
]
//
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_reduce_scatter_base
(
output
,
input_
.
contiguous
(),
group
=
get_tensor_model_parallel_group
())
return
output
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Pass the input to the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
input_
@
staticmethod
def
forward
(
ctx
,
input_
):
return
input_
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_reduce
(
grad_output
)
class
_ReduceFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""All-reduce the input from the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_reduce
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
grad_output
class
_ScatterToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split_along_last_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_split_along_last_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_last_dim
(
grad_output
)
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatinate."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_gather_along_last_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_gather_along_last_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_split_along_last_dim
(
grad_output
)
class
_ScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_split_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_first_dim
(
grad_output
)
class
_GatherFromSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatinate."""
@
staticmethod
def
symbolic
(
graph
,
input_
,
tensor_parallel_output_grad
=
True
):
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
tensor_parallel_output_grad
=
True
):
ctx
.
tensor_parallel_output_grad
=
tensor_parallel_output_grad
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
tensor_parallel_output_grad
=
ctx
.
tensor_parallel_output_grad
# If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if
tensor_parallel_output_grad
:
return
_reduce_scatter_along_first_dim
(
grad_output
),
None
else
:
return
_split_along_first_dim
(
grad_output
),
None
class
_ReduceScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_first_dim
(
grad_output
)
# -----------------
# Helper functions.
# -----------------
def
copy_to_tensor_model_parallel_region
(
input_
):
return
_CopyToModelParallelRegion
.
apply
(
input_
)
def
reduce_from_tensor_model_parallel_region
(
input_
):
return
_ReduceFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_tensor_model_parallel_region
(
input_
):
return
_ScatterToModelParallelRegion
.
apply
(
input_
)
def
gather_from_tensor_model_parallel_region
(
input_
):
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_sequence_parallel_region
(
input_
):
return
_ScatterToSequenceParallelRegion
.
apply
(
input_
)
def
gather_from_sequence_parallel_region
(
input_
,
tensor_parallel_output_grad
=
True
):
return
_GatherFromSequenceParallelRegion
.
apply
(
input_
,
tensor_parallel_output_grad
)
def
reduce_scatter_to_sequence_parallel_region
(
input_
):
return
_ReduceScatterToSequenceParallelRegion
.
apply
(
input_
)
vllm/model_executor/parallel_utils/tensor_parallel/random.py
deleted
100644 → 0
View file @
84e4e37d
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import
contextlib
import
torch
from
torch
import
_C
from
torch.cuda
import
_lazy_call
,
device
as
device_ctx_manager
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
)
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME
=
'model-parallel-rng'
def
_set_cuda_rng_state
(
new_state
,
device
=-
1
):
"""Sets the random number generator state of the current GPU.
Argumentss:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if
hasattr
(
_C
,
'_cuda_setRNGState'
)
and
callable
(
_C
.
_cuda_setRNGState
):
# older PyTorch
def
cb
():
with
device_ctx_manager
(
device
):
_C
.
_cuda_setRNGState
(
new_state
)
else
:
# newer PyTorch
if
device
==
-
1
:
device
=
torch
.
device
(
'cuda'
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
elif
isinstance
(
device
,
int
):
device
=
torch
.
device
(
'cuda'
,
device
)
def
cb
():
idx
=
device
.
index
if
idx
is
None
:
idx
=
torch
.
cuda
.
current_device
()
default_generator
=
torch
.
cuda
.
default_generators
[
idx
]
default_generator
.
set_state
(
new_state
)
_lazy_call
(
cb
)
class
CudaRNGStatesTracker
:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def
__init__
(
self
):
# Map from a string name to the cuda rng state.
self
.
states_
=
{}
# Seeds are just for book keeping and ensure no seed is set twice.
self
.
seeds_
=
set
()
def
reset
(
self
):
"""Set to the initial state (no tracker)."""
self
.
states_
=
{}
self
.
seeds_
=
set
()
def
get_states
(
self
):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
states
=
{}
for
name
in
self
.
states_
:
states
[
name
]
=
self
.
states_
[
name
]
return
states
def
set_states
(
self
,
states
):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self
.
states_
=
states
def
add
(
self
,
name
,
seed
):
"""Track the rng state."""
# Check seed is not already used.
if
seed
in
self
.
seeds_
:
raise
Exception
(
'seed {} already exists'
.
format
(
seed
))
self
.
seeds_
.
add
(
seed
)
# Check that state is not already defined.
if
name
in
self
.
states_
:
raise
Exception
(
'cuda rng state {} already exists'
.
format
(
name
))
# Get the current rng state.
orig_rng_state
=
torch
.
cuda
.
get_rng_state
()
# Set the new state and store it.
torch
.
cuda
.
manual_seed
(
seed
)
self
.
states_
[
name
]
=
torch
.
cuda
.
get_rng_state
()
# Reset rng state to what it was.
_set_cuda_rng_state
(
orig_rng_state
)
@
contextlib
.
contextmanager
def
fork
(
self
,
name
=
_MODEL_PARALLEL_RNG_TRACKER_NAME
):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if
name
not
in
self
.
states_
:
raise
Exception
(
'cuda rng state {} is not added'
.
format
(
name
))
# Store current rng state.
orig_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
# Set rng state to the desired one
_set_cuda_rng_state
(
self
.
states_
[
name
])
# Do the stuff we wanted to do.
try
:
yield
finally
:
# Update the current rng state for later use.
self
.
states_
[
name
]
=
torch
.
cuda
.
get_rng_state
()
# And set the state to the original state we started with.
_set_cuda_rng_state
(
orig_cuda_rng_state
)
# RNG tracker object.
_CUDA_RNG_STATE_TRACKER
=
CudaRNGStatesTracker
()
def
get_cuda_rng_tracker
():
"""Get cuda rng tracker."""
return
_CUDA_RNG_STATE_TRACKER
def
model_parallel_cuda_manual_seed
(
seed
):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
initialized. Also, no torch.cuda.manual_seed should be called
after this function. Basically, this is replacement for that
function.
Two set of RNG states are tracked:
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-tensor-model-parallel regions.
tensor-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
offset
=
seed
+
2718
tensor_model_parallel_seed
=
offset
+
get_tensor_model_parallel_rank
()
# Data parallel gets the original seed.
data_parallel_seed
=
seed
_CUDA_RNG_STATE_TRACKER
.
reset
()
# Set the default state.
torch
.
cuda
.
manual_seed
(
data_parallel_seed
)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER
.
add
(
_MODEL_PARALLEL_RNG_TRACKER_NAME
,
tensor_model_parallel_seed
)
vllm/model_executor/parallel_utils/
tensor_parallel/
utils.py
→
vllm/model_executor/parallel_utils/utils.py
View file @
ba0bfd40
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
typing
import
List
,
Sequence
import
torch
from
typing
import
List
,
Sequence
def
ensure_divisibility
(
numerator
,
denominator
):
"""Ensure that numerator is divisible by the denominator."""
assert
numerator
%
denominator
==
0
,
"{} is not divisible by {}"
.
format
(
numerator
,
denominator
)
numerator
,
denominator
)
def
divide
(
numerator
,
denominator
):
...
...
@@ -56,15 +57,14 @@ class VocabUtility:
@
staticmethod
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
:
int
,
rank
,
world_size
:
int
)
->
Sequence
[
int
]:
per_partition_vocab_size
:
int
,
rank
:
int
)
->
Sequence
[
int
]:
index_f
=
rank
*
per_partition_vocab_size
index_l
=
index_f
+
per_partition_vocab_size
return
index_f
,
index_l
@
staticmethod
def
vocab_range_from_global_vocab_size
(
global_vocab_size
:
int
,
rank
:
int
,
world_size
:
int
)
->
Sequence
[
int
]:
def
vocab_range_from_global_vocab_size
(
global_vocab_size
:
int
,
rank
:
int
,
world_size
:
int
)
->
Sequence
[
int
]:
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
return
VocabUtility
.
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
,
world_size
)
per_partition_vocab_size
,
rank
)
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment