Unverified Commit 535aa3d3 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Doc]modify the wrong code in initialize the hetero_graph negative sampler (#2382) (#2411)



* modify the wrong code in initialize the hetero_graph negative sampler

the etype in hetero graph should be name of edges
self.weights = {
            etype: g.in_degrees(etype=etype).float() ** 0.75
            for _, etype, _ in g.canonical_etypes
        }
the original code will give etype a Tuple format, it cannot apply to the next processing:
self.weights[etype].multinomial(len(src), replacement=True)
and I add the part which confusing me to generate eid_dict

* fix
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
Co-authored-by: default avatarShaow <coco11563@yeah.net>
parent 35f27c73
...@@ -240,7 +240,8 @@ source-destination array pairs. An example is given as follows: ...@@ -240,7 +240,8 @@ source-destination array pairs. An example is given as follows:
# caches the probability distribution # caches the probability distribution
self.weights = { self.weights = {
etype: g.in_degrees(etype=etype).float() ** 0.75 etype: g.in_degrees(etype=etype).float() ** 0.75
for etype in g.canonical_etypes} for _, etype, _ in g.canonical_etypes
}
self.k = k self.k = k
def __call__(self, g, eids_dict): def __call__(self, g, eids_dict):
...@@ -248,10 +249,18 @@ source-destination array pairs. An example is given as follows: ...@@ -248,10 +249,18 @@ source-destination array pairs. An example is given as follows:
for etype, eids in eids_dict.items(): for etype, eids in eids_dict.items():
src, _ = g.find_edges(eids, etype=etype) src, _ = g.find_edges(eids, etype=etype)
src = src.repeat_interleave(self.k) src = src.repeat_interleave(self.k)
dst = self.weights.multinomial(len(src), replacement=True) dst = self.weights[etype].multinomial(len(src), replacement=True)
result_dict[etype] = (src, dst) result_dict[etype] = (src, dst)
return result_dict return result_dict
Then you can give the dataloader a dictionary of edge types and edge IDs as well as the negative
sampler. For instance, the following iterates over all edges of the heterogeneous graph.
.. code:: python
train_eid_dict = {
g.edges(etype=etype, form='eid')
for etype in g.etypes}
dataloader = dgl.dataloading.EdgeDataLoader( dataloader = dgl.dataloading.EdgeDataLoader(
g, train_eid_dict, sampler, g, train_eid_dict, sampler,
negative_sampler=NegativeSampler(g, 5), negative_sampler=NegativeSampler(g, 5),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment