Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
e032de58
Unverified
Commit
e032de58
authored
Jan 11, 2023
by
Xin Yao
Committed by
GitHub
Jan 11, 2023
Browse files
[Bugfix] Avoid initializing CUDA context at importing (#5134)
* Avoid initializing CUDA at importing * renaming
parent
1a33ee99
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
13 deletions
+18
-13
python/dgl/backend/pytorch/sparse.py
python/dgl/backend/pytorch/sparse.py
+18
-13
No files found.
python/dgl/backend/pytorch/sparse.py
View file @
e032de58
...
...
@@ -141,12 +141,17 @@ class empty_context:
return
# This is to avoid warnings in cpu-only dgl. We don't enable autocast for CPU ops
autocast
=
th
.
cuda
.
amp
.
autocast
if
th
.
cuda
.
is_available
()
else
empty_context
# Disable CUDA autocast since we have casted args manually,
# and do it only in a nested autocast context.
def
_disable_autocast_if_enabled
():
if
th
.
is_autocast_enabled
():
return
th
.
cuda
.
amp
.
autocast
(
enabled
=
False
)
else
:
return
empty_context
()
def
_cast_if_autocast_enabled
(
*
args
):
if
not
th
.
is_autocast_enabled
()
or
not
th
.
cuda
.
is_available
()
:
if
not
th
.
is_autocast_enabled
():
return
args
else
:
return
th
.
cuda
.
amp
.
autocast_mode
.
_cast
(
...
...
@@ -1023,7 +1028,7 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
op
=
"mul"
rhs_data
=
1.0
/
rhs_data
args
=
_cast_if_autocast_enabled
(
gidx
,
op
,
reduce_op
,
lhs_data
,
rhs_data
)
with
autocast
(
enabled
=
False
):
with
_disable_
autocast
_if_
enabled
(
):
return
GSpMM
.
apply
(
*
args
)
...
...
@@ -1037,7 +1042,7 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
args
=
_cast_if_autocast_enabled
(
gidx
,
op
,
lhs_data
,
rhs_data
,
lhs_target
,
rhs_target
)
with
autocast
(
enabled
=
False
):
with
_disable_
autocast
_if_
enabled
(
):
return
GSDDMM
.
apply
(
*
args
)
...
...
@@ -1068,7 +1073,7 @@ def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
args
=
_cast_if_autocast_enabled
(
g
,
op
,
reduce_op
,
lhs_len
,
*
lhs_and_rhs_tuple
)
with
autocast
(
enabled
=
False
):
with
_disable_
autocast
_if_
enabled
(
):
return
GSpMM_hetero
.
apply
(
*
args
)
...
...
@@ -1101,31 +1106,31 @@ def gsddmm_hetero(
args
=
_cast_if_autocast_enabled
(
g
,
op
,
lhs_len
,
lhs_target
,
rhs_target
,
*
lhs_and_rhs_tuple
)
with
autocast
(
enabled
=
False
):
with
_disable_
autocast
_if_
enabled
(
):
return
GSDDMM_hetero
.
apply
(
*
args
)
def
edge_softmax
(
gidx
,
logits
,
eids
=
ALL
,
norm_by
=
"dst"
):
args
=
_cast_if_autocast_enabled
(
gidx
,
logits
,
eids
,
norm_by
)
with
autocast
(
enabled
=
False
):
with
_disable_
autocast
_if_
enabled
(
):
return
EdgeSoftmax
.
apply
(
*
args
)
def
edge_softmax_hetero
(
gidx
,
eids
=
ALL
,
norm_by
=
"dst"
,
*
logits
):
args
=
_cast_if_autocast_enabled
(
gidx
,
eids
,
norm_by
,
*
logits
)
with
autocast
(
enabled
=
False
):
with
_disable_
autocast
_if_
enabled
(
):
return
EdgeSoftmax_hetero
.
apply
(
*
args
)
def
segment_reduce
(
op
,
x
,
offsets
):
args
=
_cast_if_autocast_enabled
(
op
,
x
,
offsets
)
with
autocast
(
enabled
=
False
):
with
_disable_
autocast
_if_
enabled
(
):
return
SegmentReduce
.
apply
(
*
args
)
def
scatter_add
(
x
,
idx
,
m
):
args
=
_cast_if_autocast_enabled
(
x
,
idx
,
m
)
with
autocast
(
enabled
=
False
):
with
_disable_
autocast
_if_
enabled
(
):
return
ScatterAdd
.
apply
(
*
args
)
...
...
@@ -1175,7 +1180,7 @@ def segment_mm(A, B, seglen_A):
return
th
.
cat
(
C
)
else
:
args
=
_cast_if_autocast_enabled
(
A
,
B
,
seglen_A
)
with
autocast
(
enabled
=
False
):
with
_disable_
autocast
_if_
enabled
(
):
return
SEGMENTMM
.
apply
(
*
args
)
...
...
@@ -1186,5 +1191,5 @@ def gather_mm(A, B, idx_A=None, idx_B=None):
return
th
.
bmm
(
A
.
unsqueeze
(
1
),
B
).
squeeze
(
1
)
else
:
args
=
_cast_if_autocast_enabled
(
A
,
B
,
idx_A
,
idx_B
)
with
autocast
(
enabled
=
False
):
with
_disable_
autocast
_if_
enabled
(
):
return
GATHERMM
.
apply
(
*
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment