Commit f1a8f926 authored by Tong He's avatar Tong He Committed by Minjie Wang
Browse files

[Feature] Add environment variable to switch on/off MXNet set_np_shape (#1207)

* add env var

* Trigger CI

* simplification

* add doc
parent a00636a0
...@@ -21,6 +21,11 @@ MXNet uses uint32 as the default data type for integer tensors, which only suppo ...@@ -21,6 +21,11 @@ MXNet uses uint32 as the default data type for integer tensors, which only suppo
size smaller than 2^32. To enable large graph training, *build* MXNet with ``USE_INT64_TENSOR_SIZE=1`` size smaller than 2^32. To enable large graph training, *build* MXNet with ``USE_INT64_TENSOR_SIZE=1``
flag. See `this FAQ <https://mxnet.apache.org/api/faq/large_tensor_support>`_ for more information. flag. See `this FAQ <https://mxnet.apache.org/api/faq/large_tensor_support>`_ for more information.
MXNet 1.5 and later has an option to enable Numpy shape mode for ``NDArray`` objects, some DGL models
need this mode to be enabled to run correctly. However, this mode may not compatible with pretrained
model parameters with this mode disabled, e.g. pretrained models from GluonCV and GluonNLP.
By setting ``DGL_MXNET_SET_NP_SHAPE``, users can switch this mode on or off.
Tensorflow backend Tensorflow backend
------------------ ------------------
......
...@@ -2,6 +2,7 @@ from __future__ import absolute_import ...@@ -2,6 +2,7 @@ from __future__ import absolute_import
from distutils.version import LooseVersion from distutils.version import LooseVersion
import os
import numpy as np import numpy as np
import mxnet as mx import mxnet as mx
import mxnet.ndarray as nd import mxnet.ndarray as nd
...@@ -17,7 +18,7 @@ if MX_VERSION.version[0] == 1 and MX_VERSION.version[1] < 5: ...@@ -17,7 +18,7 @@ if MX_VERSION.version[0] == 1 and MX_VERSION.version[1] < 5:
# After MXNet 1.5, empty tensors aren't supprted by default. # After MXNet 1.5, empty tensors aren't supprted by default.
# After we turn on the numpy compatible flag, MXNet supports empty NDArray. # After we turn on the numpy compatible flag, MXNet supports empty NDArray.
mx.set_np_shape(True) mx.set_np_shape(bool(os.environ.get('DGL_MXNET_SET_NP_SHAPE', True)))
def data_type_dict(): def data_type_dict():
return {'float16' : np.float16, return {'float16' : np.float16,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment