Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
93a58343
Unverified
Commit
93a58343
authored
Dec 29, 2023
by
Muhammed Fatih BALIN
Committed by
GitHub
Dec 29, 2023
Browse files
[GraphBolt][CUDA] Remove unnecessary check and synchronization (#6863)
parent
a2cb2ecd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
18 additions
and
27 deletions
+18
-27
python/dgl/graphbolt/base.py
python/dgl/graphbolt/base.py
+8
-0
python/dgl/graphbolt/internal/sample_utils.py
python/dgl/graphbolt/internal/sample_utils.py
+0
-9
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+2
-2
tests/python/pytorch/graphbolt/test_base.py
tests/python/pytorch/graphbolt/test_base.py
+8
-0
tests/python/pytorch/graphbolt/test_graphbolt_utils.py
tests/python/pytorch/graphbolt/test_graphbolt_utils.py
+0
-16
No files found.
python/dgl/graphbolt/base.py
View file @
93a58343
...
@@ -164,6 +164,14 @@ class CSCFormatBase:
...
@@ -164,6 +164,14 @@ class CSCFormatBase:
indptr
:
torch
.
Tensor
=
None
indptr
:
torch
.
Tensor
=
None
indices
:
torch
.
Tensor
=
None
indices
:
torch
.
Tensor
=
None
def
__init__
(
self
,
indptr
:
torch
.
Tensor
,
indices
:
torch
.
Tensor
):
self
.
indptr
=
indptr
self
.
indices
=
indices
if
not
indptr
.
is_cuda
:
assert
self
.
indptr
[
-
1
]
==
len
(
self
.
indices
),
"The last element of indptr should be the same as the length of indices."
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
_csc_format_base_str
(
self
)
return
_csc_format_base_str
(
self
)
...
...
python/dgl/graphbolt/internal/sample_utils.py
View file @
93a58343
...
@@ -254,9 +254,6 @@ def unique_and_compact_csc_formats(
...
@@ -254,9 +254,6 @@ def unique_and_compact_csc_formats(
for
etype
,
csc_format
in
csc_formats
.
items
():
for
etype
,
csc_format
in
csc_formats
.
items
():
if
device
is
None
:
if
device
is
None
:
device
=
csc_format
.
indices
.
device
device
=
csc_format
.
indices
.
device
assert
csc_format
.
indptr
[
-
1
]
==
len
(
csc_format
.
indices
),
"The last element of indptr should be the same as the length of indices."
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
assert
len
(
unique_dst_nodes
.
get
(
dst_type
,
[]))
+
1
==
len
(
assert
len
(
unique_dst_nodes
.
get
(
dst_type
,
[]))
+
1
==
len
(
csc_format
.
indptr
csc_format
.
indptr
...
@@ -358,9 +355,6 @@ def compact_csc_format(
...
@@ -358,9 +355,6 @@ def compact_csc_format(
assert
isinstance
(
assert
isinstance
(
dst_nodes
,
torch
.
Tensor
dst_nodes
,
torch
.
Tensor
),
"Edge type not supported in homogeneous graph."
),
"Edge type not supported in homogeneous graph."
assert
csc_formats
.
indptr
[
-
1
]
==
len
(
csc_formats
.
indices
),
"The last element of indptr should be the same as the length of indices."
assert
len
(
dst_nodes
)
+
1
==
len
(
assert
len
(
dst_nodes
)
+
1
==
len
(
csc_formats
.
indptr
csc_formats
.
indptr
),
"The seed nodes should correspond to indptr."
),
"The seed nodes should correspond to indptr."
...
@@ -381,9 +375,6 @@ def compact_csc_format(
...
@@ -381,9 +375,6 @@ def compact_csc_format(
compacted_csc_formats
=
{}
compacted_csc_formats
=
{}
original_row_ids
=
copy
.
deepcopy
(
dst_nodes
)
original_row_ids
=
copy
.
deepcopy
(
dst_nodes
)
for
etype
,
csc_format
in
csc_formats
.
items
():
for
etype
,
csc_format
in
csc_formats
.
items
():
assert
csc_format
.
indptr
[
-
1
]
==
len
(
csc_format
.
indices
),
"The last element of indptr should be the same as the length of indices."
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
src_type
,
_
,
dst_type
=
etype_str_to_tuple
(
etype
)
assert
len
(
dst_nodes
.
get
(
dst_type
,
[]))
+
1
==
len
(
assert
len
(
dst_nodes
.
get
(
dst_type
,
[]))
+
1
==
len
(
csc_format
.
indptr
csc_format
.
indptr
...
...
python/dgl/graphbolt/minibatch.py
View file @
93a58343
...
@@ -202,7 +202,7 @@ class MiniBatch:
...
@@ -202,7 +202,7 @@ class MiniBatch:
v
.
indices
,
v
.
indices
,
torch
.
arange
(
torch
.
arange
(
0
,
0
,
v
.
ind
ptr
[
-
1
]
,
len
(
v
.
ind
ices
)
,
device
=
v
.
indptr
.
device
,
device
=
v
.
indptr
.
device
,
dtype
=
v
.
indptr
.
dtype
,
dtype
=
v
.
indptr
.
dtype
,
),
),
...
@@ -227,7 +227,7 @@ class MiniBatch:
...
@@ -227,7 +227,7 @@ class MiniBatch:
sampled_csc
.
indices
,
sampled_csc
.
indices
,
torch
.
arange
(
torch
.
arange
(
0
,
0
,
sampled_csc
.
ind
ptr
[
-
1
]
,
len
(
sampled_csc
.
ind
ices
)
,
device
=
sampled_csc
.
indptr
.
device
,
device
=
sampled_csc
.
indptr
.
device
,
dtype
=
sampled_csc
.
indptr
.
dtype
,
dtype
=
sampled_csc
.
indptr
.
dtype
,
),
),
...
...
tests/python/pytorch/graphbolt/test_base.py
View file @
93a58343
...
@@ -244,3 +244,11 @@ def test_csc_format_base_representation():
...
@@ -244,3 +244,11 @@ def test_csc_format_base_representation():
)"""
)"""
)
)
assert
str
(
csc_format_base
)
==
expected_result
,
print
(
csc_format_base
)
assert
str
(
csc_format_base
)
==
expected_result
,
print
(
csc_format_base
)
def
test_csc_format_base_incorrect_indptr
():
indptr
=
torch
.
tensor
([
0
,
2
,
4
,
6
,
7
,
11
])
indices
=
torch
.
tensor
([
2
,
3
,
1
,
4
,
5
,
2
,
5
,
1
,
4
,
4
])
with
pytest
.
raises
(
AssertionError
):
# The value of last element in indptr is not corresponding to indices.
csc_formats
=
gb
.
CSCFormatBase
(
indptr
=
indptr
,
indices
=
indices
)
tests/python/pytorch/graphbolt/test_graphbolt_utils.py
View file @
93a58343
...
@@ -350,14 +350,6 @@ def test_unique_and_compact_incorrect_indptr():
...
@@ -350,14 +350,6 @@ def test_unique_and_compact_incorrect_indptr():
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
gb
.
unique_and_compact_csc_formats
(
csc_formats
,
seeds
)
gb
.
unique_and_compact_csc_formats
(
csc_formats
,
seeds
)
seeds
=
torch
.
tensor
([
1
,
3
,
5
,
2
,
6
])
indptr
=
torch
.
tensor
([
0
,
2
,
4
,
6
,
7
,
11
])
indices
=
torch
.
tensor
([
2
,
3
,
1
,
4
,
5
,
2
,
5
,
1
,
4
,
4
])
csc_formats
=
gb
.
CSCFormatBase
(
indptr
=
indptr
,
indices
=
indices
)
# The value of last element in indptr is not corresponding to indices.
with
pytest
.
raises
(
AssertionError
):
gb
.
unique_and_compact_csc_formats
(
csc_formats
,
seeds
)
def
test_compact_csc_format_hetero
():
def
test_compact_csc_format_hetero
():
dst_nodes
=
{
dst_nodes
=
{
...
@@ -449,11 +441,3 @@ def test_compact_incorrect_indptr():
...
@@ -449,11 +441,3 @@ def test_compact_incorrect_indptr():
# The number of seeds is not corresponding to indptr.
# The number of seeds is not corresponding to indptr.
with
pytest
.
raises
(
AssertionError
):
with
pytest
.
raises
(
AssertionError
):
gb
.
compact_csc_format
(
csc_formats
,
seeds
)
gb
.
compact_csc_format
(
csc_formats
,
seeds
)
seeds
=
torch
.
tensor
([
1
,
3
,
5
,
2
,
6
])
indptr
=
torch
.
tensor
([
0
,
2
,
4
,
6
,
7
,
11
])
indices
=
torch
.
tensor
([
2
,
3
,
1
,
4
,
5
,
2
,
5
,
1
,
4
,
4
])
csc_formats
=
gb
.
CSCFormatBase
(
indptr
=
indptr
,
indices
=
indices
)
# The value of last element in indptr is not corresponding to indices.
with
pytest
.
raises
(
AssertionError
):
gb
.
compact_csc_format
(
csc_formats
,
seeds
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment