Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
5628a6f6
"docs/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "d5b7eaaeee7b7cb1038376c6048b3466a8c6ffe6"
Commit
5628a6f6
authored
Dec 16, 2017
by
rusty1s
Browse files
added mean impl
parent
d951ab4d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
130 additions
and
56 deletions
+130
-56
test/test_mean.py
test/test_mean.py
+35
-0
torch_scatter/functions/__init__.py
torch_scatter/functions/__init__.py
+24
-6
torch_scatter/functions/scatter.py
torch_scatter/functions/scatter.py
+12
-12
torch_scatter/src/THTensorDimApply.h
torch_scatter/src/THTensorDimApply.h
+2
-2
torch_scatter/src/cpu.c
torch_scatter/src/cpu.c
+2
-0
torch_scatter/src/cpu.h
torch_scatter/src/cpu.h
+36
-28
torch_scatter/src/generic/cpu.c
torch_scatter/src/generic/cpu.c
+19
-8
No files found.
test/test_mean.py
0 → 100644
View file @
5628a6f6
import
pytest
import
torch
from
torch.autograd
import
Variable
from
torch_scatter
import
scatter_mean_
,
scatter_mean
from
.utils
import
tensor_strs
,
Tensor
# @pytest.mark.parametrize('str', tensor_strs)
# def test_scatter_add(str):
def
test_scatter_mean
():
input
=
[[
2
,
0
,
1
,
4
,
3
],
[
0
,
2
,
1
,
3
,
4
]]
index
=
[[
4
,
5
,
4
,
2
,
3
],
[
0
,
0
,
2
,
2
,
1
]]
input
=
torch
.
FloatTensor
(
input
)
index
=
torch
.
LongTensor
(
index
)
output
=
input
.
new
(
2
,
6
).
fill_
(
0
)
# expected_output = [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]
scatter_mean_
(
output
,
index
,
input
,
dim
=
1
)
print
(
output
)
# assert output.tolist() == expected_output
# output = scatter_add(index, input, dim=1)
# assert output.tolist(), expected_output
# output = Variable(output).fill_(0)
# index = Variable(index)
# input = Variable(input, requires_grad=True)
# scatter_add_(output, index, input, dim=1)
# grad_output = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5]]
# grad_output = Tensor(str, grad_output)
# output.backward(grad_output)
# assert index.data.tolist() == input.grad.data.tolist()
torch_scatter/functions/__init__.py
View file @
5628a6f6
...
@@ -3,7 +3,8 @@ from .utils import gen_output
...
@@ -3,7 +3,8 @@ from .utils import gen_output
def
scatter_add_
(
output
,
index
,
input
,
dim
=
0
):
def
scatter_add_
(
output
,
index
,
input
,
dim
=
0
):
return
scatter
(
'add'
,
output
,
index
,
input
,
dim
)
scatter
(
'add'
,
dim
,
output
,
index
,
input
)
return
output
def
scatter_add
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
def
scatter_add
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
...
@@ -12,7 +13,8 @@ def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
...
@@ -12,7 +13,8 @@ def scatter_add(index, input, dim=0, max_index=None, fill_value=0):
def
scatter_sub_
(
output
,
index
,
input
,
dim
=
0
):
def
scatter_sub_
(
output
,
index
,
input
,
dim
=
0
):
return
scatter
(
'sub'
,
output
,
index
,
input
,
dim
)
scatter
(
'sub'
,
dim
,
output
,
index
,
input
)
return
output
def
scatter_sub
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
def
scatter_sub
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
0
):
...
@@ -21,7 +23,8 @@ def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
...
@@ -21,7 +23,8 @@ def scatter_sub(index, input, dim=0, max_index=None, fill_value=0):
def
scatter_mul_
(
output
,
index
,
input
,
dim
=
0
):
def
scatter_mul_
(
output
,
index
,
input
,
dim
=
0
):
return
scatter
(
'mul'
,
output
,
index
,
input
,
dim
)
scatter
(
'mul'
,
dim
,
output
,
index
,
input
)
return
output
def
scatter_mul
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
1
):
def
scatter_mul
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
1
):
...
@@ -30,15 +33,30 @@ def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
...
@@ -30,15 +33,30 @@ def scatter_mul(index, input, dim=0, max_index=None, fill_value=1):
def
scatter_div_
(
output
,
index
,
input
,
dim
=
0
):
def
scatter_div_
(
output
,
index
,
input
,
dim
=
0
):
return
scatter
(
'div'
,
output
,
index
,
input
,
dim
)
scatter
(
'div'
,
dim
,
output
,
index
,
input
)
return
output
def
scatter_div
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
1
):
def
scatter_div
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
1
):
output
=
gen_output
(
index
,
input
,
dim
,
max_index
,
fill_value
)
output
=
gen_output
(
index
,
input
,
dim
,
max_index
,
fill_value
)
return
scatter_div_
(
output
,
index
,
input
,
dim
)
scatter_div_
(
output
,
index
,
input
,
dim
)
def
scatter_mean_
(
output
,
index
,
input
,
dim
=
0
):
output_count
=
output
.
new
(
output
.
size
()).
fill_
(
0
)
scatter
(
'mean'
,
dim
,
output
,
index
,
input
,
output_count
)
output
/=
output_count
output
[
output
!=
output
]
=
0
return
output
def
scatter_mean
(
index
,
input
,
dim
=
0
,
max_index
=
None
,
fill_value
=
1
):
output
=
gen_output
(
index
,
input
,
dim
,
max_index
,
fill_value
)
return
scatter_mean_
(
output
,
index
,
input
,
dim
)
__all__
=
[
__all__
=
[
'scatter_add_'
,
'scatter_add'
,
'scatter_sub_'
,
'scatter_sub'
,
'scatter_add_'
,
'scatter_add'
,
'scatter_sub_'
,
'scatter_sub'
,
'scatter_mul_'
,
'scatter_mul'
,
'scatter_div_'
,
'scatter_div'
'scatter_mul_'
,
'scatter_mul'
,
'scatter_div_'
,
'scatter_div'
,
'scatter_mean_'
,
'scatter_mean'
]
]
torch_scatter/functions/scatter.py
View file @
5628a6f6
...
@@ -4,11 +4,10 @@ from torch.autograd import Function
...
@@ -4,11 +4,10 @@ from torch.autograd import Function
from
.._ext
import
ffi
from
.._ext
import
ffi
def
_scatter
(
name
,
output
,
index
,
input
,
dim
):
def
_scatter
(
name
,
dim
,
*
data
):
typename
=
type
(
input
).
__name__
.
replace
(
'Tensor'
,
''
)
typename
=
type
(
data
[
0
]
).
__name__
.
replace
(
'Tensor'
,
''
)
func
=
getattr
(
ffi
,
'scatter_{}_{}'
.
format
(
name
,
typename
))
func
=
getattr
(
ffi
,
'scatter_{}_{}'
.
format
(
name
,
typename
))
func
(
output
,
index
,
input
,
dim
)
func
(
dim
,
*
data
)
return
output
class
_Scatter
(
Function
):
class
_Scatter
(
Function
):
...
@@ -17,13 +16,14 @@ class _Scatter(Function):
...
@@ -17,13 +16,14 @@ class _Scatter(Function):
self
.
dim
=
dim
self
.
dim
=
dim
self
.
name
=
name
self
.
name
=
name
def
forward
(
self
,
output
,
index
,
input
):
def
forward
(
self
,
*
data
):
assert
not
self
.
needs_input_grad
[
1
],
'Can
\'
t differentiate the index'
assert
not
self
.
needs_input_grad
[
1
],
'Can
\'
t differentiate the index'
self
.
mark_dirty
(
output
)
self
.
mark_dirty
(
data
[
0
]
)
self
.
save_for_backward
(
index
)
self
.
save_for_backward
(
data
[
1
]
)
return
_scatter
(
self
.
name
,
output
,
index
,
input
,
self
.
dim
)
_scatter
(
self
.
name
,
self
.
dim
,
*
data
)
return
data
[
0
]
def
backward
(
self
,
grad
):
def
backward
(
self
,
grad
):
index
,
=
self
.
saved_variables
index
,
=
self
.
saved_variables
...
@@ -37,8 +37,8 @@ class _Scatter(Function):
...
@@ -37,8 +37,8 @@ class _Scatter(Function):
return
grad_output
,
None
,
grad_input
return
grad_output
,
None
,
grad_input
def
scatter
(
name
,
output
,
index
,
input
,
dim
):
def
scatter
(
name
,
dim
,
*
data
):
if
torch
.
is_tensor
(
input
):
if
torch
.
is_tensor
(
data
[
0
]
):
return
_scatter
(
name
,
output
,
index
,
input
,
dim
)
return
_scatter
(
name
,
dim
,
*
data
)
else
:
else
:
return
_Scatter
(
name
,
dim
)(
output
,
index
,
input
)
return
_Scatter
(
name
,
dim
)(
*
data
)
torch_scatter/src/THTensorDimApply.h
View file @
5628a6f6
...
@@ -44,10 +44,10 @@
...
@@ -44,10 +44,10 @@
THDescBuff T3buff = _THSizeDesc(TENSOR3->size, TENSOR3->nDimension); \
THDescBuff T3buff = _THSizeDesc(TENSOR3->size, TENSOR3->nDimension); \
THDescBuff T4buff = _THSizeDesc(TENSOR4->size, TENSOR3->nDimension); \
THDescBuff T4buff = _THSizeDesc(TENSOR4->size, TENSOR3->nDimension); \
THError("inconsistent tensor size, expected %s %s, %s %s, %s %s and %s %s to have the same " \
THError("inconsistent tensor size, expected %s %s, %s %s, %s %s and %s %s to have the same " \
"number of dimensions", #TENSOR1, T1buff.str, #TENSOR2, T2buff.str, #TENSOR3, T3buff.str, #TENSOR4, T4
.
buff.str); \
"number of dimensions", #TENSOR1, T1buff.str, #TENSOR2, T2buff.str, #TENSOR3, T3buff.str, #TENSOR4, T4buff.str); \
} \
} \
\
\
SIZE_CHECK(TENSOR1, TENSOR2, TENSOR3, DIMENSION) \
SIZE_CHECK(TENSOR1, TENSOR2, TENSOR3,
TENSOR4,
DIMENSION) \
\
\
TH_TENSOR_DIM_APPLY_counter = (int64_t*)THAlloc(sizeof(int64_t)*(TENSOR1->nDimension)); \
TH_TENSOR_DIM_APPLY_counter = (int64_t*)THAlloc(sizeof(int64_t)*(TENSOR1->nDimension)); \
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \
for (TH_TENSOR_DIM_APPLY_i = 0; TH_TENSOR_DIM_APPLY_i < TENSOR1->nDimension; TH_TENSOR_DIM_APPLY_i++) \
...
...
torch_scatter/src/cpu.c
View file @
5628a6f6
#include <TH/TH.h>
#include <TH/TH.h>
#include "THTensorDimApply.h"
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _, Real)
#define scatter_(NAME) TH_CONCAT_4(scatter_, NAME, _, Real)
inline
void
assertIndexInBoundaries
(
int
idx
,
int
size
,
int64_t
*
free
)
{
inline
void
assertIndexInBoundaries
(
int
idx
,
int
size
,
int64_t
*
free
)
{
...
...
torch_scatter/src/cpu.h
View file @
5628a6f6
void
scatter_add_Float
(
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
int
dim
);
void
scatter_add_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
);
void
scatter_add_Double
(
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
int
dim
);
void
scatter_add_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
);
void
scatter_add_Byte
(
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
,
int
dim
);
void
scatter_add_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
);
void
scatter_add_Char
(
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
,
int
dim
);
void
scatter_add_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
);
void
scatter_add_Short
(
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
,
int
dim
);
void
scatter_add_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
);
void
scatter_add_Int
(
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
,
int
dim
);
void
scatter_add_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
);
void
scatter_add_Long
(
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
,
int
dim
);
void
scatter_add_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
);
void
scatter_sub_Float
(
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
int
dim
);
void
scatter_sub_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
);
void
scatter_sub_Double
(
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
int
dim
);
void
scatter_sub_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
);
void
scatter_sub_Byte
(
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
,
int
dim
);
void
scatter_sub_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
);
void
scatter_sub_Char
(
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
,
int
dim
);
void
scatter_sub_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
);
void
scatter_sub_Short
(
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
,
int
dim
);
void
scatter_sub_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
);
void
scatter_sub_Int
(
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
,
int
dim
);
void
scatter_sub_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
);
void
scatter_sub_Long
(
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
,
int
dim
);
void
scatter_sub_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
);
void
scatter_mul_Float
(
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
int
dim
);
void
scatter_mul_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
);
void
scatter_mul_Double
(
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
int
dim
);
void
scatter_mul_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
);
void
scatter_mul_Byte
(
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
,
int
dim
);
void
scatter_mul_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
);
void
scatter_mul_Char
(
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
,
int
dim
);
void
scatter_mul_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
);
void
scatter_mul_Short
(
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
,
int
dim
);
void
scatter_mul_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
);
void
scatter_mul_Int
(
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
,
int
dim
);
void
scatter_mul_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
);
void
scatter_mul_Long
(
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
,
int
dim
);
void
scatter_mul_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
);
void
scatter_div_Float
(
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
int
dim
);
void
scatter_div_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
);
void
scatter_div_Double
(
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
int
dim
);
void
scatter_div_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
);
void
scatter_div_Byte
(
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
,
int
dim
);
void
scatter_div_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
);
void
scatter_div_Char
(
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
,
int
dim
);
void
scatter_div_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
);
void
scatter_div_Short
(
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
,
int
dim
);
void
scatter_div_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
);
void
scatter_div_Int
(
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
,
int
dim
);
void
scatter_div_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
);
void
scatter_div_Long
(
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
,
int
dim
);
void
scatter_div_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
);
void
scatter_mean_Float
(
int
dim
,
THFloatTensor
*
output
,
THLongTensor
*
index
,
THFloatTensor
*
input
,
THFloatTensor
*
output_count
);
void
scatter_mean_Double
(
int
dim
,
THDoubleTensor
*
output
,
THLongTensor
*
index
,
THDoubleTensor
*
input
,
THDoubleTensor
*
output_count
);
void
scatter_mean_Byte
(
int
dim
,
THByteTensor
*
output
,
THLongTensor
*
index
,
THByteTensor
*
input
,
THByteTensor
*
output_count
);
void
scatter_mean_Char
(
int
dim
,
THCharTensor
*
output
,
THLongTensor
*
index
,
THCharTensor
*
input
,
THCharTensor
*
output_count
);
void
scatter_mean_Short
(
int
dim
,
THShortTensor
*
output
,
THLongTensor
*
index
,
THShortTensor
*
input
,
THShortTensor
*
output_count
);
void
scatter_mean_Int
(
int
dim
,
THIntTensor
*
output
,
THLongTensor
*
index
,
THIntTensor
*
input
,
THIntTensor
*
output_count
);
void
scatter_mean_Long
(
int
dim
,
THLongTensor
*
output
,
THLongTensor
*
index
,
THLongTensor
*
input
,
THLongTensor
*
output_count
);
torch_scatter/src/generic/cpu.c
View file @
5628a6f6
...
@@ -2,9 +2,9 @@
...
@@ -2,9 +2,9 @@
#define TH_GENERIC_FILE "generic/cpu.c"
#define TH_GENERIC_FILE "generic/cpu.c"
#else
#else
void
scatter_
(
add
)(
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
int
dim
)
{
void
scatter_
(
add
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
)
{
int64_t
idx
;
int64_t
idx
;
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
real
,
input
,
int64_t
,
index
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
idx
=
*
(
index_data
+
i
*
index_stride
);
idx
=
*
(
index_data
+
i
*
index_stride
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
...
@@ -12,9 +12,9 @@ void scatter_(add)(THTensor *output, THLongTensor *index, THTensor *input, int d
...
@@ -12,9 +12,9 @@ void scatter_(add)(THTensor *output, THLongTensor *index, THTensor *input, int d
})
})
}
}
void
scatter_
(
sub
)(
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
int
dim
)
{
void
scatter_
(
sub
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
)
{
int64_t
idx
;
int64_t
idx
;
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
real
,
input
,
int64_t
,
index
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
idx
=
*
(
index_data
+
i
*
index_stride
);
idx
=
*
(
index_data
+
i
*
index_stride
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
...
@@ -22,9 +22,9 @@ void scatter_(sub)(THTensor *output, THLongTensor *index, THTensor *input, int d
...
@@ -22,9 +22,9 @@ void scatter_(sub)(THTensor *output, THLongTensor *index, THTensor *input, int d
})
})
}
}
void
scatter_
(
mul
)(
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
int
dim
)
{
void
scatter_
(
mul
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
)
{
int64_t
idx
;
int64_t
idx
;
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
real
,
input
,
int64_t
,
index
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
idx
=
*
(
index_data
+
i
*
index_stride
);
idx
=
*
(
index_data
+
i
*
index_stride
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
...
@@ -32,9 +32,9 @@ void scatter_(mul)(THTensor *output, THLongTensor *index, THTensor *input, int d
...
@@ -32,9 +32,9 @@ void scatter_(mul)(THTensor *output, THLongTensor *index, THTensor *input, int d
})
})
}
}
void
scatter_
(
div
)(
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
int
dim
)
{
void
scatter_
(
div
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
)
{
int64_t
idx
;
int64_t
idx
;
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
real
,
input
,
int64_t
,
index
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
TH_TENSOR_DIM_APPLY3
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
dim
,
TH_TENSOR_DIM_APPLY3_SIZE_EQ_EXCEPT_DIM
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
idx
=
*
(
index_data
+
i
*
index_stride
);
idx
=
*
(
index_data
+
i
*
index_stride
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
...
@@ -42,4 +42,15 @@ void scatter_(div)(THTensor *output, THLongTensor *index, THTensor *input, int d
...
@@ -42,4 +42,15 @@ void scatter_(div)(THTensor *output, THLongTensor *index, THTensor *input, int d
})
})
}
}
void
scatter_
(
mean
)(
int
dim
,
THTensor
*
output
,
THLongTensor
*
index
,
THTensor
*
input
,
THTensor
*
output_count
)
{
int64_t
idx
;
TH_TENSOR_DIM_APPLY4
(
real
,
output
,
int64_t
,
index
,
real
,
input
,
real
,
output_count
,
dim
,
TH_TENSOR_DIM_APPLY4_SIZE_EQ_EXCEPT_DIM
,
for
(
int64_t
i
=
0
;
i
<
THLongTensor_size
(
index
,
dim
);
i
++
)
{
idx
=
*
(
index_data
+
i
*
index_stride
);
assertIndexInBoundaries
(
idx
,
output_size
,
TH_TENSOR_DIM_APPLY_counter
);
output_data
[
idx
]
+=
*
(
input_data
+
i
*
input_stride
);
output_count_data
[
idx
]
++
;
})
}
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment