Unverified Commit ac1536cf authored by Secbone's avatar Secbone Committed by GitHub
Browse files

[Feature] add udf support for `cross_reducer` (#2891)



* add: udf support for `cross_reducer`

* update: code lint

* update: reducer without stack

* docs: add docs for udf cross_reducer

* chore: fix code lint

* docs: update multi_update_all docstring
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>

* docs: update reduce_dict_data docstring
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 99831073
...@@ -4735,8 +4735,11 @@ class DGLHeteroGraph(object): ...@@ -4735,8 +4735,11 @@ class DGLHeteroGraph(object):
An optional apply function to further update the node features An optional apply function to further update the node features
after the message reduction. It must be a :ref:`apiudf`. after the message reduction. It must be a :ref:`apiudf`.
cross_reducer : str cross_reducer : str or callable function
Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``. Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``
or a callable function. If a callable function is provided, the input argument must be
a single list of tensors containing aggregation results from each edge type, and the
output of function must be a single tensor.
apply_node_func : callable, optional apply_node_func : callable, optional
An optional apply function after the messages are reduced both An optional apply function after the messages are reduced both
type-wisely and across different types. type-wisely and across different types.
...@@ -5907,8 +5910,11 @@ def reduce_dict_data(frames, reducer, order=None): ...@@ -5907,8 +5910,11 @@ def reduce_dict_data(frames, reducer, order=None):
---------- ----------
frames : list[dict[str, Tensor]] frames : list[dict[str, Tensor]]
Input tensor dictionaries Input tensor dictionaries
reducer : str reducer : str or callable function
One of "sum", "max", "min", "mean", "stack" One of "sum", "max", "min", "mean", "stack" or a callable function.
If a callable function is provided, the input arguments must be a single list
of tensors containing aggregation results from each edge type, and the
output of function must be a single tensor.
order : list[Int], optional order : list[Int], optional
Merge order hint. Useful for "stack" reducer. Merge order hint. Useful for "stack" reducer.
If provided, each integer indicates the relative order If provided, each integer indicates the relative order
...@@ -5925,7 +5931,9 @@ def reduce_dict_data(frames, reducer, order=None): ...@@ -5925,7 +5931,9 @@ def reduce_dict_data(frames, reducer, order=None):
# Directly return the only one input. Stack reducer requires # Directly return the only one input. Stack reducer requires
# modifying tensor shape. # modifying tensor shape.
return frames[0] return frames[0]
if reducer == 'stack': if callable(reducer):
merger = reducer
elif reducer == 'stack':
# Stack order does not matter. However, it must be consistent! # Stack order does not matter. However, it must be consistent!
if order: if order:
assert len(order) == len(frames) assert len(order) == len(frames)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment