"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "6fed066b4cc834c0a43a2ce4ca65cd0512076c01"
Unverified Commit 757eafcf authored by Jared Wen's avatar Jared Wen Committed by GitHub
Browse files

[bug-fix] GLM OCR Patch Merger context_dim (#37962)


Signed-off-by: default avatarJaredforReal <w13431838023@gmail.com>
parent dcdc1458
...@@ -38,7 +38,10 @@ import torch.nn as nn ...@@ -38,7 +38,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from transformers import BatchFeature, Glm4vProcessor from transformers import BatchFeature, Glm4vProcessor
from transformers.models.glm4v.configuration_glm4v import Glm4vVisionConfig from transformers.models.glm4v.configuration_glm4v import (
Glm4vTextConfig,
Glm4vVisionConfig,
)
from transformers.models.glm4v.image_processing_glm4v import ( from transformers.models.glm4v.image_processing_glm4v import (
Glm4vImageProcessor, Glm4vImageProcessor,
smart_resize, smart_resize,
...@@ -604,6 +607,7 @@ class Glm4vVisionEmbeddings(nn.Module): ...@@ -604,6 +607,7 @@ class Glm4vVisionEmbeddings(nn.Module):
class Glm4vVisionTransformer(nn.Module): class Glm4vVisionTransformer(nn.Module):
def __init__( def __init__(
self, self,
text_config: Glm4vTextConfig,
vision_config: Glm4vVisionConfig, vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6, norm_eps: float = 1e-6,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
...@@ -1424,6 +1428,7 @@ class Glm4vForConditionalGeneration( ...@@ -1424,6 +1428,7 @@ class Glm4vForConditionalGeneration(
with self._mark_tower_model(vllm_config, {"image", "video"}): with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = Glm4vVisionTransformer( self.visual = Glm4vVisionTransformer(
config.text_config,
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5), norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config, quant_config=quant_config,
......
...@@ -35,7 +35,10 @@ import torch.nn as nn ...@@ -35,7 +35,10 @@ import torch.nn as nn
from einops import rearrange from einops import rearrange
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers.models.glm_ocr.configuration_glm_ocr import GlmOcrVisionConfig from transformers.models.glm_ocr.configuration_glm_ocr import (
GlmOcrTextConfig,
GlmOcrVisionConfig,
)
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state from vllm.distributed import get_tensor_model_parallel_world_size, parallel_state
...@@ -250,12 +253,13 @@ class GlmOcrPatchMerger(Glm4vPatchMerger): ...@@ -250,12 +253,13 @@ class GlmOcrPatchMerger(Glm4vPatchMerger):
class GlmOcrVisionTransformer(Glm4vVisionTransformer): class GlmOcrVisionTransformer(Glm4vVisionTransformer):
def __init__( def __init__(
self, self,
text_config: "GlmOcrTextConfig",
vision_config: "GlmOcrVisionConfig", vision_config: "GlmOcrVisionConfig",
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
quant_config: QuantizationConfig | None = None, quant_config: QuantizationConfig | None = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__(vision_config, norm_eps, quant_config, prefix) super().__init__(text_config, vision_config, norm_eps, quant_config, prefix)
del self.post_conv_layernorm del self.post_conv_layernorm
del self.embeddings del self.embeddings
...@@ -301,7 +305,7 @@ class GlmOcrVisionTransformer(Glm4vVisionTransformer): ...@@ -301,7 +305,7 @@ class GlmOcrVisionTransformer(Glm4vVisionTransformer):
) )
self.merger = GlmOcrPatchMerger( self.merger = GlmOcrPatchMerger(
d_model=vision_config.out_hidden_size, d_model=vision_config.out_hidden_size,
context_dim=vision_config.out_hidden_size * vision_config.in_channels, context_dim=text_config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
bias=False, bias=False,
prefix=f"{prefix}.merger", prefix=f"{prefix}.merger",
...@@ -383,6 +387,7 @@ class GlmOcrForConditionalGeneration(Glm4vForConditionalGeneration): ...@@ -383,6 +387,7 @@ class GlmOcrForConditionalGeneration(Glm4vForConditionalGeneration):
with self._mark_tower_model(vllm_config, {"image", "video"}): with self._mark_tower_model(vllm_config, {"image", "video"}):
self.visual = GlmOcrVisionTransformer( self.visual = GlmOcrVisionTransformer(
config.text_config,
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5), norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config, quant_config=quant_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment