Unverified Commit b58416f7 authored by Zhiteng Li's avatar Zhiteng Li Committed by GitHub
Browse files

[NN] Refine the API of GraphormerLayer (#5565)


Co-authored-by: default avatarrudongyu <ru_dongyu@outlook.com>
parent 6247b006
...@@ -30,6 +30,8 @@ class GraphormerLayer(nn.Module): ...@@ -30,6 +30,8 @@ class GraphormerLayer(nn.Module):
afterwards. Default: False. afterwards. Default: False.
dropout : float, optional dropout : float, optional
Dropout probability. Default: 0.1. Dropout probability. Default: 0.1.
attn_dropout : float, optional
Attention dropout probability. Default: 0.1.
activation : callable activation layer, optional activation : callable activation layer, optional
Activation function. Default: nn.ReLU(). Activation function. Default: nn.ReLU().
...@@ -60,6 +62,7 @@ class GraphormerLayer(nn.Module): ...@@ -60,6 +62,7 @@ class GraphormerLayer(nn.Module):
attn_bias_type="add", attn_bias_type="add",
norm_first=False, norm_first=False,
dropout=0.1, dropout=0.1,
attn_dropout=0.1,
activation=nn.ReLU(), activation=nn.ReLU(),
): ):
super().__init__() super().__init__()
...@@ -70,7 +73,7 @@ class GraphormerLayer(nn.Module): ...@@ -70,7 +73,7 @@ class GraphormerLayer(nn.Module):
feat_size=feat_size, feat_size=feat_size,
num_heads=num_heads, num_heads=num_heads,
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
attn_drop=dropout, attn_drop=attn_dropout,
) )
self.ffn = nn.Sequential( self.ffn = nn.Sequential(
nn.Linear(feat_size, hidden_size), nn.Linear(feat_size, hidden_size),
......
...@@ -2322,6 +2322,7 @@ def test_GraphormerLayer(attn_bias_type, norm_first): ...@@ -2322,6 +2322,7 @@ def test_GraphormerLayer(attn_bias_type, norm_first):
attn_bias_type=attn_bias_type, attn_bias_type=attn_bias_type,
norm_first=norm_first, norm_first=norm_first,
dropout=0.1, dropout=0.1,
attn_dropout=0.1,
activation=th.nn.ReLU(), activation=th.nn.ReLU(),
) )
out = net(nfeat, attn_bias, attn_mask) out = net(nfeat, attn_bias, attn_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment