Unverified Commit e032de58 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bugfix] Avoid initializing CUDA context at importing (#5134)

* Avoid initializing CUDA at importing

* renaming
parent 1a33ee99
...@@ -141,12 +141,17 @@ class empty_context: ...@@ -141,12 +141,17 @@ class empty_context:
return return
# This is to avoid warnings in cpu-only dgl. We don't enable autocast for CPU ops # Disable CUDA autocast since we have casted args manually,
autocast = th.cuda.amp.autocast if th.cuda.is_available() else empty_context # and do it only in a nested autocast context.
def _disable_autocast_if_enabled():
if th.is_autocast_enabled():
return th.cuda.amp.autocast(enabled=False)
else:
return empty_context()
def _cast_if_autocast_enabled(*args): def _cast_if_autocast_enabled(*args):
if not th.is_autocast_enabled() or not th.cuda.is_available(): if not th.is_autocast_enabled():
return args return args
else: else:
return th.cuda.amp.autocast_mode._cast( return th.cuda.amp.autocast_mode._cast(
...@@ -1023,7 +1028,7 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data): ...@@ -1023,7 +1028,7 @@ def gspmm(gidx, op, reduce_op, lhs_data, rhs_data):
op = "mul" op = "mul"
rhs_data = 1.0 / rhs_data rhs_data = 1.0 / rhs_data
args = _cast_if_autocast_enabled(gidx, op, reduce_op, lhs_data, rhs_data) args = _cast_if_autocast_enabled(gidx, op, reduce_op, lhs_data, rhs_data)
with autocast(enabled=False): with _disable_autocast_if_enabled():
return GSpMM.apply(*args) return GSpMM.apply(*args)
...@@ -1037,7 +1042,7 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"): ...@@ -1037,7 +1042,7 @@ def gsddmm(gidx, op, lhs_data, rhs_data, lhs_target="u", rhs_target="v"):
args = _cast_if_autocast_enabled( args = _cast_if_autocast_enabled(
gidx, op, lhs_data, rhs_data, lhs_target, rhs_target gidx, op, lhs_data, rhs_data, lhs_target, rhs_target
) )
with autocast(enabled=False): with _disable_autocast_if_enabled():
return GSDDMM.apply(*args) return GSDDMM.apply(*args)
...@@ -1068,7 +1073,7 @@ def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple): ...@@ -1068,7 +1073,7 @@ def gspmm_hetero(g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple):
args = _cast_if_autocast_enabled( args = _cast_if_autocast_enabled(
g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple g, op, reduce_op, lhs_len, *lhs_and_rhs_tuple
) )
with autocast(enabled=False): with _disable_autocast_if_enabled():
return GSpMM_hetero.apply(*args) return GSpMM_hetero.apply(*args)
...@@ -1101,31 +1106,31 @@ def gsddmm_hetero( ...@@ -1101,31 +1106,31 @@ def gsddmm_hetero(
args = _cast_if_autocast_enabled( args = _cast_if_autocast_enabled(
g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple g, op, lhs_len, lhs_target, rhs_target, *lhs_and_rhs_tuple
) )
with autocast(enabled=False): with _disable_autocast_if_enabled():
return GSDDMM_hetero.apply(*args) return GSDDMM_hetero.apply(*args)
def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"): def edge_softmax(gidx, logits, eids=ALL, norm_by="dst"):
args = _cast_if_autocast_enabled(gidx, logits, eids, norm_by) args = _cast_if_autocast_enabled(gidx, logits, eids, norm_by)
with autocast(enabled=False): with _disable_autocast_if_enabled():
return EdgeSoftmax.apply(*args) return EdgeSoftmax.apply(*args)
def edge_softmax_hetero(gidx, eids=ALL, norm_by="dst", *logits): def edge_softmax_hetero(gidx, eids=ALL, norm_by="dst", *logits):
args = _cast_if_autocast_enabled(gidx, eids, norm_by, *logits) args = _cast_if_autocast_enabled(gidx, eids, norm_by, *logits)
with autocast(enabled=False): with _disable_autocast_if_enabled():
return EdgeSoftmax_hetero.apply(*args) return EdgeSoftmax_hetero.apply(*args)
def segment_reduce(op, x, offsets): def segment_reduce(op, x, offsets):
args = _cast_if_autocast_enabled(op, x, offsets) args = _cast_if_autocast_enabled(op, x, offsets)
with autocast(enabled=False): with _disable_autocast_if_enabled():
return SegmentReduce.apply(*args) return SegmentReduce.apply(*args)
def scatter_add(x, idx, m): def scatter_add(x, idx, m):
args = _cast_if_autocast_enabled(x, idx, m) args = _cast_if_autocast_enabled(x, idx, m)
with autocast(enabled=False): with _disable_autocast_if_enabled():
return ScatterAdd.apply(*args) return ScatterAdd.apply(*args)
...@@ -1175,7 +1180,7 @@ def segment_mm(A, B, seglen_A): ...@@ -1175,7 +1180,7 @@ def segment_mm(A, B, seglen_A):
return th.cat(C) return th.cat(C)
else: else:
args = _cast_if_autocast_enabled(A, B, seglen_A) args = _cast_if_autocast_enabled(A, B, seglen_A)
with autocast(enabled=False): with _disable_autocast_if_enabled():
return SEGMENTMM.apply(*args) return SEGMENTMM.apply(*args)
...@@ -1186,5 +1191,5 @@ def gather_mm(A, B, idx_A=None, idx_B=None): ...@@ -1186,5 +1191,5 @@ def gather_mm(A, B, idx_A=None, idx_B=None):
return th.bmm(A.unsqueeze(1), B).squeeze(1) return th.bmm(A.unsqueeze(1), B).squeeze(1)
else: else:
args = _cast_if_autocast_enabled(A, B, idx_A, idx_B) args = _cast_if_autocast_enabled(A, B, idx_A, idx_B)
with autocast(enabled=False): with _disable_autocast_if_enabled():
return GATHERMM.apply(*args) return GATHERMM.apply(*args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment