Unverified Commit e032de58 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix] Avoid initializing CUDA context at importing (#5134)

* Avoid initializing CUDA at importing

* renaming
parent 1a33ee99
......@@ -141,12 +141,17 @@ class empty_context:
return
# This is to avoid warnings in cpu-only dgl. We don't enable autocast for CPU ops
autocast = th.cuda.amp.autocast if th.cuda.is_available() else empty_context
# Disable CUDA autocast since we have casted args manually,
# and do it only in a nested autocast context.
def _disable_autocast_if_enabled():
if th.is_autocast_enabled():
return th.cuda.amp.autocast(enabled=False)
else:
return empty_context()
def _cast_if_autocast_enabled(*args):
if not th.is_autocast_enabled() or not th.cuda.is_available():
if not th.is_autocast_enabled():
return args
else:
return th.cuda.amp.autocast_mode._cast(
......@@ -1023,7 +1028,7 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
op = "mul"
rhs_data = 1.0 / rhs_data
args = _cast_if_autocast_enabled(gidx, op, reduce_op, lhs_data, rhs_data)
with autocast(enabled=False):
with _disable_autocast_if_enabled():
return GSpMM.apply(*args)
......@@ -1037,7 +1042,7 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
args = _cast_if_autocast_enabled(
gidx, op, lhs_data, rhs_data, lhs_target, rhs_target
)
with autocast(enabled=False):
with _disable_autocast_if_enabled():
return GSDDMM.apply(*args)
......@@ -1068,7 +1073,7 @@ def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
args = _cast_if_autocast_enabled(
g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple
)
with autocast(enabled=False):
with _disable_autocast_if_enabled():
return GSpMM_hetero.apply(*args)
......@@ -1101,31 +1106,31 @@ def gsddmm_hetero(
args = _cast_if_autocast_enabled(
g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple
)
with autocast(enabled=False):
with _disable_autocast_if_enabled():
return GSDDMM_hetero.apply(*args)
def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"):
args = _cast_if_autocast_enabled(gidx, logits, eids, norm_by)
with autocast(enabled=False):
with _disable_autocast_if_enabled():
return EdgeSoftmax.apply(*args)
def edge_softmax_hetero(gidx, eids=ALL, norm_by="dst", *logits):
args = _cast_if_autocast_enabled(gidx, eids, norm_by, *logits)
with autocast(enabled=False):
with _disable_autocast_if_enabled():
return EdgeSoftmax_hetero.apply(*args)
def segment_reduce(op, x, offsets):
args = _cast_if_autocast_enabled(op, x, offsets)
with autocast(enabled=False):
with _disable_autocast_if_enabled():
return SegmentReduce.apply(*args)
def scatter_add(x, idx, m):
args = _cast_if_autocast_enabled(x, idx, m)
with autocast(enabled=False):
with _disable_autocast_if_enabled():
return ScatterAdd.apply(*args)
......@@ -1175,7 +1180,7 @@ def segment_mm(A, B, seglen_A):
return th.cat(C)
else:
args = _cast_if_autocast_enabled(A, B, seglen_A)
with autocast(enabled=False):
with _disable_autocast_if_enabled():
return SEGMENTMM.apply(*args)
......@@ -1186,5 +1191,5 @@ def gather_mm(A, B, idx_A=None, idx_B=None):
return th.bmm(A.unsqueeze(1), B).squeeze(1)
else:
args = _cast_if_autocast_enabled(A, B, idx_A, idx_B)
with autocast(enabled=False):
with _disable_autocast_if_enabled():
return GATHERMM.apply(*args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment