Unverified Commit 38de8223 authored by Terry Gao's avatar Terry Gao Committed by GitHub
Browse files

[Model] Add torch.compile support for InternVL vision encoder (#38049)


Signed-off-by: default avatartianrengao <terrygao87@gmail.com>
parent 2bfbdca2
...@@ -296,7 +296,15 @@ def normalize_value(x): ...@@ -296,7 +296,15 @@ def normalize_value(x):
# PretrainedConfig # PretrainedConfig
if hasattr(x, "to_json_string") and callable(x.to_json_string): if hasattr(x, "to_json_string") and callable(x.to_json_string):
return x.to_json_string() try:
return x.to_json_string()
except (TypeError, ValueError):
# to_json_string() may fail for trust-remote-code configs
# with non-JSON-serializable nested objects. Fall back to
# normalizing the dict representation recursively.
if hasattr(x, "to_dict") and callable(x.to_dict):
return normalize_value(x.to_dict())
raise
# Unsupported type: e.g., modules, generators, open files, or objects # Unsupported type: e.g., modules, generators, open files, or objects
# without a stable JSON/UUID representation. Hard-error to avoid # without a stable JSON/UUID representation. Hard-error to avoid
......
...@@ -15,6 +15,10 @@ import torch.nn as nn ...@@ -15,6 +15,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.compilation.decorators import (
should_torch_compile_mm_encoder,
support_torch_compile,
)
from vllm.distributed import ( from vllm.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
...@@ -280,6 +284,11 @@ class InternMLP(nn.Module): ...@@ -280,6 +284,11 @@ class InternMLP(nn.Module):
return hidden_states return hidden_states
@support_torch_compile(
dynamic_arg_dims={"hidden_states": 0},
enable_if=should_torch_compile_mm_encoder,
is_encoder=True,
)
class InternVisionEncoderLayer(nn.Module): class InternVisionEncoderLayer(nn.Module):
def __init__( def __init__(
self, self,
...@@ -364,8 +373,8 @@ class InternVisionEncoder(nn.Module): ...@@ -364,8 +373,8 @@ class InternVisionEncoder(nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
self.layer_cls( self.layer_cls(
config, config=config,
quant_config, quant_config=quant_config,
num_dummy_heads=num_dummy_heads, num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.layers.{layer_idx}", prefix=f"{prefix}.layers.{layer_idx}",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment