Unverified Commit 8363cd09 authored by Jan Kaniecki's avatar Jan Kaniecki Committed by GitHub
Browse files

[Bugfix] Adjust mllama to regional compilation (#15112)


Signed-off-by: default avatarJan Kaniecki <jkaniecki@habana.ai>
parent 6c5a3195
...@@ -1070,8 +1070,8 @@ class MllamaTextModel(nn.Module): ...@@ -1070,8 +1070,8 @@ class MllamaTextModel(nn.Module):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
for decoder_layer in self.layers: for idx, decoder_layer in enumerate(self.layers):
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): if idx in self.cross_attention_layers:
if not skip_cross_attention: if not skip_cross_attention:
hidden_states = decoder_layer( hidden_states = decoder_layer(
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -1081,16 +1081,13 @@ class MllamaTextModel(nn.Module): ...@@ -1081,16 +1081,13 @@ class MllamaTextModel(nn.Module):
full_text_row_masked_out_mask= full_text_row_masked_out_mask=
full_text_row_masked_out_mask, full_text_row_masked_out_mask,
) )
elif isinstance(decoder_layer, LlamaDecoderLayer): else:
hidden_states, residual = decoder_layer( hidden_states, residual = decoder_layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=None, residual=None,
) )
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
else:
raise ValueError(
f"Unknown decoder layer type {type(decoder_layer)}")
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment