Unverified Commit 9e7e5baa authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

[Model] Add missing prefix to glm4_1v (#22716)


Signed-off-by: default avatarzRzRzRzRzRzRzR <2448370773@qq.com>
parent d16aa3da
...@@ -453,25 +453,30 @@ class Glm4vPatchMerger(nn.Module): ...@@ -453,25 +453,30 @@ class Glm4vPatchMerger(nn.Module):
context_dim: int, context_dim: int,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
bias: bool = False, bias: bool = False,
prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = d_model self.hidden_size = d_model
self.proj = ColumnParallelLinear(self.hidden_size, self.proj = ColumnParallelLinear(self.hidden_size,
self.hidden_size, self.hidden_size,
bias=bias, bias=bias,
gather_output=True) gather_output=True,
quant_config=quant_config,
prefix=f"{prefix}.proj")
self.post_projection_norm = nn.LayerNorm(self.hidden_size) self.post_projection_norm = nn.LayerNorm(self.hidden_size)
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
input_size=self.hidden_size, input_size=self.hidden_size,
output_sizes=[context_dim] * 2, output_sizes=[context_dim] * 2,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
) )
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
context_dim, context_dim,
self.hidden_size, self.hidden_size,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.down_proj",
) )
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
self.extra_activation_func = nn.GELU() self.extra_activation_func = nn.GELU()
...@@ -661,6 +666,7 @@ class Glm4vVisionTransformer(nn.Module): ...@@ -661,6 +666,7 @@ class Glm4vVisionTransformer(nn.Module):
context_dim=vision_config.intermediate_size, context_dim=vision_config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
bias=False, bias=False,
prefix=f"{prefix}.merger",
) )
self.embeddings = Glm4vVisionEmbeddings(vision_config) self.embeddings = Glm4vVisionEmbeddings(vision_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment