Unverified Commit 98f30b8c authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Model] Fix Skywork R1V mlp (#26673)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 3cd36660
......@@ -691,7 +691,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
prefix=maybe_prefix(prefix, "language_model"),
)
self.mlp1 = self._init_mlp1(config)
self.mlp1 = self._init_mlp1(
config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
)
self.img_context_token_id = None
self.visual_token_mask = None
......@@ -738,7 +740,12 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else:
return InternVisionPatchModel(config.vision_config)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
def _init_mlp1(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig,
prefix: str = "",
) -> nn.Module:
vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size
......@@ -748,9 +755,17 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
llm_hidden_size,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.1",
),
nn.GELU(),
ReplicatedLinear(llm_hidden_size, llm_hidden_size, return_bias=False),
ReplicatedLinear(
llm_hidden_size,
llm_hidden_size,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.3",
),
)
def pixel_shuffle(self, x, scale_factor=0.5):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment