Unverified Commit 62af41c2 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Bug] Enable turn on/off libxsmm at runtime (#4455)



* enable turn on/off libxsmm at runtime by adding a global config and related API
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-19-194.ap-northeast-1.compute.internal>
parent 06608f84
...@@ -208,7 +208,7 @@ set at each iteration. ``prop_edges_YYY`` applies traversal algorithm ``YYY`` an ...@@ -208,7 +208,7 @@ set at each iteration. ``prop_edges_YYY`` applies traversal algorithm ``YYY`` an
Utilities Utilities
----------------------------------------------- -----------------------------------------------
Other utilities for controlling randomness, saving and loading graphs, functions that applies Other utilities for controlling randomness, saving and loading graphs, setting and getting runtime configurations, functions that applies
the same function to every elements in a container, etc. the same function to every elements in a container, etc.
.. autosummary:: .. autosummary::
...@@ -218,3 +218,5 @@ the same function to every elements in a container, etc. ...@@ -218,3 +218,5 @@ the same function to every elements in a container, etc.
save_graphs save_graphs
load_graphs load_graphs
apply_each apply_each
use_libxsmm
is_libxsmm_enabled
/*!
* Copyright (c) 2019 by Contributors
* \file runtime/config.h
* \brief DGL runtime config
*/
#ifndef DGL_RUNTIME_CONFIG_H_
#define DGL_RUNTIME_CONFIG_H_
namespace dgl {
namespace runtime {
class Config {
public:
static Config* Global() {
static Config config;
return &config;
}
// Enabling or disable use libxsmm for Spmm
void EnableLibxsmm(bool);
bool IsLibxsmmAvailable() const;
private:
Config() = default;
bool libxsmm_ = true;
};
} // namespace runtime
} // namespace dgl
#endif // DGL_RUNTIME_CONFIG_H_
...@@ -50,6 +50,7 @@ from .data.utils import save_graphs, load_graphs ...@@ -50,6 +50,7 @@ from .data.utils import save_graphs, load_graphs
from . import optim from . import optim
from .frame import LazyFeature from .frame import LazyFeature
from .utils import apply_each from .utils import apply_each
from .global_config import is_libxsmm_enabled, use_libxsmm
from ._deprecate.graph import DGLGraph as DGLGraphStale from ._deprecate.graph import DGLGraph as DGLGraphStale
from ._deprecate.nodeflow import * from ._deprecate.nodeflow import *
"""Module for global configuration operators."""
from ._ffi.function import _init_api
__all__ = ["is_libxsmm_enabled", "use_libxsmm"]
def use_libxsmm(flag):
r"""Set whether DGL uses libxsmm at runtime.
Detailed information about libxsmm can be found here:
https://github.com/libxsmm/libxsmm
Parameters
----------
flag : boolean
If True, use libxsmm, otherwise not.
See Also
--------
is_libxsmm_enabled
"""
_CAPI_DGLConfigSetLibxsmm(flag)
def is_libxsmm_enabled():
r"""Get whether the use_libxsmm flag is turned on.
Returns
----------
use_libxsmm_flag[boolean]
True if the use_libxsmm flag is turned on.
See Also
----------
use_libxsmm
"""
return _CAPI_DGLConfigGetLibxsmm()
_init_api("dgl.global_config")
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/bcast.h> #include <dgl/bcast.h>
#include <dgl/runtime/parallel_for.h> #include <dgl/runtime/parallel_for.h>
#include <dgl/runtime/config.h>
#include <math.h> #include <math.h>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
...@@ -142,7 +143,9 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -142,7 +143,9 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
#ifdef USE_AVX #ifdef USE_AVX
#ifdef USE_LIBXSMM #ifdef USE_LIBXSMM
const bool no_libxsmm = const bool no_libxsmm =
bcast.use_bcast || std::is_same<DType, double>::value; bcast.use_bcast ||
std::is_same<DType, double>::value ||
!dgl::runtime::Config::Global()->IsLibxsmmAvailable();
if (!no_libxsmm) { if (!no_libxsmm) {
SpMMSumCsrLibxsmm<IdType, DType, Op>(bcast, csr, ufeat, efeat, out); SpMMSumCsrLibxsmm<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
} else { } else {
...@@ -269,7 +272,9 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, ...@@ -269,7 +272,9 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
#ifdef USE_LIBXSMM #ifdef USE_LIBXSMM
const bool no_libxsmm = const bool no_libxsmm =
bcast.use_bcast || std::is_same<DType, double>::value; bcast.use_bcast ||
std::is_same<DType, double>::value ||
!dgl::runtime::Config::Global()->IsLibxsmmAvailable();
if (!no_libxsmm) { if (!no_libxsmm) {
SpMMCmpCsrLibxsmm<IdType, DType, Op, Cmp>(bcast, csr, ufeat, efeat, out, argu, arge); SpMMCmpCsrLibxsmm<IdType, DType, Op, Cmp>(bcast, csr, ufeat, efeat, out, argu, arge);
} else { } else {
......
...@@ -264,7 +264,8 @@ inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel( ...@@ -264,7 +264,8 @@ inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel(
(sizeof(IdType) == 8) ? LIBXSMM_DATATYPE_I64 : LIBXSMM_DATATYPE_I32, opredop_flags); (sizeof(IdType) == 8) ? LIBXSMM_DATATYPE_I64 : LIBXSMM_DATATYPE_I32, opredop_flags);
} }
if (kernel == nullptr) { if (kernel == nullptr) {
LOG(FATAL) << "Failed to generate libxsmm kernel for the SpMM operation!"; LOG(FATAL) << "Failed to generate libxsmm kernel for the SpMM operation."
"To disable libxsmm, use dgl.use_libxsmm(false).";
} }
return kernel; return kernel;
} }
......
/*!
* Copyright (c) 2019 by Contributors
* \file runtime/config.cc
* \brief DGL runtime config
*/
#include <dgl/runtime/registry.h>
#include <dgl/runtime/config.h>
using namespace dgl::runtime;
namespace dgl {
namespace runtime {
void Config::EnableLibxsmm(bool b) {
libxsmm_ = b;
}
bool Config::IsLibxsmmAvailable() const {
return libxsmm_;
}
DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigSetLibxsmm")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
bool use_libxsmm = args[0];
dgl::runtime::Config::Global()->EnableLibxsmm(use_libxsmm);
});
DGL_REGISTER_GLOBAL("global_config._CAPI_DGLConfigGetLibxsmm")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
*rv = dgl::runtime::Config::Global()->IsLibxsmmAvailable();
});
} // namespace runtime
} // namespace dgl
...@@ -385,3 +385,20 @@ def _test_gather_mm_idx_a(idtype, feat_size): ...@@ -385,3 +385,20 @@ def _test_gather_mm_idx_a(idtype, feat_size):
assert torch.allclose(c, c_t, atol=1e-4, rtol=1e-4) assert torch.allclose(c, c_t, atol=1e-4, rtol=1e-4)
assert torch.allclose(da, da_t, atol=1e-4, rtol=1e-4) assert torch.allclose(da, da_t, atol=1e-4, rtol=1e-4)
assert torch.allclose(db, db_t, atol=1e-4, rtol=1e-4) assert torch.allclose(db, db_t, atol=1e-4, rtol=1e-4)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@unittest.skipIf(F._default_context_str == 'gpu', reason="Libxsmm only fit in CPU.")
def test_use_libxsmm_switch():
import torch
g = dgl.graph(([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]))
x = torch.ones(3, 2, requires_grad=True)
y = torch.arange(1, 13).float().view(6, 2).requires_grad_()
assert dgl.is_libxsmm_enabled()
dgl.ops.u_mul_e_sum(g, x, y)
dgl.use_libxsmm(False)
assert ~dgl.is_libxsmm_enabled()
dgl.ops.u_mul_e_sum(g, x, y)
dgl.use_libxsmm(True)
assert dgl.is_libxsmm_enabled()
dgl.ops.u_mul_e_sum(g, x, y)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment