Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b38bc652
Unverified
Commit
b38bc652
authored
Jul 25, 2025
by
Jason Gu
Committed by
GitHub
Jul 24, 2025
Browse files
[Model] Support tensor parallel for timm ViT in Deepseek_vl2 (#21494)
Signed-off-by:
wzqd
<
1057337859@qq.com
>
parent
adaf2c6d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
2 deletions
+38
-2
vllm/model_executor/models/deepseek_vl2.py
vllm/model_executor/models/deepseek_vl2.py
+38
-2
No files found.
vllm/model_executor/models/deepseek_vl2.py
View file @
b38bc652
...
@@ -14,9 +14,11 @@ from einops import rearrange, repeat
...
@@ -14,9 +14,11 @@ from einops import rearrange, repeat
from
transformers
import
BatchFeature
from
transformers
import
BatchFeature
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.model_loader.utils
import
set_default_torch_dtype
from
vllm.model_executor.models.transformers
import
replace_linear_class
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
,
NestedTensors
)
MultiModalKwargs
,
NestedTensors
)
...
@@ -379,6 +381,37 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -379,6 +381,37 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
make_empty_intermediate_tensors
)
def
_get_parent_and_attr
(
self
,
root
:
torch
.
nn
.
Module
,
dotted_name
:
str
):
"""Return (parent_module, final_attr_name) for a dotted module path."""
names
=
dotted_name
.
split
(
'.'
)
parent
=
root
for
n
in
names
[:
-
1
]:
parent
=
getattr
(
parent
,
n
)
return
parent
,
names
[
-
1
]
#patch for timm ViT instance to support tensor parallel
def
patch_vit_for_tp
(
self
,
vit
:
torch
.
nn
.
Module
,
quant_config
:
QuantizationConfig
):
try
:
import
timm
except
ImportError
as
e
:
raise
ImportError
(
"Please install timm"
)
from
e
for
name
,
module
in
vit
.
named_modules
():
if
isinstance
(
module
,
nn
.
Linear
):
parent
,
attr_name
=
self
.
_get_parent_and_attr
(
vit
,
name
)
if
isinstance
(
parent
,
timm
.
layers
.
Mlp
)
and
attr_name
==
"fc1"
:
new_linear
=
replace_linear_class
(
module
,
"colwise"
,
quant_config
)
setattr
(
parent
,
attr_name
,
new_linear
)
elif
isinstance
(
parent
,
timm
.
layers
.
Mlp
)
and
attr_name
==
"fc2"
:
new_linear
=
replace_linear_class
(
module
,
"rowwise"
,
quant_config
)
setattr
(
parent
,
attr_name
,
new_linear
)
return
vit
def
_init_vision_module
(
def
_init_vision_module
(
self
,
self
,
vision_config
:
VisionEncoderConfig
,
vision_config
:
VisionEncoderConfig
,
...
@@ -388,8 +421,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -388,8 +421,8 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# TODO: refactor vision model through timm wrapper from transformers
# TODO: refactor vision model through timm wrapper from transformers
try
:
try
:
import
timm
import
timm
except
ImportError
:
except
ImportError
as
e
:
raise
ImportError
(
"Please install timm"
)
from
ImportError
raise
ImportError
(
"Please install timm"
)
from
e
with
set_default_torch_dtype
(
torch
.
float16
):
with
set_default_torch_dtype
(
torch
.
float16
):
model
=
timm
.
create_model
(
model
=
timm
.
create_model
(
...
@@ -400,6 +433,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -400,6 +433,9 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
dynamic_img_pad
=
True
,
dynamic_img_pad
=
True
,
)
)
if
get_tensor_model_parallel_world_size
()
>
1
:
model
=
self
.
patch_vit_for_tp
(
model
,
quant_config
)
model
=
model
.
to
(
dtype
=
torch
.
get_default_dtype
())
model
=
model
.
to
(
dtype
=
torch
.
get_default_dtype
())
return
model
return
model
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment