Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-scatter
Commits
66320686
Commit
66320686
authored
Dec 20, 2017
by
rusty1s
Browse files
mul, div, mean
parent
3e06f342
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
3 deletions
+50
-3
torch_scatter/kernel/generic/kernel.cu
torch_scatter/kernel/generic/kernel.cu
+22
-3
torch_scatter/kernel/kernel.cu
torch_scatter/kernel/kernel.cu
+28
-0
No files found.
torch_scatter/kernel/generic/kernel.cu
View file @
66320686
...
@@ -4,17 +4,36 @@
...
@@ -4,17 +4,36 @@
void
scatter_
(
mul
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
)
{
void
scatter_
(
mul
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
)
{
thc_
(
check
)(
state
,
output
,
index
,
input
);
thc_
(
check
)(
state
,
output
,
index
,
input
);
printf
(
"mul"
);
const
int
n
=
THCudaLongTensor_nElement
(
state
,
index
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
KERNEL_RUN
(
mulKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
dim
)
}
}
void
scatter_
(
div
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
)
{
void
scatter_
(
div
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
)
{
thc_
(
check
)(
state
,
output
,
index
,
input
);
thc_
(
check
)(
state
,
output
,
index
,
input
);
printf
(
"div"
);
const
int
n
=
THCudaLongTensor_nElement
(
state
,
index
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
KERNEL_RUN
(
divKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
dim
)
}
}
void
scatter_
(
mean
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCTensor
*
count
)
{
void
scatter_
(
mean
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCTensor
*
count
)
{
thc_
(
check
)(
state
,
output
,
index
,
input
);
thc_
(
check
)(
state
,
output
,
index
,
input
);
printf
(
"mean"
);
const
int
n
=
THCudaLongTensor_nElement
(
state
,
index
);
TensorInfo
<
real
>
outputInfo
=
thc_
(
getTensorInfo
)(
state
,
output
);
TensorInfo
<
int64_t
>
indexInfo
=
thc_getTensorInfo_Long
(
state
,
index
);
TensorInfo
<
real
>
inputInfo
=
thc_
(
getTensorInfo
)(
state
,
input
);
TensorInfo
<
real
>
countInfo
=
thc_
(
getTensorInfo
)(
state
,
count
);
KERNEL_RUN
(
meanKernel
,
indexInfo
.
dims
,
n
,
outputInfo
,
indexInfo
,
inputInfo
,
countInfo
,
dim
)
}
}
void
scatter_
(
max
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCudaLongTensor
*
arg
)
{
void
scatter_
(
max
)(
THCState
*
state
,
int
dim
,
THCTensor
*
output
,
THCudaLongTensor
*
index
,
THCTensor
*
input
,
THCudaLongTensor
*
arg
)
{
...
...
torch_scatter/kernel/kernel.cu
View file @
66320686
...
@@ -13,6 +13,34 @@
...
@@ -13,6 +13,34 @@
#include "generic/common.cu"
#include "generic/common.cu"
#include "THCGenerateAllTypes.h"
#include "THCGenerateAllTypes.h"
template
<
typename
Real
,
int
Dims
>
__global__
void
mulKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;;
IndexToScatterOffsets3
<
Real
,
Real
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
);
atomicMul
(
&
output
.
data
[
outputOffset
],
input
.
data
[
inputOffset
]);
}
}
template
<
typename
Real
,
int
Dims
>
__global__
void
divKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;;
IndexToScatterOffsets3
<
Real
,
Real
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
);
atomicDiv
(
&
output
.
data
[
outputOffset
],
input
.
data
[
inputOffset
]);
}
}
template
<
typename
Real
,
int
Dims
>
__global__
void
meanKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
TensorInfo
<
Real
>
count
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
int
outputOffset
=
0
;
int
indexOffset
=
0
;
int
inputOffset
=
0
;
int
countOffset
=
0
;
IndexToScatterOffsets4
<
Real
,
Real
,
Real
,
Dims
>::
compute
(
i
,
dim
,
index
,
&
indexOffset
,
input
,
&
inputOffset
,
output
,
&
outputOffset
,
count
,
&
countOffset
);
atomicAdd
(
&
output
.
data
[
outputOffset
],
input
.
data
[
inputOffset
]);
atomicAdd
(
&
count
.
data
[
countOffset
],
1
);
}
}
template
<
typename
Real
,
int
Dims
>
template
<
typename
Real
,
int
Dims
>
__global__
void
maxKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
const
int
dim
,
const
int
n
)
{
__global__
void
maxKernel
(
TensorInfo
<
Real
>
output
,
TensorInfo
<
int64_t
>
index
,
TensorInfo
<
Real
>
input
,
const
int
dim
,
const
int
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
KERNEL_LOOP
(
i
,
n
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment