Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
38de8223
Unverified
Commit
38de8223
authored
Mar 25, 2026
by
Terry Gao
Committed by
GitHub
Mar 25, 2026
Browse files
[Model] Add torch.compile support for InternVL vision encoder (#38049)
Signed-off-by:
tianrengao
<
terrygao87@gmail.com
>
parent
2bfbdca2
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
3 deletions
+20
-3
vllm/config/utils.py
vllm/config/utils.py
+9
-1
vllm/model_executor/models/intern_vit.py
vllm/model_executor/models/intern_vit.py
+11
-2
No files found.
vllm/config/utils.py
View file @
38de8223
...
...
@@ -296,7 +296,15 @@ def normalize_value(x):
# PretrainedConfig
if
hasattr
(
x
,
"to_json_string"
)
and
callable
(
x
.
to_json_string
):
try
:
return
x
.
to_json_string
()
except
(
TypeError
,
ValueError
):
# to_json_string() may fail for trust-remote-code configs
# with non-JSON-serializable nested objects. Fall back to
# normalizing the dict representation recursively.
if
hasattr
(
x
,
"to_dict"
)
and
callable
(
x
.
to_dict
):
return
normalize_value
(
x
.
to_dict
())
raise
# Unsupported type: e.g., modules, generators, open files, or objects
# without a stable JSON/UUID representation. Hard-error to avoid
...
...
vllm/model_executor/models/intern_vit.py
View file @
38de8223
...
...
@@ -15,6 +15,10 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.compilation.decorators
import
(
should_torch_compile_mm_encoder
,
support_torch_compile
,
)
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
...
...
@@ -280,6 +284,11 @@ class InternMLP(nn.Module):
return
hidden_states
@
support_torch_compile
(
dynamic_arg_dims
=
{
"hidden_states"
:
0
},
enable_if
=
should_torch_compile_mm_encoder
,
is_encoder
=
True
,
)
class
InternVisionEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -364,8 +373,8 @@ class InternVisionEncoder(nn.Module):
self
.
layers
=
nn
.
ModuleList
(
[
self
.
layer_cls
(
config
,
quant_config
,
config
=
config
,
quant_config
=
quant_config
,
num_dummy_heads
=
num_dummy_heads
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment