Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
c72f36c0
Commit
c72f36c0
authored
Oct 22, 2021
by
rusty1s
Browse files
upgrade to PyTorch 1.10
parent
605566ec
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
4 additions
and
5 deletions
+4
-5
.github/workflows/testing.yml
.github/workflows/testing.yml
+1
-1
csrc/cuda/segment_coo_cuda.cu
csrc/cuda/segment_coo_cuda.cu
+1
-1
csrc/scatter.cpp
csrc/scatter.cpp
+1
-1
torch_scatter/scatter.py
torch_scatter/scatter.py
+1
-2
No files found.
.github/workflows/testing.yml
View file @
c72f36c0
...
...
@@ -11,7 +11,7 @@ jobs:
matrix
:
os
:
[
ubuntu-latest
,
windows-latest
]
python-version
:
[
3.6
]
torch-version
:
[
1.
8
.0
,
1.
9
.0
]
torch-version
:
[
1.
9
.0
,
1.
10
.0
]
steps
:
-
uses
:
actions/checkout@v2
...
...
csrc/cuda/segment_coo_cuda.cu
View file @
c72f36c0
...
...
@@ -274,7 +274,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
if
(
out
.
is_floating_point
())
out
.
true_divide_
(
count
);
else
out
.
floor_divide_
(
count
);
out
.
div_
(
count
,
"floor"
);
}
});
});
...
...
csrc/scatter.cpp
View file @
c72f36c0
...
...
@@ -132,7 +132,7 @@ public:
if
(
out
.
is_floating_point
())
out
.
true_divide_
(
count
);
else
out
.
floor_divide_
(
count
);
out
.
div_
(
count
,
"floor"
);
ctx
->
save_for_backward
({
index
,
count
});
if
(
optional_out
.
has_value
())
...
...
torch_scatter/scatter.py
View file @
c72f36c0
...
...
@@ -38,7 +38,6 @@ def scatter_mul(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
def
scatter_mean
(
src
:
torch
.
Tensor
,
index
:
torch
.
Tensor
,
dim
:
int
=
-
1
,
out
:
Optional
[
torch
.
Tensor
]
=
None
,
dim_size
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
out
=
scatter_sum
(
src
,
index
,
dim
,
out
,
dim_size
)
dim_size
=
out
.
size
(
dim
)
...
...
@@ -55,7 +54,7 @@ def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
if
out
.
is_floating_point
():
out
.
true_divide_
(
count
)
else
:
out
.
floor_divide_
(
count
)
out
.
div_
(
count
,
rounding_mode
=
'floor'
)
return
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment