Unverified Commit 2d88db5a authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Bug] check dtype before convert to gk (#3414)

parent d9472873
...@@ -265,6 +265,7 @@ def metis_partition_assignment(g, k, balance_ntypes=None, balance_edges=False, m ...@@ -265,6 +265,7 @@ def metis_partition_assignment(g, k, balance_ntypes=None, balance_edges=False, m
A vector with each element that indicates the partition ID of a vertex. A vector with each element that indicates the partition ID of a vertex.
''' '''
assert mode in ("k-way", "recursive"), "'mode' can only be 'k-way' or 'recursive'" assert mode in ("k-way", "recursive"), "'mode' can only be 'k-way' or 'recursive'"
assert g.idtype == F.int64, "IdType of graph is required to be int64 for now."
# METIS works only on symmetric graphs. # METIS works only on symmetric graphs.
# The METIS runs on the symmetric graph to generate the node assignment to partitions. # The METIS runs on the symmetric graph to generate the node assignment to partitions.
start = time.time() start = time.time()
......
...@@ -23,8 +23,10 @@ namespace dgl { ...@@ -23,8 +23,10 @@ namespace dgl {
*/ */
gk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row) { gk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row) {
// TODO(zhengda) The conversion will be zero-copy in the future. // TODO(zhengda) The conversion will be zero-copy in the future.
const dgl_id_t *indptr = static_cast<dgl_id_t*>(mat.indptr->data); CHECK_EQ(mat.indptr->dtype.bits, sizeof(dgl_id_t) * CHAR_BIT);
const dgl_id_t *indices = static_cast<dgl_id_t*>(mat.indices->data); CHECK_EQ(mat.indices->dtype.bits, sizeof(dgl_id_t) * CHAR_BIT);
const dgl_id_t *indptr = static_cast<dgl_id_t *>(mat.indptr->data);
const dgl_id_t *indices = static_cast<dgl_id_t *>(mat.indices->data);
gk_csr_t *gk_csr = gk_csr_Create(); gk_csr_t *gk_csr = gk_csr_Create();
gk_csr->nrows = mat.num_rows; gk_csr->nrows = mat.num_rows;
......
...@@ -498,13 +498,13 @@ def test_laplacian_lambda_max(): ...@@ -498,13 +498,13 @@ def test_laplacian_lambda_max():
assert l_max < 2 + eps assert l_max < 2 + eps
''' '''
def create_large_graph(num_nodes): def create_large_graph(num_nodes, idtype=F.int64):
row = np.random.choice(num_nodes, num_nodes * 10) row = np.random.choice(num_nodes, num_nodes * 10)
col = np.random.choice(num_nodes, num_nodes * 10) col = np.random.choice(num_nodes, num_nodes * 10)
spm = spsp.coo_matrix((np.ones(len(row)), (row, col))) spm = spsp.coo_matrix((np.ones(len(row)), (row, col)))
spm.sum_duplicates() spm.sum_duplicates()
return dgl.from_scipy(spm) return dgl.from_scipy(spm, idtype=idtype)
def get_nodeflow(g, node_ids, num_layers): def get_nodeflow(g, node_ids, num_layers):
batch_size = len(node_ids) batch_size = len(node_ids)
...@@ -530,14 +530,22 @@ def test_partition_with_halo(): ...@@ -530,14 +530,22 @@ def test_partition_with_halo():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(F._default_context_str == 'gpu', reason="METIS doesn't support GPU") @unittest.skipIf(F._default_context_str == 'gpu', reason="METIS doesn't support GPU")
def test_metis_partition(): @parametrize_dtype
def test_metis_partition(idtype):
# TODO(zhengda) Metis fails to partition a small graph. # TODO(zhengda) Metis fails to partition a small graph.
g = create_large_graph(1000) g = create_large_graph(1000, idtype=idtype)
check_metis_partition(g, 0) if idtype == F.int64:
check_metis_partition(g, 1) check_metis_partition(g, 0)
check_metis_partition(g, 2) check_metis_partition(g, 1)
check_metis_partition_with_constraint(g) check_metis_partition(g, 2)
check_metis_partition_with_constraint(g)
else:
assert_fail = False
try:
check_metis_partition(g, 1)
except:
assert_fail = True
assert assert_fail
def check_metis_partition_with_constraint(g): def check_metis_partition_with_constraint(g):
ntypes = np.zeros((g.number_of_nodes(),), dtype=np.int32) ntypes = np.zeros((g.number_of_nodes(),), dtype=np.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment