Unverified Commit f57ee565 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Model] Modify MolmoForCausalLM MLP (#11510)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent dcb1a944
...@@ -464,24 +464,27 @@ class MolmoAttention(nn.Module): ...@@ -464,24 +464,27 @@ class MolmoAttention(nn.Module):
class MolmoMLP(nn.Module): class MolmoMLP(nn.Module):
"""Molmo's LLM mlp.""" """Molmo's LLM mlp."""
def __init__( def __init__(self,
self, config: PretrainedConfig,
config: PretrainedConfig, input_dim: Optional[int] = None,
input_dim: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None,
quant_config: Optional[QuantizationConfig] = None, proj_name: str = "gate_up_proj") -> None:
) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size // 2 self.intermediate_size = config.intermediate_size // 2
# Feed-forward input projection. # Molmo's LLM proj weights are already merged into the disk, while
self.gate_up_proj = MergedColumnParallelLinear( # image_projector proj is separate. If the same proj_name were used, it
input_dim or self.hidden_size, # would create ambiguity and make it difficult to support BNB and LoRA.
[self.intermediate_size] * 2, self.proj_name = proj_name
bias=False, setattr(
quant_config=quant_config, self, proj_name,
) MergedColumnParallelLinear(
input_dim or self.hidden_size,
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
))
# Activation function. # Activation function.
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
...@@ -497,7 +500,7 @@ class MolmoMLP(nn.Module): ...@@ -497,7 +500,7 @@ class MolmoMLP(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x) gate_up, _ = getattr(self, self.proj_name)(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x return x
...@@ -520,7 +523,9 @@ class MolmoDecoderLayer(nn.Module): ...@@ -520,7 +523,9 @@ class MolmoDecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn") prefix=f"{prefix}.self_attn")
# MLP block. # MLP block.
self.mlp = MolmoMLP(config, quant_config=quant_config) self.mlp = MolmoMLP(config,
quant_config=quant_config,
proj_name="gate_up_proj")
# LayerNorm # LayerNorm
assert config.layer_norm_type == "rms" assert config.layer_norm_type == "rms"
...@@ -616,6 +621,7 @@ class MolmoVisionBackbone(nn.Module): ...@@ -616,6 +621,7 @@ class MolmoVisionBackbone(nn.Module):
config, config,
input_dim=vision_config.image_emb_dim, input_dim=vision_config.image_emb_dim,
quant_config=quant_config, quant_config=quant_config,
proj_name="merged_linear",
) )
image_dim = vision_config.image_emb_dim * len(self.vit_layers) image_dim = vision_config.image_emb_dim * len(self.vit_layers)
...@@ -714,8 +720,8 @@ class MolmoVisionBackbone(nn.Module): ...@@ -714,8 +720,8 @@ class MolmoVisionBackbone(nn.Module):
torch.Tensor]]) -> Set[str]: torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0), ("merged_linear", "gate_proj", 0),
("gate_up_proj", "up_proj", 1), ("merged_linear", "up_proj", 1),
] ]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment