Unverified Commit 142b69f2 authored by Kevin Ko's avatar Kevin Ko Committed by GitHub
Browse files

Add layer_idx to CrossAttention of GPT2 model (#15730)

* Add layer_idx to CrossAttention

* Add layer_idx to crossattention of ImageGPT model
parent 86119c11
...@@ -374,7 +374,7 @@ class GPT2Block(nn.Module): ...@@ -374,7 +374,7 @@ class GPT2Block(nn.Module):
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention: if config.add_cross_attention:
self.crossattention = GPT2Attention(config, is_cross_attention=True) self.crossattention = GPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config) self.mlp = GPT2MLP(inner_dim, config)
......
...@@ -423,7 +423,7 @@ class ImageGPTBlock(nn.Module): ...@@ -423,7 +423,7 @@ class ImageGPTBlock(nn.Module):
self.ln_2 = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_2 = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention: if config.add_cross_attention:
self.crossattention = ImageGPTAttention(config, is_cross_attention=True) self.crossattention = ImageGPTAttention(config, is_cross_attention=True, layer_idx=layer_idx)
self.ln_cross_attn = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.ln_cross_attn = ImageGPTLayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = ImageGPTMLP(inner_dim, config) self.mlp = ImageGPTMLP(inner_dim, config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment