"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f07a16e09bb5b1cf4fa2306bfa4ea791f24fa968"
Unverified Commit 2d88db5a authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Bug] check dtype before convert to gk (#3414)

parent d9472873
......@@ -265,6 +265,7 @@ def metis_partition_assignment(g, k, balance_ntypes=None, balance_edges=False, m
A vector with each element that indicates the partition ID of a vertex.
'''
assert mode in ("k-way", "recursive"), "'mode' can only be 'k-way' or 'recursive'"
assert g.idtype == F.int64, "IdType of graph is required to be int64 for now."
# METIS works only on symmetric graphs.
# The METIS runs on the symmetric graph to generate the node assignment to partitions.
start = time.time()
......
......@@ -23,8 +23,10 @@ namespace dgl {
*/
gk_csr_t *Convert2GKCsr(const aten::CSRMatrix mat, bool is_row) {
// TODO(zhengda) The conversion will be zero-copy in the future.
const dgl_id_t *indptr = static_cast<dgl_id_t*>(mat.indptr->data);
const dgl_id_t *indices = static_cast<dgl_id_t*>(mat.indices->data);
CHECK_EQ(mat.indptr->dtype.bits, sizeof(dgl_id_t) * CHAR_BIT);
CHECK_EQ(mat.indices->dtype.bits, sizeof(dgl_id_t) * CHAR_BIT);
const dgl_id_t *indptr = static_cast<dgl_id_t *>(mat.indptr->data);
const dgl_id_t *indices = static_cast<dgl_id_t *>(mat.indices->data);
gk_csr_t *gk_csr = gk_csr_Create();
gk_csr->nrows = mat.num_rows;
......
......@@ -498,13 +498,13 @@ def test_laplacian_lambda_max():
assert l_max < 2 + eps
'''
def create_large_graph(num_nodes):
def create_large_graph(num_nodes, idtype=F.int64):
row = np.random.choice(num_nodes, num_nodes * 10)
col = np.random.choice(num_nodes, num_nodes * 10)
spm = spsp.coo_matrix((np.ones(len(row)), (row, col)))
spm.sum_duplicates()
return dgl.from_scipy(spm)
return dgl.from_scipy(spm, idtype=idtype)
def get_nodeflow(g, node_ids, num_layers):
batch_size = len(node_ids)
......@@ -530,14 +530,22 @@ def test_partition_with_halo():
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(F._default_context_str == 'gpu', reason="METIS doesn't support GPU")
def test_metis_partition():
@parametrize_dtype
def test_metis_partition(idtype):
# TODO(zhengda) Metis fails to partition a small graph.
g = create_large_graph(1000)
check_metis_partition(g, 0)
check_metis_partition(g, 1)
check_metis_partition(g, 2)
check_metis_partition_with_constraint(g)
g = create_large_graph(1000, idtype=idtype)
if idtype == F.int64:
check_metis_partition(g, 0)
check_metis_partition(g, 1)
check_metis_partition(g, 2)
check_metis_partition_with_constraint(g)
else:
assert_fail = False
try:
check_metis_partition(g, 1)
except:
assert_fail = True
assert assert_fail
def check_metis_partition_with_constraint(g):
ntypes = np.zeros((g.number_of_nodes(),), dtype=np.int32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment