"src/vscode:/vscode.git/clone" did not exist on "360e6cb48d2cb175d8c515882a0e1a4e27d092b5"
Unverified Commit ac1536cf authored by Secbone's avatar Secbone Committed by GitHub
Browse files

[Feature] add udf support for `cross_reducer` (#2891)



* add: udf support for `cross_reducer`

* update: code lint

* update: reducer without stack

* docs: add docs for udf cross_reducer

* chore: fix code lint

* docs: update multi_update_all docstring
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>

* docs: update reduce_dict_data docstring
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
parent 99831073
......@@ -4735,8 +4735,11 @@ class DGLHeteroGraph(object):
An optional apply function to further update the node features
after the message reduction. It must be a :ref:`apiudf`.
cross_reducer : str
Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``.
cross_reducer : str or callable function
Cross type reducer. One of ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"``
or a callable function. If a callable function is provided, the input argument must be
a single list of tensors containing aggregation results from each edge type, and the
output of function must be a single tensor.
apply_node_func : callable, optional
An optional apply function after the messages are reduced both
type-wisely and across different types.
......@@ -5907,8 +5910,11 @@ def reduce_dict_data(frames, reducer, order=None):
----------
frames : list[dict[str, Tensor]]
Input tensor dictionaries
reducer : str
One of "sum", "max", "min", "mean", "stack"
reducer : str or callable function
One of "sum", "max", "min", "mean", "stack" or a callable function.
If a callable function is provided, the input arguments must be a single list
of tensors containing aggregation results from each edge type, and the
output of function must be a single tensor.
order : list[Int], optional
Merge order hint. Useful for "stack" reducer.
If provided, each integer indicates the relative order
......@@ -5925,7 +5931,9 @@ def reduce_dict_data(frames, reducer, order=None):
# Directly return the only one input. Stack reducer requires
# modifying tensor shape.
return frames[0]
if reducer == 'stack':
if callable(reducer):
merger = reducer
elif reducer == 'stack':
# Stack order does not matter. However, it must be consistent!
if order:
assert len(order) == len(frames)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment