Unverified Commit 98f30b8c authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Model] Fix Skywork R1V mlp (#26673)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 3cd36660
...@@ -691,7 +691,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -691,7 +691,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
prefix=maybe_prefix(prefix, "language_model"), prefix=maybe_prefix(prefix, "language_model"),
) )
self.mlp1 = self._init_mlp1(config) self.mlp1 = self._init_mlp1(
config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
)
self.img_context_token_id = None self.img_context_token_id = None
self.visual_token_mask = None self.visual_token_mask = None
...@@ -738,7 +740,12 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -738,7 +740,12 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
else: else:
return InternVisionPatchModel(config.vision_config) return InternVisionPatchModel(config.vision_config)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module: def _init_mlp1(
self,
config: PretrainedConfig,
quant_config: QuantizationConfig,
prefix: str = "",
) -> nn.Module:
vit_hidden_size = config.vision_config.hidden_size vit_hidden_size = config.vision_config.hidden_size
llm_hidden_size = config.text_config.hidden_size llm_hidden_size = config.text_config.hidden_size
...@@ -748,9 +755,17 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP): ...@@ -748,9 +755,17 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
vit_hidden_size * int(1 / self.downsample_ratio) ** 2, vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
llm_hidden_size, llm_hidden_size,
return_bias=False, return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.1",
), ),
nn.GELU(), nn.GELU(),
ReplicatedLinear(llm_hidden_size, llm_hidden_size, return_bias=False), ReplicatedLinear(
llm_hidden_size,
llm_hidden_size,
return_bias=False,
quant_config=quant_config,
prefix=f"{prefix}.3",
),
) )
def pixel_shuffle(self, x, scale_factor=0.5): def pixel_shuffle(self, x, scale_factor=0.5):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment