Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
ec8477ea
Commit
ec8477ea
authored
Jun 19, 2019
by
rusty1s
Browse files
cleaner scatter min/max backward implementation
parent
3922eca6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
2 additions
and
64 deletions
+2
-64
cpu/scatter.cpp
cpu/scatter.cpp
+0
-17
cuda/scatter.cpp
cuda/scatter.cpp
+0
-10
cuda/scatter_kernel.cu
cuda/scatter_kernel.cu
+0
-33
torch_scatter/max.py
torch_scatter/max.py
+1
-2
torch_scatter/min.py
torch_scatter/min.py
+1
-2
No files found.
cpu/scatter.cpp
View file @
ec8477ea
...
...
@@ -62,26 +62,9 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
});
}
void
index_backward
(
at
::
Tensor
grad
,
at
::
Tensor
index
,
at
::
Tensor
arg
,
at
::
Tensor
out
,
int64_t
dim
)
{
int64_t
elems_per_row
=
index
.
size
(
dim
),
i
,
idx
;
AT_DISPATCH_ALL_TYPES
(
grad
.
scalar_type
(),
"index_backward"
,
[
&
]
{
DIM_APPLY4
(
scalar_t
,
grad
,
int64_t
,
index
,
int64_t
,
arg
,
scalar_t
,
out
,
dim
,
{
for
(
i
=
0
;
i
<
elems_per_row
;
i
++
)
{
idx
=
index_data
[
i
*
index_stride
];
if
(
arg_data
[
idx
*
arg_stride
]
==
i
)
{
out_data
[
i
*
out_stride
]
=
grad_data
[
idx
*
grad_stride
];
}
}
});
});
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"scatter_mul"
,
&
scatter_mul
,
"Scatter Mul (CPU)"
);
m
.
def
(
"scatter_div"
,
&
scatter_div
,
"Scatter Div (CPU)"
);
m
.
def
(
"scatter_max"
,
&
scatter_max
,
"Scatter Max (CPU)"
);
m
.
def
(
"scatter_min"
,
&
scatter_min
,
"Scatter Min (CPU)"
);
m
.
def
(
"index_backward"
,
&
index_backward
,
"Index Backward (CPU)"
);
}
cuda/scatter.cpp
View file @
ec8477ea
...
...
@@ -47,19 +47,9 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
scatter_min_cuda
(
src
,
index
,
out
,
arg
,
dim
);
}
void
index_backward
(
at
::
Tensor
grad
,
at
::
Tensor
index
,
at
::
Tensor
arg
,
at
::
Tensor
out
,
int64_t
dim
)
{
CHECK_CUDA
(
grad
);
CHECK_CUDA
(
index
);
CHECK_CUDA
(
arg
);
CHECK_CUDA
(
out
);
index_backward_cuda
(
grad
,
index
,
arg
,
out
,
dim
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"scatter_mul"
,
&
scatter_mul
,
"Scatter Mul (CUDA)"
);
m
.
def
(
"scatter_div"
,
&
scatter_div
,
"Scatter Div (CUDA)"
);
m
.
def
(
"scatter_max"
,
&
scatter_max
,
"Scatter Max (CUDA)"
);
m
.
def
(
"scatter_min"
,
&
scatter_min
,
"Scatter Min (CUDA)"
);
m
.
def
(
"index_backward"
,
&
index_backward
,
"Index Backward (CUDA)"
);
}
cuda/scatter_kernel.cu
View file @
ec8477ea
...
...
@@ -159,36 +159,3 @@ void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
dim
);
});
}
template
<
typename
scalar_t
,
int64_t
Dims
>
__global__
void
index_backward_kernel
(
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
grad
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
index
,
at
::
cuda
::
detail
::
TensorInfo
<
int64_t
,
int64_t
>
arg
,
at
::
cuda
::
detail
::
TensorInfo
<
scalar_t
,
int64_t
>
out
,
int64_t
dim
,
size_t
numel
)
{
const
size_t
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
i
=
idx
;
i
<
numel
;
i
+=
stride
)
{
int64_t
gradOffset
=
0
,
indexOffset
=
0
,
argOffset
=
0
,
outOffset
=
0
;
IndexToScatterOffsets4
<
scalar_t
,
int64_t
,
scalar_t
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
out
,
&
outOffset
,
arg
,
&
argOffset
,
grad
,
&
gradOffset
);
if
(
arg
.
data
[
argOffset
]
==
(
outOffset
/
out
.
strides
[
dim
])
%
out
.
sizes
[
dim
])
{
out
.
data
[
outOffset
]
=
grad
.
data
[
gradOffset
];
}
}
}
void
index_backward_cuda
(
at
::
Tensor
grad
,
at
::
Tensor
index
,
at
::
Tensor
arg
,
at
::
Tensor
out
,
int64_t
dim
)
{
cudaSetDevice
(
grad
.
get_device
());
AT_DISPATCH_ALL_TYPES
(
grad
.
scalar_type
(),
"index_backward_kernel"
,
[
&
]
{
KERNEL_RUN
(
index_backward_kernel
,
index
.
dim
(),
index
.
numel
(),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
grad
),
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
index
),
at
::
cuda
::
detail
::
getTensorInfo
<
int64_t
,
int64_t
>
(
arg
),
at
::
cuda
::
detail
::
getTensorInfo
<
scalar_t
,
int64_t
>
(
out
),
dim
);
});
}
torch_scatter/max.py
View file @
ec8477ea
...
...
@@ -25,8 +25,7 @@ class ScatterMax(Function):
grad_src
=
None
if
ctx
.
needs_input_grad
[
1
]:
grad_src
=
grad_out
.
new_zeros
(
index
.
size
())
func
=
get_func
(
'index_backward'
,
grad_out
)
func
(
grad_out
,
index
,
arg
,
grad_src
,
ctx
.
dim
)
grad_src
.
scatter_
(
ctx
.
dim
,
arg
.
detach
(),
grad_out
)
return
None
,
grad_src
,
None
,
None
...
...
torch_scatter/min.py
View file @
ec8477ea
...
...
@@ -25,8 +25,7 @@ class ScatterMin(Function):
grad_src
=
None
if
ctx
.
needs_input_grad
[
1
]:
grad_src
=
grad_out
.
new_zeros
(
index
.
size
())
func
=
get_func
(
'index_backward'
,
grad_out
)
func
(
grad_out
,
index
,
arg
,
grad_src
,
ctx
.
dim
)
grad_src
.
scatter_
(
ctx
.
dim
,
arg
.
detach
(),
grad_out
)
return
None
,
grad_src
,
None
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment