Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1230263e
Unverified
Commit
1230263e
authored
Sep 11, 2024
by
Isotr0py
Committed by
GitHub
Sep 11, 2024
Browse files
[Bugfix] Fix InternVL2 vision embeddings process with pipeline parallel (#8299)
parent
e497b8ae
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
3 deletions
+10
-3
tests/distributed/test_pipeline_parallel.py
tests/distributed/test_pipeline_parallel.py
+8
-2
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+2
-1
No files found.
tests/distributed/test_pipeline_parallel.py
View file @
1230263e
...
@@ -32,7 +32,9 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
...
@@ -32,7 +32,9 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(
1
,
4
,
1
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
1
,
4
,
1
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
1
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
1
,
0
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
0
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
0
,
1
,
0
,
"meta-llama/Meta-Llama-3-8B"
,
"ray"
),
(
2
,
2
,
1
,
1
,
1
,
"internlm/internlm2_5-7b-chat"
,
"ray"
),
(
1
,
2
,
1
,
1
,
1
,
"OpenGVLab/InternVL2-1B"
,
"ray"
),
(
1
,
2
,
1
,
1
,
1
,
"OpenGVLab/InternVL2-2B"
,
"ray"
),
(
1
,
2
,
1
,
0
,
1
,
"OpenGVLab/InternVL2-4B"
,
"ray"
),
],
],
)
)
@
fork_new_process_for_each_test
@
fork_new_process_for_each_test
...
@@ -46,6 +48,8 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
...
@@ -46,6 +48,8 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
# use half precision for speed and memory savings in CI environment
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"--dtype"
,
"float16"
,
"float16"
,
"--max-model-len"
,
"8192"
,
"--pipeline-parallel-size"
,
"--pipeline-parallel-size"
,
str
(
PP_SIZE
),
str
(
PP_SIZE
),
"--tensor-parallel-size"
,
"--tensor-parallel-size"
,
...
@@ -62,7 +66,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
...
@@ -62,7 +66,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
tp_args
=
[
tp_args
=
[
# use half precision for speed and memory savings in CI environment
# use half precision for speed and memory savings in CI environment
"--dtype"
,
"--dtype"
,
"bfloat16"
,
"float16"
,
"--max-model-len"
,
"8192"
,
"--tensor-parallel-size"
,
"--tensor-parallel-size"
,
str
(
max
(
TP_SIZE
,
2
)),
# We only use 2 GPUs in the CI.
str
(
max
(
TP_SIZE
,
2
)),
# We only use 2 GPUs in the CI.
"--distributed-executor-backend"
,
"--distributed-executor-backend"
,
...
...
vllm/model_executor/models/internvl.py
View file @
1230263e
...
@@ -17,6 +17,7 @@ from transformers import PretrainedConfig
...
@@ -17,6 +17,7 @@ from transformers import PretrainedConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.distributed
import
get_pp_group
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
...
@@ -480,7 +481,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -480,7 +481,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
**
kwargs
:
object
,
**
kwargs
:
object
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
if
image_input
is
not
None
and
get_pp_group
().
is_first_rank
:
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
input_ids
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment