Unverified Commit e0ac72b7 authored by Jaesun Park's avatar Jaesun Park Committed by GitHub
Browse files

Fix PerceiverMLP and test (#16405)


Co-authored-by: default avatarJaesun Park <jaesun.park1@navercorp.com>
parent 473709fc
......@@ -420,7 +420,7 @@ class PerceiverMLP(nn.Module):
self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
self.dense2 = nn.Linear(input_size, input_size)
self.dense2 = nn.Linear(widening_factor * input_size, input_size)
def forward(self, hidden_states):
hidden_states = self.dense1(hidden_states)
......
......@@ -82,6 +82,8 @@ class PerceiverModelTester:
num_self_attends_per_block=2,
num_self_attention_heads=1,
num_cross_attention_heads=1,
self_attention_widening_factor=4,
cross_attention_widening_factor=4,
is_training=True,
use_input_mask=True,
use_labels=True,
......@@ -109,6 +111,8 @@ class PerceiverModelTester:
self.num_self_attends_per_block = num_self_attends_per_block
self.num_self_attention_heads = num_self_attention_heads
self.num_cross_attention_heads = num_cross_attention_heads
self.self_attention_widening_factor = self_attention_widening_factor
self.cross_attention_widening_factor = cross_attention_widening_factor
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
......@@ -174,10 +178,14 @@ class PerceiverModelTester:
return PerceiverConfig(
num_latents=self.num_latents,
d_latents=self.d_latents,
qk_channels=self.d_latents,
v_channels=self.d_latents,
num_blocks=self.num_blocks,
num_self_attends_per_block=self.num_self_attends_per_block,
num_self_attention_heads=self.num_self_attention_heads,
num_cross_attention_heads=self.num_cross_attention_heads,
self_attention_widening_factor=self.self_attention_widening_factor,
cross_attention_widening_factor=self.cross_attention_widening_factor,
vocab_size=self.vocab_size,
hidden_act=self.hidden_act,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment