Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8685ba1a
Unverified
Commit
8685ba1a
authored
Sep 05, 2024
by
manikandan.tm@zucisystems.com
Committed by
GitHub
Sep 05, 2024
Browse files
Inclusion of InternVLChatModel In PP_SUPPORTED_MODELS(Pipeline Parallelism) (#7860)
parent
288a9388
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
90 additions
and
35 deletions
+90
-35
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+22
-16
tests/utils.py
tests/utils.py
+6
-1
vllm/config.py
vllm/config.py
+5
-3
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+38
-14
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+3
-1
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+16
-0
No files found.
tests/distributed/test_pipeline_parallel.py
View file @
8685ba1a
...
@@ -18,23 +18,26 @@ logger = init_logger("test_pipeline_parallel")
...
@@ -18,23 +18,26 @@ logger = init_logger("test_pipeline_parallel")
VLLM_MULTI_NODE
=
os
.
getenv
(
"VLLM_MULTI_NODE"
,
"0"
)
==
"1"
VLLM_MULTI_NODE
=
os
.
getenv
(
"VLLM_MULTI_NODE"
,
"0"
)
==
"1"
@
pytest
.
mark
.
parametrize
((
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
@
pytest
.
mark
.
parametrize
(
(
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, "
"MODEL_NAME, DIST_BACKEND"
),
"MODEL_NAME, DIST_BACKEND"
),
[
[
(
2
,
2
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
2
,
2
,
0
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
2
,
2
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
2
,
2
,
1
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
3
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
3
,
0
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
4
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
4
,
0
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
4
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
4
,
1
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"mp"
),
(
1
,
3
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
3
,
0
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
0
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
1
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
1
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
0
,
1
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
0
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
])
(
2
,
2
,
1
,
1
,
1
,
"internlm/internlm2_5-7b-chat"
,
"ray"
),
],
)
@
fork_new_process_for_each_test
@
fork_new_process_for_each_test
def
test_compare_tp
(
TP_SIZE
,
PP_SIZE
,
EAGER_MODE
,
CHUNKED_PREFILL
,
MODEL_NAME
,
def
test_compare_tp
(
TP_SIZE
,
PP_SIZE
,
EAGER_MODE
,
CHUNKED_PREFILL
,
DIST_BACKEND
):
TRUST_REMOTE_CODE
,
MODEL_NAME
,
DIST_BACKEND
):
if
VLLM_MULTI_NODE
and
DIST_BACKEND
==
"mp"
:
if
VLLM_MULTI_NODE
and
DIST_BACKEND
==
"mp"
:
pytest
.
skip
(
"Skipping multi-node pipeline parallel test for "
pytest
.
skip
(
"Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend"
)
"multiprocessing distributed backend"
)
...
@@ -71,6 +74,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
...
@@ -71,6 +74,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
if
EAGER_MODE
:
if
EAGER_MODE
:
pp_args
.
append
(
"--enforce-eager"
)
pp_args
.
append
(
"--enforce-eager"
)
tp_args
.
append
(
"--enforce-eager"
)
tp_args
.
append
(
"--enforce-eager"
)
if
TRUST_REMOTE_CODE
:
pp_args
.
append
(
"--trust-remote-code"
)
tp_args
.
append
(
"--trust-remote-code"
)
pp_env
=
None
pp_env
=
None
if
(
DIST_BACKEND
==
"ray"
and
TP_SIZE
==
2
and
PP_SIZE
==
2
if
(
DIST_BACKEND
==
"ray"
and
TP_SIZE
==
2
and
PP_SIZE
==
2
and
CHUNKED_PREFILL
):
and
CHUNKED_PREFILL
):
...
...
tests/utils.py
View file @
8685ba1a
...
@@ -178,6 +178,11 @@ def compare_two_settings(model: str,
...
@@ -178,6 +178,11 @@ def compare_two_settings(model: str,
env2: The second set of environment variables to pass to the API server.
env2: The second set of environment variables to pass to the API server.
"""
"""
trust_remote_code
=
"--trust-remote-code"
if
trust_remote_code
in
arg1
or
trust_remote_code
in
arg2
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
,
trust_remote_code
=
True
)
else
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
)
prompt
=
"Hello, my name is"
prompt
=
"Hello, my name is"
...
...
vllm/config.py
View file @
8685ba1a
...
@@ -35,18 +35,20 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
...
@@ -35,18 +35,20 @@ _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
=
4096
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS
=
4096
_PP_SUPPORTED_MODELS
=
[
_PP_SUPPORTED_MODELS
=
[
"AquilaModel"
,
"AquilaForCausalLM"
,
"AquilaForCausalLM"
,
"AquilaModel"
,
"DeepseekV2ForCausalLM"
,
"DeepseekV2ForCausalLM"
,
"GPT2LMHeadModel"
,
"InternLM2ForCausalLM"
,
"InternLMForCausalLM"
,
"InternLMForCausalLM"
,
"InternVLChatModel"
,
"JAISLMHeadModel"
,
"JAISLMHeadModel"
,
"LlamaForCausalLM"
,
"LlamaForCausalLM"
,
"LLaMAForCausalLM"
,
"LLaMAForCausalLM"
,
"MistralForCausalLM"
,
"MistralForCausalLM"
,
"Phi3ForCausalLM"
,
"GPT2LMHeadModel"
,
"MixtralForCausalLM"
,
"MixtralForCausalLM"
,
"NemotronForCausalLM"
,
"NemotronForCausalLM"
,
"Phi3ForCausalLM"
,
"Qwen2ForCausalLM"
,
"Qwen2ForCausalLM"
,
"Qwen2MoeForCausalLM"
,
"Qwen2MoeForCausalLM"
,
"QWenLMHeadModel"
,
"QWenLMHeadModel"
,
...
...
vllm/model_executor/models/internlm2.py
View file @
8685ba1a
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -8,7 +8,7 @@ from transformers import PretrainedConfig
...
@@ -8,7 +8,7 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
)
tensor_model_parallel_all_gather
)
...
@@ -28,6 +28,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -28,6 +28,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
.utils
import
(
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
)
class
InternLM2MLP
(
nn
.
Module
):
class
InternLM2MLP
(
nn
.
Module
):
...
@@ -234,6 +237,7 @@ class InternLM2Model(nn.Module):
...
@@ -234,6 +237,7 @@ class InternLM2Model(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -243,11 +247,15 @@ class InternLM2Model(nn.Module):
...
@@ -243,11 +247,15 @@ class InternLM2Model(nn.Module):
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
InternLMDecoderLayer
(
config
,
cache_config
,
quant_config
)
config
.
num_hidden_layers
,
for
_
in
range
(
config
.
num_hidden_layers
)
lambda
prefix
:
InternLMDecoderLayer
(
config
,
cache_config
,
])
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
(
[
"hidden_states"
,
"residual"
],
config
.
hidden_size
))
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
tok_embeddings
(
input_ids
)
return
self
.
tok_embeddings
(
input_ids
)
...
@@ -260,21 +268,31 @@ class InternLM2Model(nn.Module):
...
@@ -260,21 +268,31 @@ class InternLM2Model(nn.Module):
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
IntermediateTensors
=
None
,
intermediate_tensors
:
IntermediateTensors
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
else
:
else
:
hidden_states
=
self
.
tok_embeddings
(
input_ids
)
hidden_states
=
self
.
tok_embeddings
(
input_ids
)
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
attn_metadata
,
residual
,
residual
,
)
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -298,6 +316,8 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -298,6 +316,8 @@ class InternLM2ForCausalLM(nn.Module):
self
.
output
.
weight
=
self
.
model
.
tok_embeddings
.
weight
self
.
output
.
weight
=
self
.
model
.
tok_embeddings
.
weight
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
self
.
sampler
=
Sampler
()
self
.
make_empty_intermediate_tensors
=
(
self
.
model
.
make_empty_intermediate_tensors
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -308,7 +328,7 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -308,7 +328,7 @@ class InternLM2ForCausalLM(nn.Module):
intermediate_tensors
:
IntermediateTensors
,
intermediate_tensors
:
IntermediateTensors
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
return
hidden_states
def
compute_logits
(
def
compute_logits
(
...
@@ -345,6 +365,8 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -345,6 +365,8 @@ class InternLM2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
@@ -353,6 +375,8 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -353,6 +375,8 @@ class InternLM2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/internvl.py
View file @
8685ba1a
...
@@ -341,6 +341,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -341,6 +341,8 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
nn
.
Linear
(
llm_hidden_size
,
llm_hidden_size
))
nn
.
Linear
(
llm_hidden_size
,
llm_hidden_size
))
self
.
img_context_token_id
=
None
self
.
img_context_token_id
=
None
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
def
pixel_shuffle
(
self
,
x
,
scale_factor
=
0.5
):
def
pixel_shuffle
(
self
,
x
,
scale_factor
=
0.5
):
n
,
w
,
h
,
c
=
x
.
size
()
n
,
w
,
h
,
c
=
x
.
size
()
...
@@ -461,7 +463,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -461,7 +463,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
positions
,
positions
,
kv_caches
,
kv_caches
,
attn_metadata
,
attn_metadata
,
None
,
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
)
inputs_embeds
=
inputs_embeds
)
return
hidden_states
return
hidden_states
...
...
vllm/model_executor/models/utils.py
View file @
8685ba1a
...
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.model_loader.loader
import
build_model
from
vllm.model_executor.model_loader.loader
import
build_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.multimodal.base
import
NestedTensors
from
vllm.multimodal.base
import
NestedTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
...
@@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
...
@@ -279,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
if
name
.
startswith
(
missing_layer_name
):
if
name
.
startswith
(
missing_layer_name
):
return
True
return
True
return
False
return
False
def
make_empty_intermediate_tensors_factory
(
keys
:
List
[
str
],
hidden_size
:
int
):
def
make_empty_intermediate_tensors
(
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
key
:
torch
.
zeros
((
batch_size
,
hidden_size
),
dtype
=
dtype
,
device
=
device
)
for
key
in
keys
})
return
make_empty_intermediate_tensors
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment