"vscode:/vscode.git/clone" did not exist on "cba5af22eb5ca96975b03556e41f4dac95a3db73"
Unverified Commit 094fbdac authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Fix incorrect LoRA weight loading for fused gate_up_proj (#6734)

parent 888cb175
...@@ -680,8 +680,8 @@ register_conv_template( ...@@ -680,8 +680,8 @@ register_conv_template(
register_conv_template( register_conv_template(
Conversation( Conversation(
name="phi-4-mm", name="phi-4-mm",
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", system_message="",
system_template="<|system|>{system_message}<|end|>", system_template="{system_message}",
roles=("<|user|>", "<|assistant|>"), roles=("<|user|>", "<|assistant|>"),
sep_style=SeparatorStyle.NO_COLON_SINGLE, sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="<|end|>", sep="<|end|>",
......
...@@ -209,4 +209,12 @@ class LoRAAdapter(nn.Module): ...@@ -209,4 +209,12 @@ class LoRAAdapter(nn.Module):
gate_up_name = weight_name gate_up_name = weight_name
if "lora_A" in weight_name: if "lora_A" in weight_name:
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1) weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
# else: "lora_B" is already stacked, no operations is needed. else:
output_dim = weights[gate_up_name].shape[0] // 2
weights[gate_up_name] = torch.stack(
[
weights[gate_up_name][:output_dim, :],
weights[gate_up_name][output_dim:, :],
],
dim=0,
)
...@@ -296,23 +296,30 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -296,23 +296,30 @@ class Idefics2VisionTransformer(nn.Module):
def compute_cu_seqlens( def compute_cu_seqlens(
self, self,
tgt_sizes: Optional[torch.Tensor] = None, tgt_sizes: Optional[torch.Tensor] = None,
atch_attention_mask: Optional[torch.BoolTensor] = None, input_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# shape: (batch_size,) # shape: (batch_size,)
if tgt_sizes is not None: if tgt_sizes is not None:
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] seqlen = tgt_sizes[:, 0] * tgt_sizes[:, 1]
elif input_embeds is not None:
seqlen = torch.full(
size=(input_embeds.shape[0],),
fill_value=input_embeds.shape[1],
dtype=torch.int32,
device=input_embeds.device,
)
else: else:
patch_len = atch_attention_mask[:, :, 0].sum(dim=1) * atch_attention_mask[ raise ValueError(
:, 0, : "Either `tgt_sizes` or `input_embeds` must be provided to compute cu_seqlens."
].sum(dim=1) )
cu_seqlens = torch.cat( cu_seqlens = torch.cat(
[ [
torch.tensor([0], device=patch_len.device, dtype=torch.int32), torch.tensor([0], device=seqlen.device, dtype=torch.int32),
torch.cumsum(patch_len, dim=0, dtype=torch.int32), torch.cumsum(seqlen, dim=0, dtype=torch.int32),
], ],
dim=0, dim=0,
).to(patch_len.device) ).to(seqlen.device)
return cu_seqlens return cu_seqlens
def forward( def forward(
...@@ -326,7 +333,7 @@ class Idefics2VisionTransformer(nn.Module): ...@@ -326,7 +333,7 @@ class Idefics2VisionTransformer(nn.Module):
patch_attention_mask=patch_attention_mask, patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes, tgt_sizes=tgt_sizes,
) )
cu_seqlens = self.compute_cu_seqlens(tgt_sizes, patch_attention_mask) cu_seqlens = self.compute_cu_seqlens(tgt_sizes, hidden_states)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
hidden_states, hidden_states,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
......
...@@ -451,8 +451,8 @@ class Phi4MMForCausalLM(nn.Module): ...@@ -451,8 +451,8 @@ class Phi4MMForCausalLM(nn.Module):
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id]) pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def should_apply_lora(self, module_name: str) -> Optional[str]: def should_apply_lora(self, module_name: str) -> bool:
return self.lora_pattern.match(module_name) return bool(self.lora_pattern.match(module_name))
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment