Unverified Commit e0ac72b7 authored by Jaesun Park's avatar Jaesun Park Committed by GitHub
Browse files

Fix PerceiverMLP and test (#16405)


Co-authored-by: default avatarJaesun Park <jaesun.park1@navercorp.com>
parent 473709fc
...@@ -420,7 +420,7 @@ class PerceiverMLP(nn.Module): ...@@ -420,7 +420,7 @@ class PerceiverMLP(nn.Module):
self.intermediate_act_fn = ACT2FN[config.hidden_act] self.intermediate_act_fn = ACT2FN[config.hidden_act]
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
self.dense2 = nn.Linear(input_size, input_size) self.dense2 = nn.Linear(widening_factor * input_size, input_size)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense1(hidden_states) hidden_states = self.dense1(hidden_states)
......
...@@ -82,6 +82,8 @@ class PerceiverModelTester: ...@@ -82,6 +82,8 @@ class PerceiverModelTester:
num_self_attends_per_block=2, num_self_attends_per_block=2,
num_self_attention_heads=1, num_self_attention_heads=1,
num_cross_attention_heads=1, num_cross_attention_heads=1,
self_attention_widening_factor=4,
cross_attention_widening_factor=4,
is_training=True, is_training=True,
use_input_mask=True, use_input_mask=True,
use_labels=True, use_labels=True,
...@@ -109,6 +111,8 @@ class PerceiverModelTester: ...@@ -109,6 +111,8 @@ class PerceiverModelTester:
self.num_self_attends_per_block = num_self_attends_per_block self.num_self_attends_per_block = num_self_attends_per_block
self.num_self_attention_heads = num_self_attention_heads self.num_self_attention_heads = num_self_attention_heads
self.num_cross_attention_heads = num_cross_attention_heads self.num_cross_attention_heads = num_cross_attention_heads
self.self_attention_widening_factor = self_attention_widening_factor
self.cross_attention_widening_factor = cross_attention_widening_factor
self.is_training = is_training self.is_training = is_training
self.use_input_mask = use_input_mask self.use_input_mask = use_input_mask
self.use_labels = use_labels self.use_labels = use_labels
...@@ -174,10 +178,14 @@ class PerceiverModelTester: ...@@ -174,10 +178,14 @@ class PerceiverModelTester:
return PerceiverConfig( return PerceiverConfig(
num_latents=self.num_latents, num_latents=self.num_latents,
d_latents=self.d_latents, d_latents=self.d_latents,
qk_channels=self.d_latents,
v_channels=self.d_latents,
num_blocks=self.num_blocks, num_blocks=self.num_blocks,
num_self_attends_per_block=self.num_self_attends_per_block, num_self_attends_per_block=self.num_self_attends_per_block,
num_self_attention_heads=self.num_self_attention_heads, num_self_attention_heads=self.num_self_attention_heads,
num_cross_attention_heads=self.num_cross_attention_heads, num_cross_attention_heads=self.num_cross_attention_heads,
self_attention_widening_factor=self.self_attention_widening_factor,
cross_attention_widening_factor=self.cross_attention_widening_factor,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
hidden_act=self.hidden_act, hidden_act=self.hidden_act,
attention_probs_dropout_prob=self.attention_probs_dropout_prob, attention_probs_dropout_prob=self.attention_probs_dropout_prob,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment