Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
0fa23d82
Commit
0fa23d82
authored
Apr 23, 2020
by
rusty1s
Browse files
fixed true divide
parent
1538c112
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
13 additions
and
6 deletions
+13
-6
csrc/scatter.cpp
csrc/scatter.cpp
+6
-2
csrc/segment_coo.cpp
csrc/segment_coo.cpp
+1
-1
csrc/segment_csr.cpp
csrc/segment_csr.cpp
+1
-1
test/utils.py
test/utils.py
+1
-1
torch_scatter/scatter.py
torch_scatter/scatter.py
+4
-1
No files found.
csrc/scatter.cpp
View file @
0fa23d82
...
@@ -93,7 +93,11 @@ public:
...
@@ -93,7 +93,11 @@ public:
auto
count
=
std
::
get
<
0
>
(
result
);
auto
count
=
std
::
get
<
0
>
(
result
);
count
.
clamp_
(
1
);
count
.
clamp_
(
1
);
count
=
broadcast
(
count
,
out
,
dim
);
count
=
broadcast
(
count
,
out
,
dim
);
out
.
div_
(
count
);
if
(
out
.
is_floating_point
())
out
.
true_divide_
(
count
);
else
out
.
floor_divide_
(
count
);
ctx
->
save_for_backward
({
index
,
count
});
ctx
->
save_for_backward
({
index
,
count
});
if
(
optional_out
.
has_value
())
if
(
optional_out
.
has_value
())
...
@@ -110,7 +114,7 @@ public:
...
@@ -110,7 +114,7 @@ public:
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toIntList
());
auto
src_shape
=
list2vec
(
ctx
->
saved_data
[
"src_shape"
].
toIntList
());
count
=
torch
::
gather
(
count
,
dim
,
index
,
false
);
count
=
torch
::
gather
(
count
,
dim
,
index
,
false
);
auto
grad_in
=
torch
::
gather
(
grad_out
,
dim
,
index
,
false
);
auto
grad_in
=
torch
::
gather
(
grad_out
,
dim
,
index
,
false
);
grad_in
.
div
_
(
count
);
grad_in
.
true_divide
_
(
count
);
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
(),
Variable
()};
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
(),
Variable
()};
}
}
};
};
...
...
csrc/segment_coo.cpp
View file @
0fa23d82
...
@@ -97,7 +97,7 @@ public:
...
@@ -97,7 +97,7 @@ public:
count
=
gather_coo_fw
(
count
,
index
,
torch
::
nullopt
);
count
=
gather_coo_fw
(
count
,
index
,
torch
::
nullopt
);
for
(
auto
i
=
0
;
i
<
grad_out
.
dim
()
-
index
.
dim
();
i
++
)
for
(
auto
i
=
0
;
i
<
grad_out
.
dim
()
-
index
.
dim
();
i
++
)
count
=
count
.
unsqueeze
(
-
1
);
count
=
count
.
unsqueeze
(
-
1
);
grad_in
.
div
_
(
count
);
grad_in
.
true_divide
_
(
count
);
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
()};
return
{
grad_in
,
Variable
(),
Variable
(),
Variable
()};
}
}
};
};
...
...
csrc/segment_csr.cpp
View file @
0fa23d82
...
@@ -95,7 +95,7 @@ public:
...
@@ -95,7 +95,7 @@ public:
count
=
gather_csr_fw
(
count
,
indptr
,
torch
::
nullopt
);
count
=
gather_csr_fw
(
count
,
indptr
,
torch
::
nullopt
);
for
(
auto
i
=
0
;
i
<
grad_out
.
dim
()
-
indptr
.
dim
();
i
++
)
for
(
auto
i
=
0
;
i
<
grad_out
.
dim
()
-
indptr
.
dim
();
i
++
)
count
=
count
.
unsqueeze
(
-
1
);
count
=
count
.
unsqueeze
(
-
1
);
grad_in
.
div
_
(
count
);
grad_in
.
true_divide
_
(
count
);
}
}
return
{
grad_in
,
Variable
(),
Variable
()};
return
{
grad_in
,
Variable
(),
Variable
()};
}
}
...
...
test/utils.py
View file @
0fa23d82
...
@@ -11,4 +11,4 @@ if torch.cuda.is_available():
...
@@ -11,4 +11,4 @@ if torch.cuda.is_available():
def
tensor
(
x
,
dtype
,
device
):
def
tensor
(
x
,
dtype
,
device
):
return
None
if
x
is
None
else
torch
.
tensor
(
x
,
dtype
=
dtype
,
device
=
device
)
return
None
if
x
is
None
else
torch
.
tensor
(
x
,
device
=
device
)
.
to
(
dtype
)
torch_scatter/scatter.py
View file @
0fa23d82
...
@@ -49,7 +49,10 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
...
@@ -49,7 +49,10 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
count
=
scatter_sum
(
ones
,
index
,
index_dim
,
None
,
dim_size
)
count
=
scatter_sum
(
ones
,
index
,
index_dim
,
None
,
dim_size
)
count
.
clamp_
(
1
)
count
.
clamp_
(
1
)
count
=
broadcast
(
count
,
out
,
dim
)
count
=
broadcast
(
count
,
out
,
dim
)
out
.
div_
(
count
)
if
torch
.
is_floating_point
(
out
):
out
.
true_divide_
(
count
)
else
:
out
.
floor_divide_
(
count
)
return
out
return
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment