Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
cc6f8862
Commit
cc6f8862
authored
Jan 15, 2018
by
rusty1s
Browse files
cuda bugfixes
parent
ba26dfb1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
9 additions
and
8 deletions
+9
-8
test/forward.json
test/forward.json
+2
-2
test/test_backward.py
test/test_backward.py
+1
-1
test/test_forward.py
test/test_forward.py
+1
-2
torch_scatter/functions/sub.py
torch_scatter/functions/sub.py
+1
-1
torch_scatter/kernel/kernel.cu
torch_scatter/kernel/kernel.cu
+4
-2
No files found.
test/forward.json
View file @
cc6f8862
...
@@ -26,10 +26,10 @@
...
@@ -26,10 +26,10 @@
{
{
"name"
:
"sub"
,
"name"
:
"sub"
,
"index"
:
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
"index"
:
[[
0
,
0
],
[
1
,
1
],
[
1
,
1
],
[
0
,
0
]],
"input"
:
[[
5
,
2
],
[
2
,
5
],
[
4
,
3
],
[
1
,
3
]],
"input"
:
[[
5
,
2
],
[
2
,
2
],
[
4
,
2
],
[
1
,
3
]],
"dim"
:
0
,
"dim"
:
0
,
"fill_value"
:
9
,
"fill_value"
:
9
,
"expected"
:
[[
3
,
4
],
[
3
,
1
]]
"expected"
:
[[
3
,
4
],
[
3
,
5
]]
},
},
{
{
"name"
:
"mul"
,
"name"
:
"mul"
,
...
...
test/test_backward.py
View file @
cc6f8862
...
@@ -35,7 +35,7 @@ def test_backward_cpu(tensor, i):
...
@@ -35,7 +35,7 @@ def test_backward_cpu(tensor, i):
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
data
))))
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
data
))))
def
test_backward_gpu
(
tensor
,
i
):
# pragma: no cover
def
test_backward_gpu
(
tensor
,
i
):
# pragma: no cover
name
=
data
[
i
][
'name'
]
name
=
data
[
i
][
'name'
]
index
=
V
(
torch
.
LongTensor
(
data
[
i
][
'index'
])
.
cuda
()
)
index
=
V
(
torch
.
cuda
.
LongTensor
(
data
[
i
][
'index'
]))
input
=
V
(
Tensor
(
tensor
,
data
[
i
][
'input'
]).
cuda
(),
requires_grad
=
True
)
input
=
V
(
Tensor
(
tensor
,
data
[
i
][
'input'
]).
cuda
(),
requires_grad
=
True
)
dim
=
data
[
i
][
'dim'
]
dim
=
data
[
i
][
'dim'
]
fill_value
=
data
[
i
][
'fill_value'
]
fill_value
=
data
[
i
][
'fill_value'
]
...
...
test/test_forward.py
View file @
cc6f8862
...
@@ -44,7 +44,7 @@ def test_forward_cpu(tensor, i):
...
@@ -44,7 +44,7 @@ def test_forward_cpu(tensor, i):
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
data
))))
@
pytest
.
mark
.
parametrize
(
'tensor,i'
,
product
(
tensors
,
range
(
len
(
data
))))
def
test_forward_gpu
(
tensor
,
i
):
# pragma: no cover
def
test_forward_gpu
(
tensor
,
i
):
# pragma: no cover
name
=
data
[
i
][
'name'
]
name
=
data
[
i
][
'name'
]
index
=
torch
.
LongTensor
(
data
[
i
][
'index'
])
.
cuda
()
index
=
torch
.
cuda
.
LongTensor
(
data
[
i
][
'index'
])
input
=
Tensor
(
tensor
,
data
[
i
][
'input'
]).
cuda
()
input
=
Tensor
(
tensor
,
data
[
i
][
'input'
]).
cuda
()
dim
=
data
[
i
][
'dim'
]
dim
=
data
[
i
][
'dim'
]
fill_value
=
data
[
i
][
'fill_value'
]
fill_value
=
data
[
i
][
'fill_value'
]
...
@@ -57,7 +57,6 @@ def test_forward_gpu(tensor, i): # pragma: no cover
...
@@ -57,7 +57,6 @@ def test_forward_gpu(tensor, i): # pragma: no cover
if
'expected_arg'
in
data
[
i
]:
if
'expected_arg'
in
data
[
i
]:
expected_arg
=
torch
.
LongTensor
(
data
[
i
][
'expected_arg'
])
expected_arg
=
torch
.
LongTensor
(
data
[
i
][
'expected_arg'
])
assert
result
[
1
].
cpu
().
tolist
()
==
expected_arg
.
tolist
()
assert
result
[
1
].
cpu
().
tolist
()
==
expected_arg
.
tolist
()
func
=
getattr
(
torch_scatter
,
'scatter_{}'
.
format
(
name
))
func
=
getattr
(
torch_scatter
,
'scatter_{}'
.
format
(
name
))
result
=
func
(
index
,
input
,
dim
,
fill_value
=
fill_value
)
result
=
func
(
index
,
input
,
dim
,
fill_value
=
fill_value
)
if
'expected_arg'
not
in
data
[
i
]:
if
'expected_arg'
not
in
data
[
i
]:
...
...
torch_scatter/functions/sub.py
View file @
cc6f8862
...
@@ -51,7 +51,7 @@ def scatter_sub_(output, index, input, dim=0):
...
@@ -51,7 +51,7 @@ def scatter_sub_(output, index, input, dim=0):
-2 -4 -4 0 0 0
-2 -4 -4 0 0 0
[torch.FloatTensor of size 2x6]
[torch.FloatTensor of size 2x6]
"""
"""
return
output
.
scatter_add_
(
dim
,
index
,
-
1
*
input
)
return
output
.
scatter_add_
(
dim
,
index
,
-
input
)
def
scatter_sub
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
def
scatter_sub
(
index
,
input
,
dim
=
0
,
size
=
None
,
fill_value
=
0
):
...
...
torch_scatter/kernel/kernel.cu
View file @
cc6f8862
...
@@ -64,7 +64,9 @@ __global__ void argKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te
...
@@ -64,7 +64,9 @@ __global__ void argKernel(TensorInfo<Real> output, TensorInfo<int64_t> index, Te
KERNEL_LOOP
(
i
,
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;
int
argOffset
=
0
;
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;
int
argOffset
=
0
;
IndexToScatterOffsets4
<
Real
,
Real
,
int64_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
,
arg
,
&
argOffset
);
IndexToScatterOffsets4
<
Real
,
Real
,
int64_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
,
arg
,
&
argOffset
);
if
(
input
.
data
[
inputOffset
]
==
output
.
data
[
outputOffset
])
arg
.
data
[
argOffset
]
=
inputOffset
%
input
.
size
[
dim
];
if
(
input
.
data
[
inputOffset
]
==
output
.
data
[
outputOffset
])
{
arg
.
data
[
argOffset
]
=
(
inputOffset
/
input
.
stride
[
dim
])
%
input
.
size
[
dim
];
}
}
}
}
}
...
@@ -73,7 +75,7 @@ __global__ void indexBackwardKernel(TensorInfo<Real> output, TensorInfo<int64_t>
...
@@ -73,7 +75,7 @@ __global__ void indexBackwardKernel(TensorInfo<Real> output, TensorInfo<int64_t>
KERNEL_LOOP
(
i
,
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
gradOffset
=
0
;
int
argOffset
=
0
;
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
gradOffset
=
0
;
int
argOffset
=
0
;
IndexToScatterOffsets4
<
Real
,
Real
,
int64_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
output
,
&
outputOffset
,
grad
,
&
gradOffset
,
arg
,
&
argOffset
);
IndexToScatterOffsets4
<
Real
,
Real
,
int64_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
output
,
&
outputOffset
,
grad
,
&
gradOffset
,
arg
,
&
argOffset
);
if
(
arg
.
data
[
argOffset
]
==
outputOffset
%
output
.
size
[
dim
])
output
.
data
[
outputOffset
]
=
grad
.
data
[
gradOffset
];
if
(
arg
.
data
[
argOffset
]
==
(
outputOffset
/
output
.
stride
[
dim
])
%
output
.
size
[
dim
])
output
.
data
[
outputOffset
]
=
grad
.
data
[
gradOffset
];
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment