Unverified Commit b58416f7 authored by Zhiteng Li's avatar Zhiteng Li Committed by GitHub
Browse files

[NN] Refine the API of GraphormerLayer (#5565)


Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent 6247b006
......@@ -30,6 +30,8 @@ class GraphormerLayer(nn.Module):
afterwards. Default: False.
dropout : float, optional
Dropout probability. Default: 0.1.
attn_dropout : float, optional
Attention dropout probability. Default: 0.1.
activation : callable activation layer, optional
Activation function. Default: nn.ReLU().
......@@ -60,6 +62,7 @@ class GraphormerLayer(nn.Module):
attn_bias_type="add",
norm_first=False,
dropout=0.1,
attn_dropout=0.1,
activation=nn.ReLU(),
):
super().__init__()
......@@ -70,7 +73,7 @@ class GraphormerLayer(nn.Module):
feat_size=feat_size,
num_heads=num_heads,
attn_bias_type=attn_bias_type,
attn_drop=dropout,
attn_drop=attn_dropout,
)
self.ffn = nn.Sequential(
nn.Linear(feat_size, hidden_size),
......
......@@ -2322,6 +2322,7 @@ def test_GraphormerLayer(attn_bias_type, norm_first):
attn_bias_type=attn_bias_type,
norm_first=norm_first,
dropout=0.1,
attn_dropout=0.1,
activation=th.nn.ReLU(),
)
out = net(nfeat, attn_bias, attn_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment