Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
69672f11
Unverified
Commit
69672f11
authored
Jul 14, 2024
by
youkaichao
Committed by
GitHub
Jul 14, 2024
Browse files
[core][distributed] simplify code to support pipeline parallel (#6406)
parent
44874a0b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
107 additions
and
61 deletions
+107
-61
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-3
tests/basic_correctness/test_basic_correctness.py
tests/basic_correctness/test_basic_correctness.py
+8
-3
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+20
-27
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+22
-28
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+56
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
69672f11
...
@@ -46,9 +46,7 @@ steps:
...
@@ -46,9 +46,7 @@ steps:
fast_check
:
true
fast_check
:
true
commands
:
commands
:
-
pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
-
pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
-
VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
-
pytest -v -s basic_correctness/test_basic_correctness.py
-
VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
-
VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py
-
VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
-
VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
-
VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
-
VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
-
VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
-
VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
...
...
tests/basic_correctness/test_basic_correctness.py
View file @
69672f11
...
@@ -28,10 +28,8 @@ def test_vllm_gc_ed():
...
@@ -28,10 +28,8 @@ def test_vllm_gc_ed():
assert
weak_llm
()
is
None
assert
weak_llm
()
is
None
@
pytest
.
mark
.
skipif
(
is_hip
()
and
os
.
getenv
(
"VLLM_ATTENTION_BACKEND"
)
==
"FLASHINFER"
,
reason
=
"Flashinfer does not support ROCm/HIP."
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"XFORMERS"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
,
True
])
...
@@ -40,10 +38,17 @@ def test_models(
...
@@ -40,10 +38,17 @@ def test_models(
vllm_runner
,
vllm_runner
,
example_prompts
,
example_prompts
,
model
:
str
,
model
:
str
,
backend
:
str
,
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
enforce_eager
:
bool
,
enforce_eager
:
bool
,
)
->
None
:
)
->
None
:
if
backend
==
"FLASHINFER"
and
is_hip
():
pytest
.
skip
(
"Flashinfer does not support ROCm/HIP."
)
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
backend
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
...
...
vllm/model_executor/models/gpt2.py
View file @
69672f11
...
@@ -27,7 +27,6 @@ from vllm.attention import Attention, AttentionMetadata
...
@@ -27,7 +27,6 @@ from vllm.attention import Attention, AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
get_pp_group
,
get_tensor_model_parallel_world_size
)
get_pp_group
,
get_tensor_model_parallel_world_size
)
from
vllm.distributed.utils
import
get_pp_indices
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
...
@@ -42,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -42,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.utils
import
is_pp_missing_parameter
,
make_layers
class
GPT2Attention
(
nn
.
Module
):
class
GPT2Attention
(
nn
.
Module
):
...
@@ -183,18 +184,9 @@ class GPT2Model(nn.Module):
...
@@ -183,18 +184,9 @@ class GPT2Model(nn.Module):
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
)
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
self
.
embed_dim
)
self
.
start_layer
,
self
.
end_layer
=
get_pp_indice
s
(
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layer
s
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
get_pp_group
().
rank_in_group
,
lambda
:
GPT2Block
(
config
,
cache_config
,
quant_config
))
get_pp_group
().
world_size
)
self
.
h
=
nn
.
ModuleList
(
[
nn
.
Identity
()
for
_
in
range
(
self
.
start_layer
)]
+
[
GPT2Block
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
self
.
start_layer
,
self
.
end_layer
)
]
+
[
nn
.
Identity
()
for
_
in
range
(
self
.
end_layer
,
config
.
num_hidden_layers
)
])
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
def
forward
(
...
@@ -291,19 +283,20 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -291,19 +283,20 @@ class GPT2LMHeadModel(nn.Module):
continue
continue
if
not
name
.
startswith
(
"transformer."
):
if
not
name
.
startswith
(
"transformer."
):
name
=
"transformer."
+
name
name
=
"transformer."
+
name
try
:
param
=
params_dict
[
name
]
if
is_pp_missing_parameter
(
name
,
self
):
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for
conv1d_weight_name
in
[
"c_attn"
,
"c_proj"
,
"c_fc"
]:
if
conv1d_weight_name
not
in
name
:
continue
if
not
name
.
endswith
(
".weight"
):
continue
loaded_weight
=
loaded_weight
.
t
()
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
except
KeyError
:
continue
continue
param
=
params_dict
[
name
]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for
conv1d_weight_name
in
[
"c_attn"
,
"c_proj"
,
"c_fc"
]:
if
conv1d_weight_name
not
in
name
:
continue
if
not
name
.
endswith
(
".weight"
):
continue
loaded_weight
=
loaded_weight
.
t
()
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/llama.py
View file @
69672f11
...
@@ -29,8 +29,7 @@ from transformers import LlamaConfig
...
@@ -29,8 +29,7 @@ from transformers import LlamaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
(
get_pp_group
,
get_pp_indices
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
@@ -51,6 +50,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
...
@@ -51,6 +50,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from
vllm.utils
import
is_hip
,
print_warning_once
from
vllm.utils
import
is_hip
,
print_warning_once
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
from
.utils
import
is_pp_missing_parameter
,
make_layers
class
LlamaMLP
(
nn
.
Module
):
class
LlamaMLP
(
nn
.
Module
):
...
@@ -262,20 +262,11 @@ class LlamaModel(nn.Module):
...
@@ -262,20 +262,11 @@ class LlamaModel(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
)
self
.
start_layer
,
self
.
end_layer
=
get_pp_indice
s
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layer
s
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
get_pp_group
().
rank_in_group
,
lambda
:
LlamaDecoderLayer
(
config
=
config
,
get_pp_group
().
world_size
)
cache_config
=
cache_config
,
self
.
layers
=
nn
.
ModuleList
(
quant_config
=
quant_config
))
[
nn
.
Identity
()
for
_
in
range
(
self
.
start_layer
)]
+
[
LlamaDecoderLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
for
_
in
range
(
self
.
start_layer
,
self
.
end_layer
)
]
+
[
nn
.
Identity
()
for
_
in
range
(
self
.
end_layer
,
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -455,12 +446,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -455,12 +446,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
try
:
param
=
params_dict
[
name
]
if
is_pp_missing_parameter
(
name
,
self
):
weight_loader
=
param
.
weight_loader
continue
weight_loader
(
param
,
loaded_weight
,
shard_id
)
except
KeyError
:
param
=
params_dict
[
name
]
pass
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
...
@@ -479,13 +472,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -479,13 +472,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
continue
continue
else
:
else
:
name
=
remapped_kv_scale_name
name
=
remapped_kv_scale_name
try
:
param
=
params_dict
[
name
]
if
is_pp_missing_parameter
(
name
,
self
):
weight_loader
=
getattr
(
param
,
"weight_loader"
,
continue
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
param
=
params_dict
[
name
]
except
KeyError
:
weight_loader
=
getattr
(
param
,
"weight_loader"
,
pass
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# If this function is called, it should always initialize KV cache scale
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# factors (or else raise an exception). Thus, handled exceptions should
...
...
vllm/model_executor/models/utils.py
View file @
69672f11
from
typing
import
Callable
,
Dict
,
List
,
Tuple
import
torch
import
torch
from
vllm.multimodal
import
BatchedTensors
from
vllm.multimodal
import
BatchedTensors
...
@@ -39,3 +41,57 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
...
@@ -39,3 +41,57 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds
[
mask
]
=
torch
.
cat
(
vision_embeddings
)
inputs_embeds
[
mask
]
=
torch
.
cat
(
vision_embeddings
)
return
inputs_embeds
return
inputs_embeds
class
PPMissingLayer
(
torch
.
nn
.
Identity
):
"""
A placeholder layer for missing layers in a pipeline parallel model.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
()
def
make_layers
(
num_hidden_layers
:
int
,
layer_fn
:
Callable
[[],
torch
.
nn
.
Module
]
)
->
Tuple
[
int
,
int
,
torch
.
nn
.
ModuleList
]:
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
"""
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.utils
import
get_pp_indices
start_layer
,
end_layer
=
get_pp_indices
(
num_hidden_layers
,
get_pp_group
().
rank_in_group
,
get_pp_group
().
world_size
)
modules
=
torch
.
nn
.
ModuleList
(
[
PPMissingLayer
()
for
_
in
range
(
start_layer
)]
+
[
layer_fn
()
for
_
in
range
(
start_layer
,
end_layer
)]
+
[
PPMissingLayer
()
for
_
in
range
(
end_layer
,
num_hidden_layers
)])
return
start_layer
,
end_layer
,
modules
# NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names
:
Dict
[
int
,
List
[
str
]]
=
{}
def
get_pp_missing_layer_names
(
model
:
torch
.
nn
.
Module
)
->
List
[
str
]:
"""Get the names of the missing layers in a pipeline parallel model."""
model_id
=
id
(
model
)
if
model_id
in
_model_to_pp_missing_layer_names
:
return
_model_to_pp_missing_layer_names
[
model_id
]
missing_layer_names
=
[]
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
PPMissingLayer
):
missing_layer_names
.
append
(
name
)
_model_to_pp_missing_layer_names
[
model_id
]
=
missing_layer_names
return
missing_layer_names
def
is_pp_missing_parameter
(
name
:
str
,
model
:
torch
.
nn
.
Module
)
->
bool
:
"""Check if a parameter is missing in a pipeline parallel model."""
for
missing_layer_name
in
get_pp_missing_layer_names
(
model
):
if
name
.
startswith
(
missing_layer_name
):
return
True
return
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment