"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "257ffeb99747a2bb9caed54b1d51020aca95b0ae"
Commit d824c8be authored by rusty1s's avatar rusty1s
Browse files

segment documentation

parent 9b365d31
...@@ -60,7 +60,10 @@ ...@@ -60,7 +60,10 @@
<path style="stroke:none;" d=""/> <path style="stroke:none;" d=""/>
</symbol> </symbol>
<symbol overflow="visible" id="glyph1-1"> <symbol overflow="visible" id="glyph1-1">
<path style="stroke:none;" d="M 1.828125 -1.03125 L 3.078125 -1.03125 C 3.140625 -1.03125 3.234375 -1.03125 3.234375 -1.125 C 3.234375 -1.203125 3.140625 -1.203125 3.078125 -1.203125 L 1.828125 -1.203125 L 1.828125 -2.46875 C 1.828125 -2.53125 1.828125 -2.609375 1.75 -2.609375 C 1.65625 -2.609375 1.65625 -2.53125 1.65625 -2.46875 L 1.65625 -1.203125 L 0.40625 -1.203125 C 0.34375 -1.203125 0.25 -1.203125 0.25 -1.125 C 0.25 -1.03125 0.34375 -1.03125 0.40625 -1.03125 L 1.65625 -1.03125 L 1.65625 0.21875 C 1.65625 0.28125 1.65625 0.375 1.75 0.375 C 1.828125 0.375 1.828125 0.28125 1.828125 0.21875 Z M 1.828125 -1.03125 "/> <path style="stroke:none;" d="M 1.5 -0.34375 C 1.515625 -0.15625 1.625 0.03125 1.84375 0.03125 C 1.9375 0.03125 2.203125 -0.03125 2.203125 -0.40625 L 2.203125 -0.65625 L 2.09375 -0.65625 L 2.09375 -0.40625 C 2.09375 -0.140625 1.984375 -0.109375 1.9375 -0.109375 C 1.796875 -0.109375 1.765625 -0.3125 1.765625 -0.34375 L 1.765625 -1.234375 C 1.765625 -1.421875 1.765625 -1.59375 1.609375 -1.765625 C 1.4375 -1.9375 1.203125 -2.015625 1 -2.015625 C 0.625 -2.015625 0.3125 -1.796875 0.3125 -1.5 C 0.3125 -1.375 0.40625 -1.296875 0.53125 -1.296875 C 0.65625 -1.296875 0.734375 -1.375 0.734375 -1.5 C 0.734375 -1.546875 0.703125 -1.703125 0.5 -1.703125 C 0.625 -1.859375 0.84375 -1.90625 0.984375 -1.90625 C 1.203125 -1.90625 1.46875 -1.734375 1.46875 -1.34375 L 1.46875 -1.171875 C 1.234375 -1.15625 0.921875 -1.140625 0.640625 -1.015625 C 0.296875 -0.859375 0.1875 -0.625 0.1875 -0.421875 C 0.1875 -0.0625 0.625 0.046875 0.90625 0.046875 C 1.203125 0.046875 1.40625 -0.125 1.5 -0.34375 Z M 1.46875 -1.078125 L 1.46875 -0.625 C 1.46875 -0.203125 1.140625 -0.046875 0.9375 -0.046875 C 0.71875 -0.046875 0.53125 -0.203125 0.53125 -0.4375 C 0.53125 -0.671875 0.71875 -1.046875 1.46875 -1.078125 Z M 1.46875 -1.078125 "/>
</symbol>
<symbol overflow="visible" id="glyph1-2">
<path style="stroke:none;" d="M 1.703125 -0.25 L 1.703125 0.046875 L 2.359375 0 L 2.359375 -0.140625 C 2.046875 -0.140625 2.015625 -0.171875 2.015625 -0.390625 L 2.015625 -3.109375 L 1.375 -3.0625 L 1.375 -2.921875 C 1.6875 -2.921875 1.71875 -2.890625 1.71875 -2.671875 L 1.71875 -1.703125 C 1.59375 -1.859375 1.390625 -1.984375 1.15625 -1.984375 C 0.625 -1.984375 0.15625 -1.546875 0.15625 -0.96875 C 0.15625 -0.390625 0.59375 0.046875 1.109375 0.046875 C 1.390625 0.046875 1.59375 -0.109375 1.703125 -0.25 Z M 1.703125 -1.453125 L 1.703125 -0.53125 C 1.703125 -0.453125 1.703125 -0.4375 1.65625 -0.359375 C 1.515625 -0.140625 1.3125 -0.046875 1.125 -0.046875 C 0.921875 -0.046875 0.765625 -0.171875 0.65625 -0.34375 C 0.53125 -0.515625 0.53125 -0.78125 0.53125 -0.953125 C 0.53125 -1.125 0.53125 -1.390625 0.65625 -1.59375 C 0.765625 -1.734375 0.921875 -1.890625 1.171875 -1.890625 C 1.328125 -1.890625 1.515625 -1.8125 1.65625 -1.609375 C 1.703125 -1.53125 1.703125 -1.53125 1.703125 -1.453125 Z M 1.703125 -1.453125 "/>
</symbol> </symbol>
</g> </g>
<clipPath id="clip1"> <clipPath id="clip1">
...@@ -194,16 +197,15 @@ ...@@ -194,16 +197,15 @@
<path style="fill:none;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 99.212469 55.077688 L 99.212469 50.265188 " transform="matrix(1,0,0,-1,38.268,69.648)"/> <path style="fill:none;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 99.212469 55.077688 L 99.212469 50.265188 " transform="matrix(1,0,0,-1,38.268,69.648)"/>
<path style="fill:none;stroke-width:0.31879;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M -1.19608 1.592399 C -1.094518 0.994742 -0.0007675 0.100211 0.300014 -0.00135125 C -0.0007675 -0.0990075 -1.094518 -0.997445 -1.19608 -1.595101 " transform="matrix(0,1,1,0,137.48182,19.38358)"/> <path style="fill:none;stroke-width:0.31879;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M -1.19608 1.592399 C -1.094518 0.994742 -0.0007675 0.100211 0.300014 -0.00135125 C -0.0007675 -0.0990075 -1.094518 -0.997445 -1.19608 -1.595101 " transform="matrix(0,1,1,0,137.48182,19.38358)"/>
<g style="fill:rgb(0%,0%,0%);fill-opacity:1;"> <g style="fill:rgb(0%,0%,0%);fill-opacity:1;">
<use xlink:href="#glyph0-15" x="0" y="71.743"/> <use xlink:href="#glyph0-15" x="0" y="72.712"/>
<use xlink:href="#glyph0-7" x="4.9813" y="71.743"/> <use xlink:href="#glyph0-7" x="4.9813" y="72.712"/>
<use xlink:href="#glyph0-8" x="10.516521" y="71.743"/> <use xlink:href="#glyph0-8" x="10.516521" y="72.712"/>
<use xlink:href="#glyph0-6" x="14.390976" y="71.743"/>
<use xlink:href="#glyph0-7" x="19.926196" y="71.743"/>
<use xlink:href="#glyph0-8" x="25.461417" y="71.743"/>
</g> </g>
<path style="fill:none;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 34.32575 17.007375 C 34.32575 19.483938 31.649969 21.49175 28.345281 21.49175 C 25.0445 21.49175 22.368719 19.483938 22.368719 17.007375 C 22.368719 14.530813 25.0445 12.526906 28.345281 12.526906 C 31.649969 12.526906 34.32575 14.530813 34.32575 17.007375 Z M 34.32575 17.007375 " transform="matrix(1,0,0,-1,38.268,69.648)"/> <path style="fill:none;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 34.32575 17.007375 C 34.32575 19.483938 31.649969 21.49175 28.345281 21.49175 C 25.0445 21.49175 22.368719 19.483938 22.368719 17.007375 C 22.368719 14.530813 25.0445 12.526906 28.345281 12.526906 C 31.649969 12.526906 34.32575 14.530813 34.32575 17.007375 Z M 34.32575 17.007375 " transform="matrix(1,0,0,-1,38.268,69.648)"/>
<g style="fill:rgb(0%,0%,0%);fill-opacity:1;"> <g style="fill:rgb(0%,0%,0%);fill-opacity:1;">
<use xlink:href="#glyph1-1" x="64.871" y="53.761"/> <use xlink:href="#glyph1-1" x="63.003" y="54.197"/>
<use xlink:href="#glyph1-2" x="65.244585" y="54.197"/>
<use xlink:href="#glyph1-2" x="67.735434" y="54.197"/>
</g> </g>
<g clip-path="url(#clip4)" clip-rule="nonzero"> <g clip-path="url(#clip4)" clip-rule="nonzero">
<path style="fill-rule:nonzero;fill:rgb(0%,67.83905%,93.728638%);fill-opacity:0.5;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 21.259344 -7.086375 L 35.435125 -7.086375 L 35.435125 7.0855 L 21.259344 7.0855 Z M 21.259344 -7.086375 " transform="matrix(1,0,0,-1,38.268,69.648)"/> <path style="fill-rule:nonzero;fill:rgb(0%,67.83905%,93.728638%);fill-opacity:0.5;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 21.259344 -7.086375 L 35.435125 -7.086375 L 35.435125 7.0855 L 21.259344 7.0855 Z M 21.259344 -7.086375 " transform="matrix(1,0,0,-1,38.268,69.648)"/>
...@@ -215,7 +217,9 @@ ...@@ -215,7 +217,9 @@
<path style="fill:none;stroke-width:0.31879;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M -1.196749 1.592231 C -1.095186 0.994575 -0.00143625 0.100044 0.299345 -0.00151875 C -0.00143625 -0.099175 -1.095186 -0.997613 -1.196749 -1.595269 " transform="matrix(0,1,1,0,66.6148,61.90378)"/> <path style="fill:none;stroke-width:0.31879;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M -1.196749 1.592231 C -1.095186 0.994575 -0.00143625 0.100044 0.299345 -0.00151875 C -0.00143625 -0.099175 -1.095186 -0.997613 -1.196749 -1.595269 " transform="matrix(0,1,1,0,66.6148,61.90378)"/>
<path style="fill:none;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 48.497625 17.007375 C 48.497625 19.483938 45.821844 21.49175 42.521062 21.49175 C 39.220281 21.49175 36.540594 19.483938 36.540594 17.007375 C 36.540594 14.530813 39.220281 12.526906 42.521062 12.526906 C 45.821844 12.526906 48.497625 14.530813 48.497625 17.007375 Z M 48.497625 17.007375 " transform="matrix(1,0,0,-1,38.268,69.648)"/> <path style="fill:none;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 48.497625 17.007375 C 48.497625 19.483938 45.821844 21.49175 42.521062 21.49175 C 39.220281 21.49175 36.540594 19.483938 36.540594 17.007375 C 36.540594 14.530813 39.220281 12.526906 42.521062 12.526906 C 45.821844 12.526906 48.497625 14.530813 48.497625 17.007375 Z M 48.497625 17.007375 " transform="matrix(1,0,0,-1,38.268,69.648)"/>
<g style="fill:rgb(0%,0%,0%);fill-opacity:1;"> <g style="fill:rgb(0%,0%,0%);fill-opacity:1;">
<use xlink:href="#glyph1-1" x="79.044" y="53.761"/> <use xlink:href="#glyph1-1" x="77.176" y="54.197"/>
<use xlink:href="#glyph1-2" x="79.417585" y="54.197"/>
<use xlink:href="#glyph1-2" x="81.908434" y="54.197"/>
</g> </g>
<g clip-path="url(#clip5)" clip-rule="nonzero"> <g clip-path="url(#clip5)" clip-rule="nonzero">
<path style=" stroke:none;fill-rule:nonzero;fill:rgb(100%,50%,0%);fill-opacity:0.5;" d="M 73.703125 76.734375 L 87.875 76.734375 L 87.875 62.5625 L 73.703125 62.5625 Z M 73.703125 76.734375 "/> <path style=" stroke:none;fill-rule:nonzero;fill:rgb(100%,50%,0%);fill-opacity:0.5;" d="M 73.703125 76.734375 L 87.875 76.734375 L 87.875 62.5625 L 73.703125 62.5625 Z M 73.703125 76.734375 "/>
...@@ -230,7 +234,9 @@ ...@@ -230,7 +234,9 @@
<path style="fill:none;stroke-width:0.31879;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M -1.196749 1.594612 C -1.095186 0.996956 -0.00143625 0.0985187 0.299345 0.0008625 C -0.00143625 -0.1007 -1.095186 -0.995231 -1.196749 -1.592888 " transform="matrix(0,1,1,0,80.7882,61.90378)"/> <path style="fill:none;stroke-width:0.31879;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M -1.196749 1.594612 C -1.095186 0.996956 -0.00143625 0.0985187 0.299345 0.0008625 C -0.00143625 -0.1007 -1.095186 -0.995231 -1.196749 -1.592888 " transform="matrix(0,1,1,0,80.7882,61.90378)"/>
<path style="fill:none;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 62.6695 17.007375 C 62.6695 19.483938 59.993719 21.49175 56.692937 21.49175 C 53.392156 21.49175 50.716375 19.483938 50.716375 17.007375 C 50.716375 14.530813 53.392156 12.526906 56.692937 12.526906 C 59.993719 12.526906 62.6695 14.530813 62.6695 17.007375 Z M 62.6695 17.007375 " transform="matrix(1,0,0,-1,38.268,69.648)"/> <path style="fill:none;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 62.6695 17.007375 C 62.6695 19.483938 59.993719 21.49175 56.692937 21.49175 C 53.392156 21.49175 50.716375 19.483938 50.716375 17.007375 C 50.716375 14.530813 53.392156 12.526906 56.692937 12.526906 C 59.993719 12.526906 62.6695 14.530813 62.6695 17.007375 Z M 62.6695 17.007375 " transform="matrix(1,0,0,-1,38.268,69.648)"/>
<g style="fill:rgb(0%,0%,0%);fill-opacity:1;"> <g style="fill:rgb(0%,0%,0%);fill-opacity:1;">
<use xlink:href="#glyph1-1" x="93.217" y="53.761"/> <use xlink:href="#glyph1-1" x="91.349" y="54.197"/>
<use xlink:href="#glyph1-2" x="93.590585" y="54.197"/>
<use xlink:href="#glyph1-2" x="96.081434" y="54.197"/>
</g> </g>
<g clip-path="url(#clip7)" clip-rule="nonzero"> <g clip-path="url(#clip7)" clip-rule="nonzero">
<path style="fill-rule:nonzero;fill:rgb(55.488586%,52.549744%,0%);fill-opacity:0.5;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 49.607 -7.086375 L 63.778875 -7.086375 L 63.778875 7.0855 L 49.607 7.0855 Z M 49.607 -7.086375 " transform="matrix(1,0,0,-1,38.268,69.648)"/> <path style="fill-rule:nonzero;fill:rgb(55.488586%,52.549744%,0%);fill-opacity:0.5;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 49.607 -7.086375 L 63.778875 -7.086375 L 63.778875 7.0855 L 49.607 7.0855 Z M 49.607 -7.086375 " transform="matrix(1,0,0,-1,38.268,69.648)"/>
...@@ -242,7 +248,9 @@ ...@@ -242,7 +248,9 @@
<path style="fill:none;stroke-width:0.31879;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M -1.196749 1.593067 C -1.095186 0.995411 -0.00143625 0.10088 0.299345 -0.0006825 C -0.00143625 -0.0983388 -1.095186 -0.996776 -1.196749 -1.594433 " transform="matrix(0,1,1,0,94.96162,61.90378)"/> <path style="fill:none;stroke-width:0.31879;stroke-linecap:round;stroke-linejoin:round;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M -1.196749 1.593067 C -1.095186 0.995411 -0.00143625 0.10088 0.299345 -0.0006825 C -0.00143625 -0.0983388 -1.095186 -0.996776 -1.196749 -1.594433 " transform="matrix(0,1,1,0,94.96162,61.90378)"/>
<path style="fill:none;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 76.845281 17.007375 C 76.845281 19.483938 74.1695 21.49175 70.868719 21.49175 C 67.564031 21.49175 64.88825 19.483938 64.88825 17.007375 C 64.88825 14.530813 67.564031 12.526906 70.868719 12.526906 C 74.1695 12.526906 76.845281 14.530813 76.845281 17.007375 Z M 76.845281 17.007375 " transform="matrix(1,0,0,-1,38.268,69.648)"/> <path style="fill:none;stroke-width:0.3985;stroke-linecap:butt;stroke-linejoin:miter;stroke:rgb(0%,0%,0%);stroke-opacity:1;stroke-miterlimit:10;" d="M 76.845281 17.007375 C 76.845281 19.483938 74.1695 21.49175 70.868719 21.49175 C 67.564031 21.49175 64.88825 19.483938 64.88825 17.007375 C 64.88825 14.530813 67.564031 12.526906 70.868719 12.526906 C 74.1695 12.526906 76.845281 14.530813 76.845281 17.007375 Z M 76.845281 17.007375 " transform="matrix(1,0,0,-1,38.268,69.648)"/>
<g style="fill:rgb(0%,0%,0%);fill-opacity:1;"> <g style="fill:rgb(0%,0%,0%);fill-opacity:1;">
<use xlink:href="#glyph1-1" x="107.391" y="53.761"/> <use xlink:href="#glyph1-1" x="105.523" y="54.197"/>
<use xlink:href="#glyph1-2" x="107.764585" y="54.197"/>
<use xlink:href="#glyph1-2" x="110.255434" y="54.197"/>
</g> </g>
<g clip-path="url(#clip8)" clip-rule="nonzero"> <g clip-path="url(#clip8)" clip-rule="nonzero">
<path style=" stroke:none;fill-rule:nonzero;fill:rgb(92.549133%,0%,54.899597%);fill-opacity:0.5;" d="M 102.046875 76.734375 L 116.222656 76.734375 L 116.222656 62.5625 L 102.046875 62.5625 Z M 102.046875 76.734375 "/> <path style=" stroke:none;fill-rule:nonzero;fill:rgb(92.549133%,0%,54.899597%);fill-opacity:0.5;" d="M 102.046875 76.734375 L 116.222656 76.734375 L 116.222656 62.5625 L 102.046875 62.5625 Z M 102.046875 76.734375 "/>
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
\draw[edge] (index\i) -- (input\i); \draw[edge] (index\i) -- (input\i);
} }
\node[title] at (-0.8, 0.0) {output}; \node[title] at (-0.8, 0.0) {out};
\foreach \i in {0,...,\numberOutputs} { \foreach \i in {0,...,\numberOutputs} {
\pgfmathparse{\outputs[\i]}\let\out\pgfmathresult \pgfmathparse{\outputs[\i]}\let\out\pgfmathresult
\pgfmathparse{\colors[\i]}\let\co\pgfmathresult \pgfmathparse{\colors[\i]}\let\co\pgfmathresult
......
Segment COO
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: segment_coo
Segment CSR
===========
.. automodule:: torch_scatter
:noindex:
.. autofunction:: segment_csr
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
PyTorch Scatter Documentation PyTorch Scatter Documentation
============================= =============================
This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in `PyTorch <http://pytorch.org/>`_, which are missing in the main package. This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in `PyTorch <http://pytorch.org/>`_, which are missing in the main package.
Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor. Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations. All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
......
...@@ -15,7 +15,7 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0): ...@@ -15,7 +15,7 @@ def scatter_add(src, index, dim=-1, out=None, dim_size=None, fill_value=0):
Sums all values from the :attr:`src` tensor into :attr:`out` at the indices Sums all values from the :attr:`src` tensor into :attr:`out` at the indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`. For specified in the :attr:`index` tensor along a given axis :attr:`dim`. For
each value in :attr:`src`, its output index is specified by its index in each value in :attr:`src`, its output index is specified by its index in
:attr:`input` for dimensions outside of :attr:`dim` and by the :attr:`src` for dimensions outside of :attr:`dim` and by the
corresponding value in :attr:`index` for dimension :attr:`dim`. If corresponding value in :attr:`index` for dimension :attr:`dim`. If
multiple indices reference the same location, their **contributions add**. multiple indices reference the same location, their **contributions add**.
......
import torch import torch
def min_value(dtype): def min_value(dtype): # pragma: no cover
try: try:
return torch.finfo(dtype).min return torch.finfo(dtype).min
except TypeError: except TypeError:
return torch.iinfo(dtype).min return torch.iinfo(dtype).min
def max_value(dtype): def max_value(dtype): # pragma: no cover
try: try:
return torch.finfo(dtype).max return torch.finfo(dtype).max
except TypeError: except TypeError:
......
...@@ -112,9 +112,164 @@ class SegmentCSR(torch.autograd.Function): ...@@ -112,9 +112,164 @@ class SegmentCSR(torch.autograd.Function):
return grad_src, None, None, None return grad_src, None, None, None
def segment_coo(src, index, out=None, dim_size=None, reduce='add'): def segment_coo(src, index, out=None, dim_size=None, reduce="add"):
r"""
|
.. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
master/docs/source/_figures/segment_coo.svg?sanitize=true
:align: center
:width: 400px
|
Reduces all values from the :attr:`src` tensor into :attr:`out` at the
indices specified in the :attr:`index` tensor along the last dimension of
:attr:`index`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :obj:`index.dim() - 1` and by the
corresponding value in :attr:`index` for dimension :obj:`index.dim() - 1`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-1}, x_m)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`index` must be between :math:`0` and
:math:`y - 1` in ascending order.
The :attr:`index` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="add"`, the operation
computes
.. math::
\mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j
where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.
In contrast to :meth:`scatter`, this method expects values in :attr:`index`
**to be sorted** along dimension :obj:`index.dim() - 1`.
Due to the use of sorted indices, :meth:`segment_coo` is usually faster
than the more general :meth:`scatter` operation.
For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
second tensor representing the :obj:`argmin` and :obj:`argmax`,
respectively.
.. note::
This operation is implemented via atomic operations on the GPU and is
therefore **non-deterministic** since the order of parallel operations
to the same value is undetermined.
For floating-point variables, this results in a source of variance in
the result.
Args:
src (Tensor): The source tensor.
index (LongTensor): The sorted indices of elements to segment.
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension
:obj:`index.dim() - 1`.
If :attr:`dim_size` is not given, a minimal sized output tensor
according to :obj:`index.max() + 1` is returned.
(default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"add"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"add"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
.. code-block:: python
from torch_scatter import segment_coo
src = torch.randn(10, 6, 64)
index = torch.tensor([0, 0, 1, 1, 1, 2])
index = index.view(1, -1) # Broadcasting in the first and last dim.
out = segment_coo(src, index, reduce="add")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
return SegmentCOO.apply(src, index, out, dim_size, reduce) return SegmentCOO.apply(src, index, out, dim_size, reduce)
def segment_csr(src, indptr, out=None, reduce='add'): def segment_csr(src, indptr, out=None, reduce="add"):
r"""
Reduces all values from the :attr:`src` tensor into :attr:`out` within the
ranges specified in the :attr:`indptr` tensor along the last dimension of
:attr:`indptr`.
For each value in :attr:`src`, its output index is specified by its index
in :attr:`src` for dimensions outside of :obj:`indptr.dim() - 1` and by the
corresponding range index in :attr:`indptr` for dimension
:obj:`indptr.dim() - 1`.
The applied reduction is defined via the :attr:`reduce` argument.
Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and
:math:`m`-dimensional tensors with
size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and
:math:`(x_0, ..., x_{m-1}, y)`, respectively, then :attr:`out` must be an
:math:`n`-dimensional tensor with size
:math:`(x_0, ..., x_{m-1}, y - 1, x_{m+1}, ..., x_{n-1})`.
Moreover, the values of :attr:`indptr` must be between :math:`0` and
:math:`x_m` in ascending order.
The :attr:`indptr` tensor supports broadcasting in case its dimensions do
not match with :attr:`src`.
For one-dimensional tensors with :obj:`reduce="add"`, the operation
computes
.. math::
\mathrm{out}_i =
\sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+i]}~\mathrm{src}_j.
Due to the use of index pointers, :meth:`segment_csr` is the fastest
method to apply for grouped reductions.
For reductions :obj:`"min"` and :obj:`"max"`, this operation returns a
second tensor representing the :obj:`argmin` and :obj:`argmax`,
respectively.
.. note::
In contrast to :meth:`scatter()` and :meth:`segment_coo`, this
operation is **fully-deterministic**.
Args:
src (Tensor): The source tensor.
indptr (LongTensor): The index pointers between elements to segment.
The number of dimensions of :attr:`index` needs to be less than or
equal to :attr:`src`.
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
reduce (string, optional): The reduce operation (:obj:`"add"`,
:obj:`"mean"`, :obj:`"min"` or :obj:`"max"`).
(default: :obj:`"add"`)
:rtype: :class:`Tensor`, :class:`LongTensor` *(optional)*
.. code-block:: python
from torch_scatter import segment_csr
src = torch.randn(10, 6, 64)
indptr = torch.tensor([0, 2, 5, 6])
indptr = indptr.view(1, -1) # Broadcasting in the first and last dim.
out = segment_csr(src, indptr, reduce="add")
print(out.size())
.. code-block::
torch.Size([10, 3, 64])
"""
return SegmentCSR.apply(src, indptr, out, reduce) return SegmentCSR.apply(src, indptr, out, reduce)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment