Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-cluster
Commits
7bb94638
Commit
7bb94638
authored
Mar 09, 2020
by
rusty1s
Browse files
new rw cuda implementation
parent
ba9f2ed2
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
74 additions
and
81 deletions
+74
-81
csrc/cpu/rw_cpu.cpp
csrc/cpu/rw_cpu.cpp
+5
-7
csrc/cuda/grid_cuda.cu
csrc/cuda/grid_cuda.cu
+7
-8
csrc/cuda/rw.cpp
csrc/cuda/rw.cpp
+0
-20
csrc/cuda/rw_cuda.cu
csrc/cuda/rw_cuda.cu
+55
-0
csrc/cuda/rw_cuda.h
csrc/cuda/rw_cuda.h
+7
-0
csrc/cuda/rw_kernel.cu
csrc/cuda/rw_kernel.cu
+0
-46
No files found.
csrc/cpu/rw_cpu.cpp
View file @
7bb94638
...
...
@@ -13,16 +13,12 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
CHECK_INPUT
(
col
.
dim
()
==
1
);
CHECK_INPUT
(
start
.
dim
()
==
1
);
auto
num_nodes
=
rowptr
.
size
(
0
)
-
1
;
auto
deg
=
rowptr
.
narrow
(
0
,
1
,
num_nodes
)
-
rowptr
.
narrow
(
0
,
0
,
num_nodes
);
auto
rand
=
torch
::
rand
({
start
.
size
(
0
),
walk_length
},
start
.
options
().
dtype
(
torch
::
kFloat
));
auto
out
=
torch
::
full
({
start
.
size
(
0
),
walk_length
+
1
},
-
1
,
start
.
options
());
auto
rowptr_data
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
deg_data
=
deg
.
data_ptr
<
int64_t
>
();
auto
col_data
=
col
.
data_ptr
<
int64_t
>
();
auto
start_data
=
start
.
data_ptr
<
int64_t
>
();
auto
rand_data
=
rand
.
data_ptr
<
float
>
();
...
...
@@ -33,10 +29,12 @@ torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
auto
offset
=
n
*
(
walk_length
+
1
);
out_data
[
offset
]
=
cur
;
int64_t
row_start
,
row_end
;
for
(
auto
l
=
1
;
l
<=
walk_length
;
l
++
)
{
cur
=
col_data
[
rowptr_data
[
cur
]
+
int64_t
(
rand_data
[
n
*
walk_length
+
(
l
-
1
)]
*
deg_data
[
cur
])];
row_start
=
rowptr_data
[
cur
],
row_end
=
rowptr_data
[
cur
+
1
];
cur
=
col_data
[
row_start
+
int64_t
(
rand_data
[
n
*
walk_length
+
(
l
-
1
)]
*
(
row_end
-
row_start
))];
out_data
[
offset
+
l
]
=
cur
;
}
}
...
...
csrc/cuda/grid_cuda.cu
View file @
7bb94638
#include "grid_c
p
u.h"
#include "grid_cu
da
.h"
#include <ATen/ATen.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/CUDAContext.h>
#include "
compat
.cuh"
#include "
utils
.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
...
...
@@ -12,7 +10,7 @@
template
<
typename
scalar_t
>
__global__
void
grid_kernel
(
const
scalar_t
*
pos
,
const
scalar_t
*
size
,
const
scalar_t
*
start
,
const
scalar_t
*
end
,
int64_t
*
out
,
int64_t
N
,
int64_t
D
,
int64_t
numel
)
{
int64_t
*
out
,
int64_t
D
,
int64_t
numel
)
{
const
size_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
...
...
@@ -62,11 +60,12 @@ torch::Tensor grid_cpu(torch::Tensor pos, torch::Tensor size,
auto
out
=
torch
::
empty
(
pos
.
size
(
0
),
pos
.
options
().
dtype
(
torch
::
kLong
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_ALL_TYPES
(
pos
.
scalar_type
(),
"grid_kernel"
,
[
&
]
{
grid_kernel
<
scalar_t
><<<
BLOCKS
(
out
.
numel
()),
THREADS
>>>
(
grid_kernel
<
scalar_t
><<<
BLOCKS
(
out
.
numel
()),
THREADS
,
0
,
stream
>>>
(
pos
.
data_ptr
<
scalar_t
>
(),
size
.
data_ptr
<
scalar_t
>
(),
start
.
data_ptr
<
scalar_t
>
(),
end
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int64_t
>
(),
pos
.
size
(
0
),
pos
.
size
(
1
),
out
.
numel
());
out
.
data_ptr
<
int64_t
>
(),
pos
.
size
(
1
),
out
.
numel
());
});
return
out
;
...
...
csrc/cuda/rw.cpp
deleted
100644 → 0
View file @
ba9f2ed2
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
#define IS_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " is not contiguous");
at
::
Tensor
rw_cuda
(
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
start
,
size_t
walk_length
,
float
p
,
float
q
,
size_t
num_nodes
);
at
::
Tensor
rw
(
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
start
,
size_t
walk_length
,
float
p
,
float
q
,
size_t
num_nodes
)
{
CHECK_CUDA
(
row
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
start
);
return
rw_cuda
(
row
,
col
,
start
,
walk_length
,
p
,
q
,
num_nodes
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"rw"
,
&
rw
,
"Random Walk Sampling (CUDA)"
);
}
csrc/cuda/rw_cuda.cu
0 → 100644
View file @
7bb94638
#include "rw_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__
void
uniform_random_walk_kernel
(
const
int64_t
*
rowptr
,
const
int64_t
*
col
,
const
int64_t
*
start
,
const
float
*
rand
,
int64_t
*
out
,
int64_t
walk_length
,
int64_t
numel
)
{
const
int64_t
thread_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
thread_idx
<
numel
)
{
out
[
thread_idx
]
=
start
[
thread_idx
];
int64_t
row_start
,
row_end
,
i
,
cur
;
for
(
int64_t
l
=
1
;
l
<=
walk_length
;
l
++
)
{
i
=
(
l
-
1
)
*
numel
+
thread_idx
;
cur
=
out
[
i
];
row_start
=
rowptr
[
cur
],
row_end
=
rowptr
[
cur
+
1
];
out
[
l
*
numel
+
n
]
=
col
[
row_start
+
int64_t
(
rand
[
i
]
*
(
row_end
-
row_start
))];
}
}
}
torch
::
Tensor
random_walk_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
start
,
int64_t
walk_length
,
double
p
,
double
q
)
{
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
start
);
cudaSetDevice
(
rowptr
.
get_device
());
CHECK_INPUT
(
rowptr
.
dim
()
==
1
);
CHECK_INPUT
(
col
.
dim
()
==
1
);
CHECK_INPUT
(
start
.
dim
()
==
1
);
auto
rand
=
torch
::
rand
({
start
.
size
(
0
),
walk_length
},
start
.
options
().
dtype
(
torch
::
kFloat
));
auto
out
=
torch
::
full
({
walk_length
+
1
,
start
.
size
(
0
)},
-
1
,
start
.
options
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
uniform_random_walk_kernel
<<<
BLOCKS
(
start
.
numel
()),
THREADS
,
0
,
stream
>>>
(
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
start
.
data_ptr
<
int64_t
>
(),
rand
.
data_ptr
<
float
>
(),
out
.
data_ptr
<
int64_t
>
(),
walk_length
,
start
.
numel
());
return
out
.
t
().
contiguous
();
}
csrc/cuda/rw_cuda.h
0 → 100644
View file @
7bb94638
#pragma once
#include <torch/extension.h>
torch
::
Tensor
random_walk_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
start
,
int64_t
walk_length
,
double
p
,
double
q
);
csrc/cuda/rw_kernel.cu
deleted
100644 → 0
View file @
ba9f2ed2
#include <ATen/ATen.h>
#include "compat.cuh"
#include "utils.cuh"
#define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__
void
uniform_rw_kernel
(
const
int64_t
*
__restrict__
row
,
const
int64_t
*
__restrict__
col
,
const
int64_t
*
__restrict__
deg
,
const
int64_t
*
__restrict__
start
,
const
float
*
__restrict__
rand
,
int64_t
*
__restrict__
out
,
const
size_t
walk_length
,
const
size_t
numel
)
{
const
size_t
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
for
(
ptrdiff_t
n
=
index
;
n
<
numel
;
n
+=
stride
)
{
out
[
n
]
=
start
[
n
];
for
(
ptrdiff_t
l
=
1
;
l
<=
walk_length
;
l
++
)
{
auto
i
=
(
l
-
1
)
*
numel
+
n
;
auto
cur
=
out
[
i
];
out
[
l
*
numel
+
n
]
=
col
[
row
[
cur
]
+
int64_t
(
rand
[
i
]
*
deg
[
cur
])];
}
}
}
at
::
Tensor
rw_cuda
(
at
::
Tensor
row
,
at
::
Tensor
col
,
at
::
Tensor
start
,
size_t
walk_length
,
float
p
,
float
q
,
size_t
num_nodes
)
{
cudaSetDevice
(
row
.
get_device
());
auto
deg
=
degree
(
row
,
num_nodes
);
row
=
at
::
cat
({
at
::
zeros
(
1
,
deg
.
options
()),
deg
.
cumsum
(
0
)},
0
);
auto
rand
=
at
::
rand
({(
int64_t
)
walk_length
,
start
.
size
(
0
)},
start
.
options
().
dtype
(
at
::
kFloat
));
auto
out
=
at
::
full
({(
int64_t
)
walk_length
+
1
,
start
.
size
(
0
)},
-
1
,
start
.
options
());
uniform_rw_kernel
<<<
BLOCKS
(
start
.
numel
()),
THREADS
>>>
(
row
.
DATA_PTR
<
int64_t
>
(),
col
.
DATA_PTR
<
int64_t
>
(),
deg
.
DATA_PTR
<
int64_t
>
(),
start
.
DATA_PTR
<
int64_t
>
(),
rand
.
DATA_PTR
<
float
>
(),
out
.
DATA_PTR
<
int64_t
>
(),
walk_length
,
start
.
numel
());
return
out
.
t
().
contiguous
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment