"src/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "6b02babbadce55093b3de0f47a144c5574162f31"
Unverified Commit 5b97a1a2 authored by Krzysztof Sadowski's avatar Krzysztof Sadowski Committed by GitHub
Browse files

add bias argument (#3970)

parent ae7e3db6
...@@ -260,6 +260,8 @@ class HeteroLinear(nn.Module): ...@@ -260,6 +260,8 @@ class HeteroLinear(nn.Module):
Input feature size for heterogeneous inputs. A key can be a string or a tuple of strings. Input feature size for heterogeneous inputs. A key can be a string or a tuple of strings.
out_size : int out_size : int
Output feature size. Output feature size.
bias : bool, optional
If True, learns a bias term. Defaults: ``True``.
Examples Examples
-------- --------
...@@ -276,12 +278,12 @@ class HeteroLinear(nn.Module): ...@@ -276,12 +278,12 @@ class HeteroLinear(nn.Module):
>>> print(out_feats[('user', 'follows', 'user')].shape) >>> print(out_feats[('user', 'follows', 'user')].shape)
torch.Size([3, 3]) torch.Size([3, 3])
""" """
def __init__(self, in_size, out_size): def __init__(self, in_size, out_size, bias=True):
super(HeteroLinear, self).__init__() super(HeteroLinear, self).__init__()
self.linears = nn.ModuleDict() self.linears = nn.ModuleDict()
for typ, typ_in_size in in_size.items(): for typ, typ_in_size in in_size.items():
self.linears[str(typ)] = nn.Linear(typ_in_size, out_size) self.linears[str(typ)] = nn.Linear(typ_in_size, out_size, bias=bias)
def forward(self, feat): def forward(self, feat):
"""Forward function """Forward function
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment