Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torch-sparse
Commits
73a89efb
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "86da45bc66ee7ca782d5f498b0cacb490051d0f6"
Commit
73a89efb
authored
Jun 12, 2020
by
rusty1s
Browse files
matmul overload
parent
6a992437
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
19 additions
and
537 deletions
+19
-537
.gitignore
.gitignore
+1
-0
csrc/cpu/padding_cpu.cpp
csrc/cpu/padding_cpu.cpp
+0
-129
csrc/cpu/padding_cpu.h
csrc/cpu/padding_cpu.h
+0
-14
csrc/cuda/padding_cuda.cu
csrc/cuda/padding_cuda.cu
+0
-292
csrc/cuda/padding_cuda.h
csrc/cuda/padding_cuda.h
+0
-14
csrc/padding.cpp
csrc/padding.cpp
+0
-80
torch_sparse/__init__.py
torch_sparse/__init__.py
+1
-1
torch_sparse/matmul.py
torch_sparse/matmul.py
+15
-5
torch_sparse/tensor.py
torch_sparse/tensor.py
+2
-2
No files found.
.gitignore
View file @
73a89efb
__pycache__/
__pycache__/
build/
build/
dist/
dist/
alpha/
.cache/
.cache/
.eggs/
.eggs/
*.egg-info/
*.egg-info/
...
...
csrc/cpu/padding_cpu.cpp
deleted
100644 → 0
View file @
6a992437
#include "padding_cpu.h"
#include "utils.h"
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
padded_index_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
)
{
CHECK_CPU
(
rowptr
);
CHECK_CPU
(
col
);
CHECK_CPU
(
rowcount
);
CHECK_CPU
(
binptr
);
CHECK_INPUT
(
rowptr
.
numel
()
==
rowcount
.
numel
()
+
1
);
ptrdiff_t
B
=
binptr
.
numel
()
-
1
;
ptrdiff_t
N
=
rowcount
.
numel
();
auto
rowptr_data
=
rowptr
.
data_ptr
<
int64_t
>
();
auto
col_data
=
col
.
data_ptr
<
int64_t
>
();
auto
rowcount_data
=
rowcount
.
data_ptr
<
int64_t
>
();
auto
binptr_data
=
binptr
.
data_ptr
<
int64_t
>
();
auto
bin
=
torch
::
empty
(
N
,
col
.
options
());
auto
bin_data
=
bin
.
data_ptr
<
int64_t
>
();
auto
idx
=
torch
::
empty
(
N
,
col
.
options
());
auto
idx_data
=
idx
.
data_ptr
<
int64_t
>
();
std
::
vector
<
int64_t
>
node_sizes
(
B
),
edge_sizes
(
B
),
max_degs
(
B
),
node_offsets
(
B
+
1
),
edge_offsets
(
B
+
1
);
int64_t
deg
,
bin_idx
=
-
1
;
for
(
ptrdiff_t
n
=
0
;
n
<
N
;
n
++
)
{
deg
=
rowcount_data
[
n
];
for
(
ptrdiff_t
b
=
1
;
b
<=
B
;
b
++
)
{
if
(
deg
<
binptr_data
[
b
])
{
bin_idx
=
b
-
1
;
break
;
}
}
if
(
bin_idx
==
-
1
)
{
bin_idx
=
B
-
1
;
}
bin_data
[
n
]
=
bin_idx
;
idx_data
[
n
]
=
node_sizes
[
bin_idx
];
node_sizes
[
bin_idx
]
+=
1
;
max_degs
[
bin_idx
]
=
std
::
max
(
max_degs
[
bin_idx
],
deg
);
}
for
(
ptrdiff_t
b
=
0
;
b
<
B
;
b
++
)
{
edge_sizes
[
b
]
=
node_sizes
[
b
]
*
max_degs
[
b
];
node_offsets
[
b
+
1
]
=
node_offsets
[
b
]
+
node_sizes
[
b
];
edge_offsets
[
b
+
1
]
=
edge_offsets
[
b
]
+
edge_sizes
[
b
];
}
auto
node_perm
=
torch
::
empty
(
N
,
col
.
options
());
auto
node_perm_data
=
node_perm
.
data_ptr
<
int64_t
>
();
auto
E
=
edge_offsets
[
B
];
auto
row_perm
=
torch
::
empty
(
E
,
col
.
options
());
auto
row_perm_data
=
row_perm
.
data_ptr
<
int64_t
>
();
auto
col_perm
=
torch
::
empty
(
E
,
col
.
options
());
auto
col_perm_data
=
col_perm
.
data_ptr
<
int64_t
>
();
auto
edge_mask
=
torch
::
empty
(
E
,
col
.
options
().
dtype
(
torch
::
kBool
));
auto
edge_mask_data
=
edge_mask
.
data_ptr
<
bool
>
();
int64_t
row_start
=
rowptr_data
[
0
],
row_end
,
edge_offset
,
offset
;
for
(
ptrdiff_t
n
=
0
;
n
<
N
;
n
++
)
{
bin_idx
=
bin_data
[
n
];
offset
=
idx_data
[
n
];
node_perm_data
[
node_offsets
[
bin_idx
]
+
offset
]
=
n
;
row_end
=
rowptr_data
[
n
+
1
];
edge_offset
=
edge_offsets
[
bin_idx
]
+
offset
*
max_degs
[
bin_idx
];
for
(
ptrdiff_t
e
=
0
;
e
<
row_end
-
row_start
;
e
++
)
{
row_perm_data
[
edge_offset
+
e
]
=
n
;
col_perm_data
[
edge_offset
+
e
]
=
col_data
[
row_start
+
e
];
edge_mask_data
[
edge_offset
+
e
]
=
false
;
}
for
(
ptrdiff_t
e
=
row_end
-
row_start
;
e
<
max_degs
[
bin_data
[
n
]];
e
++
)
{
row_perm_data
[
edge_offset
+
e
]
=
-
1
;
col_perm_data
[
edge_offset
+
e
]
=
-
1
;
edge_mask_data
[
edge_offset
+
e
]
=
true
;
}
row_start
=
row_end
;
}
return
std
::
make_tuple
(
node_perm
,
row_perm
,
col_perm
,
edge_mask
,
node_sizes
,
edge_sizes
);
}
torch
::
Tensor
padded_index_select_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
CHECK_INPUT
(
src
.
dim
()
==
2
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
auto
mask
=
index
==
-
1
;
auto
out
=
src
.
index_select
(
0
,
index
.
masked_fill
(
mask
,
0
));
out
.
masked_fill_
(
mask
.
view
({
-
1
,
1
}).
expand_as
(
out
),
fill_value
);
return
out
;
}
torch
::
Tensor
padded_index_scatter_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
N
)
{
CHECK_CPU
(
src
);
CHECK_CPU
(
index
);
CHECK_INPUT
(
src
.
dim
()
==
2
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
CHECK_INPUT
(
src
.
size
(
0
)
==
index
.
size
(
0
));
auto
mask
=
index
==
-
1
;
index
=
index
.
masked_fill
(
mask
,
N
);
auto
out
=
torch
::
zeros
({
N
+
1
,
src
.
size
(
-
1
)},
src
.
options
());
out
.
scatter_add_
(
0
,
index
.
view
({
-
1
,
1
}).
expand_as
(
src
),
src
);
out
=
out
.
narrow
(
0
,
0
,
N
);
return
out
;
}
csrc/cpu/padding_cpu.h
deleted
100644 → 0
View file @
6a992437
#pragma once
#include <torch/extension.h>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
padded_index_cpu
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
);
torch
::
Tensor
padded_index_select_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
);
torch
::
Tensor
padded_index_scatter_cpu
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
N
);
csrc/cuda/padding_cuda.cu
deleted
100644 → 0
View file @
6a992437
#include "padding_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include "atomics.cuh"
#include "utils.cuh"
#define THREADS 1024
#define FULL_MASK 0xffffffff
#define BLOCKS(N) (N + THREADS - 1) / THREADS
__global__
void
bin_kernel
(
const
int64_t
*
__restrict__
rowcount
,
const
int64_t
*
__restrict__
binptr
,
int64_t
*
__restrict__
bin
,
int64_t
*
__restrict__
idx
,
int
*
__restrict__
node_size
,
int
*
__restrict__
max_deg
,
const
size_t
B
,
const
size_t
N
)
{
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
N
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
bin_idx
=
-
1
,
deg
=
rowcount
[
thread_idx
];
for
(
ptrdiff_t
b
=
1
;
b
<=
B
;
b
++
)
{
if
(
deg
<
__ldg
(
binptr
+
b
))
{
bin_idx
=
b
-
1
;
break
;
}
}
if
(
bin_idx
==
-
1
)
{
bin_idx
=
B
-
1
;
}
int
old
=
atomicAdd
(
node_size
+
bin_idx
,
1
);
atomicMax
(
max_deg
+
bin_idx
,
deg
);
bin
[
thread_idx
]
=
bin_idx
;
idx
[
thread_idx
]
=
old
;
}
}
__global__
void
info_kernel
(
const
int
*
__restrict__
node_size
,
const
int
*
__restrict__
max_deg
,
int
*
__restrict__
edge_size
,
int
*
__restrict__
node_offset
,
int
*
__restrict__
edge_offset
,
const
size_t
B
)
{
int
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int
bin_idx
=
thread_idx
/
32
;
int
lane_idx
=
thread_idx
%
32
;
if
(
bin_idx
<=
B
)
{
// Computes `node_offset` and `edge_offset`.
int
node_tmp
=
0
;
int
edge_tmp
=
0
;
for
(
int
i
=
lane_idx
;
i
<
bin_idx
;
i
+=
32
)
{
node_tmp
+=
node_size
[
i
];
edge_tmp
+=
node_size
[
i
]
*
max_deg
[
i
];
}
for
(
int
i
=
32
/
2
;
i
>
0
;
i
/=
2
)
{
node_tmp
+=
__shfl_down_sync
(
FULL_MASK
,
node_tmp
,
i
);
edge_tmp
+=
__shfl_down_sync
(
FULL_MASK
,
edge_tmp
,
i
);
}
if
(
lane_idx
==
0
)
{
node_offset
[
bin_idx
]
=
node_tmp
;
edge_offset
[
bin_idx
]
=
edge_tmp
;
}
}
else
if
(
bin_idx
==
B
+
1
)
{
// Computes `edge_size`.
for
(
int
i
=
lane_idx
;
i
<
B
;
i
+=
32
)
{
edge_size
[
i
]
=
node_size
[
i
]
*
max_deg
[
i
];
}
}
}
__global__
void
node_perm_kernel
(
const
int64_t
*
__restrict__
bin
,
const
int64_t
*
__restrict__
idx
,
const
int
*
__restrict__
node_offset
,
int64_t
*
__restrict__
out
,
const
size_t
N
)
{
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
N
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
out
[
__ldg
(
node_offset
+
bin
[
thread_idx
])
+
idx
[
thread_idx
]]
=
thread_idx
;
}
}
template
<
int
TB
>
__global__
void
padded_index_kernel
(
const
int64_t
*
__restrict__
rowptr
,
const
int64_t
*
__restrict__
col
,
const
int64_t
*
__restrict__
rowcount
,
const
int64_t
*
__restrict__
bin
,
const
int64_t
*
__restrict__
idx
,
const
int
*
__restrict__
max_deg
,
const
int
*
__restrict__
edge_offset
,
int64_t
*
__restrict__
row_perm
,
int64_t
*
__restrict__
col_perm
,
bool
*
__restrict__
edge_mask
,
const
size_t
B
,
const
size_t
N
)
{
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
TB
*
N
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
row_idx
=
thread_idx
/
TB
;
int
lane_idx
=
thread_idx
%
TB
;
int64_t
bin_idx
=
bin
[
row_idx
];
int
len
=
__ldg
(
max_deg
+
bin_idx
);
int
off
=
__ldg
(
edge_offset
+
bin_idx
)
+
len
*
idx
[
row_idx
];
int64_t
row_start
=
rowptr
[
row_idx
],
deg
=
rowcount
[
row_idx
];
int64_t
row_tmp
,
col_tmp
;
for
(
int
i
=
lane_idx
;
i
<
len
;
i
+=
TB
)
{
row_tmp
=
-
1
,
col_tmp
=
-
1
;
if
(
i
<
deg
)
{
row_tmp
=
row_idx
;
col_tmp
=
col
[
row_start
+
i
];
}
row_perm
[
off
+
i
]
=
row_tmp
;
col_perm
[
off
+
i
]
=
col_tmp
;
edge_mask
[
off
+
i
]
=
row_tmp
==
-
1
;
}
}
}
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
padded_index_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
)
{
CHECK_CUDA
(
rowptr
);
CHECK_CUDA
(
col
);
CHECK_CUDA
(
rowcount
);
CHECK_CUDA
(
binptr
);
CHECK_INPUT
(
rowptr
.
numel
()
==
rowcount
.
numel
()
+
1
);
cudaSetDevice
(
rowcount
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
size_t
B
=
binptr
.
numel
()
-
1
;
size_t
N
=
rowcount
.
numel
();
auto
bin
=
torch
::
empty
(
N
,
col
.
options
());
auto
idx
=
torch
::
empty
(
N
,
col
.
options
());
auto
d_info
=
torch
::
zeros
(
5
*
B
+
2
,
col
.
options
().
dtype
(
torch
::
kInt
));
auto
d_node_size
=
d_info
.
narrow
(
0
,
0
,
B
);
auto
d_edge_size
=
d_info
.
narrow
(
0
,
B
,
B
);
auto
d_max_deg
=
d_info
.
narrow
(
0
,
2
*
B
,
B
);
auto
d_node_offset
=
d_info
.
narrow
(
0
,
3
*
B
,
B
+
1
);
auto
d_edge_offset
=
d_info
.
narrow
(
0
,
4
*
B
+
1
,
B
+
1
);
bin_kernel
<<<
std
::
min
(
BLOCKS
(
N
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
rowcount
.
data_ptr
<
int64_t
>
(),
binptr
.
data_ptr
<
int64_t
>
(),
bin
.
data_ptr
<
int64_t
>
(),
idx
.
data_ptr
<
int64_t
>
(),
d_node_size
.
data_ptr
<
int
>
(),
d_max_deg
.
data_ptr
<
int
>
(),
B
,
N
);
info_kernel
<<<
BLOCKS
(
32
*
(
B
+
2
)),
THREADS
,
0
,
stream
>>>
(
d_node_size
.
data_ptr
<
int
>
(),
d_max_deg
.
data_ptr
<
int
>
(),
d_edge_size
.
data_ptr
<
int
>
(),
d_node_offset
.
data_ptr
<
int
>
(),
d_edge_offset
.
data_ptr
<
int
>
(),
B
);
auto
node_perm
=
torch
::
empty
(
N
,
col
.
options
());
node_perm_kernel
<<<
std
::
min
(
BLOCKS
(
N
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
bin
.
data_ptr
<
int64_t
>
(),
idx
.
data_ptr
<
int64_t
>
(),
d_node_offset
.
data_ptr
<
int
>
(),
node_perm
.
data_ptr
<
int64_t
>
(),
N
);
auto
h_info
=
torch
::
empty
(
d_info
.
numel
(),
d_info
.
options
().
device
(
torch
::
kCPU
).
pinned_memory
(
true
));
cudaMemcpy
(
h_info
.
data_ptr
<
int
>
(),
d_info
.
data_ptr
<
int
>
(),
d_info
.
numel
()
*
sizeof
(
int
),
cudaMemcpyDeviceToHost
);
size_t
E
=
h_info
.
data_ptr
<
int
>
()[
5
*
B
+
1
];
auto
row_perm
=
torch
::
empty
(
E
,
col
.
options
());
auto
col_perm
=
torch
::
empty
(
E
,
col
.
options
());
auto
edge_mask
=
torch
::
empty
(
E
,
col
.
options
().
dtype
(
torch
::
kBool
));
padded_index_kernel
<
8
>
<<<
std
::
min
(
BLOCKS
(
N
*
8
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
rowptr
.
data_ptr
<
int64_t
>
(),
col
.
data_ptr
<
int64_t
>
(),
rowcount
.
data_ptr
<
int64_t
>
(),
bin
.
data_ptr
<
int64_t
>
(),
idx
.
data_ptr
<
int64_t
>
(),
d_max_deg
.
data_ptr
<
int
>
(),
d_edge_offset
.
data_ptr
<
int
>
(),
row_perm
.
data_ptr
<
int64_t
>
(),
col_perm
.
data_ptr
<
int64_t
>
(),
edge_mask
.
data_ptr
<
bool
>
(),
B
,
N
);
h_info
=
h_info
.
to
(
torch
::
kLong
);
auto
h_info_data
=
h_info
.
data_ptr
<
int64_t
>
();
std
::
vector
<
int64_t
>
node_sizes
(
h_info_data
,
h_info_data
+
B
);
std
::
vector
<
int64_t
>
edge_sizes
(
h_info_data
+
B
,
h_info_data
+
2
*
B
);
return
std
::
make_tuple
(
node_perm
,
row_perm
,
col_perm
,
edge_mask
,
node_sizes
,
edge_sizes
);
}
template
<
typename
scalar_t
>
__global__
void
padded_index_select_kernel
(
const
scalar_t
*
__restrict__
src
,
const
int64_t
*
__restrict__
index
,
scalar_t
*
__restrict__
out
,
const
scalar_t
fill_value
,
const
size_t
E
,
const
size_t
F
)
{
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
E
*
F
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
int64_t
row_idx
=
thread_idx
/
F
;
int64_t
lane_idx
=
thread_idx
%
F
;
int64_t
index_idx
=
__ldg
(
index
+
row_idx
);
scalar_t
tmp
=
fill_value
;
if
(
index_idx
!=
-
1
)
{
tmp
=
src
[
index_idx
*
F
+
lane_idx
];
}
out
[
thread_idx
]
=
tmp
;
}
}
torch
::
Tensor
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_INPUT
(
src
.
dim
()
==
2
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
cudaSetDevice
(
src
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
size_t
E
=
index
.
numel
();
size_t
F
=
src
.
size
(
-
1
);
auto
out
=
torch
::
empty
({(
int
)
E
,
(
int
)
F
},
src
.
options
());
AT_DISPATCH_ALL_TYPES
(
src
.
scalar_type
(),
"padded_index_select_kernel"
,
[
&
]
{
scalar_t
*
fill
;
if
(
fill_value
.
is_cuda
())
{
fill
=
(
scalar_t
*
)
malloc
(
sizeof
(
scalar_t
));
cudaMemcpy
(
fill
,
fill_value
.
data_ptr
<
scalar_t
>
(),
sizeof
(
scalar_t
),
cudaMemcpyDeviceToHost
);
}
else
{
fill
=
fill_value
.
data_ptr
<
scalar_t
>
();
}
padded_index_select_kernel
<
scalar_t
>
<<<
std
::
min
(
BLOCKS
(
E
*
F
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
src
.
data_ptr
<
scalar_t
>
(),
index
.
data_ptr
<
int64_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
fill
[
0
],
E
,
F
);
});
return
out
;
}
template
<
typename
scalar_t
>
__global__
void
padded_index_scatter_kernel
(
const
scalar_t
*
__restrict__
src
,
const
int64_t
*
__restrict__
index
,
scalar_t
*
__restrict__
out
,
const
size_t
E
,
const
size_t
F
)
{
for
(
ptrdiff_t
thread_idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
thread_idx
<
E
*
F
;
thread_idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
int64_t
index_idx
=
__ldg
(
index
+
thread_idx
/
F
);
if
(
index_idx
!=
-
1
)
{
atomAdd
(
out
+
index_idx
*
F
+
thread_idx
%
F
,
src
[
thread_idx
]);
}
}
}
torch
::
Tensor
padded_index_scatter_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
N
)
{
CHECK_CUDA
(
src
);
CHECK_CUDA
(
index
);
CHECK_INPUT
(
src
.
dim
()
==
2
);
CHECK_INPUT
(
index
.
dim
()
==
1
);
CHECK_INPUT
(
src
.
size
(
0
)
==
index
.
size
(
0
));
cudaSetDevice
(
src
.
get_device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
mpc
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
;
size_t
E
=
index
.
numel
();
size_t
F
=
src
.
size
(
-
1
);
auto
out
=
torch
::
zeros
({
N
,
(
int
)
F
},
src
.
options
());
AT_DISPATCH_FLOATING_TYPES
(
src
.
scalar_type
(),
"padded_index_scatter_kernel"
,
[
&
]
{
padded_index_scatter_kernel
<
scalar_t
>
<<<
std
::
min
(
BLOCKS
(
E
*
F
),
mpc
*
8
),
THREADS
,
0
,
stream
>>>
(
src
.
data_ptr
<
scalar_t
>
(),
index
.
data_ptr
<
int64_t
>
(),
out
.
data_ptr
<
scalar_t
>
(),
E
,
F
);
});
return
out
;
}
csrc/cuda/padding_cuda.h
deleted
100644 → 0
View file @
6a992437
#pragma once
#include <torch/extension.h>
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
padded_index_cuda
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
);
torch
::
Tensor
padded_index_select_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
);
torch
::
Tensor
padded_index_scatter_cuda
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
int64_t
N
);
csrc/padding.cpp
deleted
100644 → 0
View file @
6a992437
#include <Python.h>
#include <torch/script.h>
#include "cpu/padding_cpu.h"
#ifdef WITH_CUDA
#include "cuda/padding_cuda.h"
#endif
#ifdef _WIN32
PyMODINIT_FUNC
PyInit__padding
(
void
)
{
return
NULL
;
}
#endif
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
torch
::
Tensor
,
std
::
vector
<
int64_t
>
,
std
::
vector
<
int64_t
>>
padded_index
(
torch
::
Tensor
rowptr
,
torch
::
Tensor
col
,
torch
::
Tensor
rowcount
,
torch
::
Tensor
binptr
)
{
if
(
rowptr
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
return
padded_index_cuda
(
rowptr
,
col
,
rowcount
,
binptr
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
return
padded_index_cpu
(
rowptr
,
col
,
rowcount
,
binptr
);
}
}
using
torch
::
autograd
::
AutogradContext
;
using
torch
::
autograd
::
Variable
;
using
torch
::
autograd
::
variable_list
;
class
PaddedIndexSelect
:
public
torch
::
autograd
::
Function
<
PaddedIndexSelect
>
{
public:
static
variable_list
forward
(
AutogradContext
*
ctx
,
Variable
src
,
Variable
index
,
Variable
fill_value
)
{
ctx
->
saved_data
[
"N"
]
=
src
.
size
(
0
);
torch
::
Tensor
out
;
if
(
src
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
out
=
padded_index_select_cuda
(
src
,
index
,
fill_value
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
out
=
padded_index_select_cpu
(
src
,
index
,
fill_value
);
}
ctx
->
save_for_backward
({
index
});
return
{
out
};
}
static
variable_list
backward
(
AutogradContext
*
ctx
,
variable_list
grad_outs
)
{
auto
grad_out
=
grad_outs
[
0
];
auto
saved
=
ctx
->
get_saved_variables
();
auto
index
=
saved
[
0
];
auto
N
=
ctx
->
saved_data
[
"N"
].
toInt
();
torch
::
Tensor
grad_in
;
if
(
grad_out
.
device
().
is_cuda
())
{
#ifdef WITH_CUDA
grad_in
=
padded_index_scatter_cuda
(
grad_out
,
index
,
N
);
#else
AT_ERROR
(
"Not compiled with CUDA support"
);
#endif
}
else
{
grad_in
=
padded_index_scatter_cpu
(
grad_out
,
index
,
N
);
}
return
{
grad_in
,
Variable
(),
Variable
()};
}
};
torch
::
Tensor
padded_index_select
(
torch
::
Tensor
src
,
torch
::
Tensor
index
,
torch
::
Tensor
fill_value
)
{
return
PaddedIndexSelect
::
apply
(
src
,
index
,
fill_value
)[
0
];
}
static
auto
registry
=
torch
::
RegisterOperators
()
.
op
(
"torch_sparse::padded_index"
,
&
padded_index
)
.
op
(
"torch_sparse::padded_index_select"
,
&
padded_index_select
);
torch_sparse/__init__.py
View file @
73a89efb
...
@@ -7,7 +7,7 @@ __version__ = '0.6.4'
...
@@ -7,7 +7,7 @@ __version__ = '0.6.4'
for
library
in
[
for
library
in
[
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_spspmm'
,
'_metis'
,
'_rw'
,
'_version'
,
'_convert'
,
'_diag'
,
'_spmm'
,
'_spspmm'
,
'_metis'
,
'_rw'
,
'_saint'
,
'_padding'
,
'_sample'
'_saint'
,
'_sample'
]:
]:
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
torch
.
ops
.
load_library
(
importlib
.
machinery
.
PathFinder
().
find_spec
(
library
,
[
osp
.
dirname
(
__file__
)]).
origin
)
library
,
[
osp
.
dirname
(
__file__
)]).
origin
)
...
...
torch_sparse/matmul.py
View file @
73a89efb
...
@@ -99,14 +99,24 @@ def spspmm(src: SparseTensor, other: SparseTensor,
...
@@ -99,14 +99,24 @@ def spspmm(src: SparseTensor, other: SparseTensor,
raise
ValueError
raise
ValueError
def
matmul
(
src
:
SparseTensor
,
other
:
Union
[
torch
.
Tensor
,
SparseTensor
],
@
torch
.
jit
.
_overload
# noqa: F811
reduce
:
str
=
"sum"
):
def
matmul
(
src
,
other
,
reduce
):
if
torch
.
is_tensor
(
other
):
# type: (SparseTensor, torch.Tensor, str) -> torch.Tensor
pass
@
torch
.
jit
.
_overload
# noqa: F811
def
matmul
(
src
,
other
,
reduce
):
# type: (SparseTensor, SparseTensor, str) -> SparseTensor
pass
def
matmul
(
src
,
other
,
reduce
=
"sum"
):
# noqa: F811
if
isinstance
(
other
,
torch
.
Tensor
):
return
spmm
(
src
,
other
,
reduce
)
return
spmm
(
src
,
other
,
reduce
)
elif
isinstance
(
other
,
SparseTensor
):
elif
isinstance
(
other
,
SparseTensor
):
return
spspmm
(
src
,
other
,
reduce
)
return
spspmm
(
src
,
other
,
reduce
)
else
:
raise
ValueError
raise
ValueError
SparseTensor
.
spmm
=
lambda
self
,
other
,
reduce
=
"sum"
:
spmm
(
self
,
other
,
reduce
)
SparseTensor
.
spmm
=
lambda
self
,
other
,
reduce
=
"sum"
:
spmm
(
self
,
other
,
reduce
)
...
...
torch_sparse/tensor.py
View file @
73a89efb
...
@@ -540,8 +540,8 @@ SparseTensor.__repr__ = __repr__
...
@@ -540,8 +540,8 @@ SparseTensor.__repr__ = __repr__
# Scipy Conversions ###########################################################
# Scipy Conversions ###########################################################
ScipySparseMatrix
=
Union
[
scipy
.
sparse
.
coo_matrix
,
scipy
.
sparse
.
ScipySparseMatrix
=
Union
[
scipy
.
sparse
.
coo_matrix
,
scipy
.
sparse
.
csr_matrix
,
csr_matrix
,
scipy
.
sparse
.
csc_matrix
]
scipy
.
sparse
.
csc_matrix
]
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment