Unverified Commit 79b0a50a authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Unittest][Fix] Several unit tests fixes for Ampere+ and PyTorch 1.12+ (#4213)

* Fix test_csrmm for tensor core

* unset allow tf32 flag

* update test unified tensor

* skip fp16 for CPU
parent 2efdaa5d
...@@ -5,6 +5,10 @@ import dgl ...@@ -5,6 +5,10 @@ import dgl
from test_utils import parametrize_idtype from test_utils import parametrize_idtype
import backend as F import backend as F
if F.backend_name == 'pytorch':
import torch
torch.backends.cuda.matmul.allow_tf32 = False
def _random_simple_graph(idtype, dtype, ctx, M, N, max_nnz, srctype, dsttype, etype): def _random_simple_graph(idtype, dtype, ctx, M, N, max_nnz, srctype, dsttype, etype):
src = np.random.randint(0, M, (max_nnz,)) src = np.random.randint(0, M, (max_nnz,))
dst = np.random.randint(0, N, (max_nnz,)) dst = np.random.randint(0, N, (max_nnz,))
......
...@@ -293,6 +293,8 @@ def test_segment_reduce(reducer): ...@@ -293,6 +293,8 @@ def test_segment_reduce(reducer):
@pytest.mark.parametrize('feat_size', [1, 8, 16, 64, 256]) @pytest.mark.parametrize('feat_size', [1, 8, 16, 64, 256])
@pytest.mark.parametrize('dtype,tol', [(torch.float16,1e-2),(torch.float32,3e-3),(torch.float64,1e-4)]) @pytest.mark.parametrize('dtype,tol', [(torch.float16,1e-2),(torch.float32,3e-3),(torch.float64,1e-4)])
def test_segment_mm(idtype, feat_size, dtype, tol): def test_segment_mm(idtype, feat_size, dtype, tol):
if F._default_context_str == 'cpu' and dtype == torch.float16:
pytest.skip("fp16 support for CPU linalg functions has been removed in PyTorch.")
dev = F.ctx() dev = F.ctx()
# input # input
a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype) a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
......
...@@ -18,23 +18,24 @@ def start_unified_tensor_worker(dev_id, input, seq_idx, rand_idx, output_seq, ou ...@@ -18,23 +18,24 @@ def start_unified_tensor_worker(dev_id, input, seq_idx, rand_idx, output_seq, ou
def test_unified_tensor(): def test_unified_tensor():
test_row_size = 65536 test_row_size = 65536
test_col_size = 128 test_col_size = 128
rand_test_size = 8192 rand_test_size = 8192
device = th.device('cuda:0')
input = th.rand((test_row_size, test_col_size)) input = th.rand((test_row_size, test_col_size))
input_unified = dgl.contrib.UnifiedTensor(input, device=th.device('cuda')) input_unified = dgl.contrib.UnifiedTensor(input, device=device)
seq_idx = th.arange(0, test_row_size) seq_idx = th.arange(0, test_row_size)
# CPU indexing
assert th.all(th.eq(input[seq_idx], input_unified[seq_idx])) assert th.all(th.eq(input[seq_idx], input_unified[seq_idx]))
# GPU indexing
seq_idx = seq_idx.to(th.device('cuda')) assert th.all(th.eq(input[seq_idx].to(device), input_unified[seq_idx.to(device)]))
assert th.all(th.eq(input[seq_idx].to(th.device('cuda')), input_unified[seq_idx]))
rand_idx = th.randint(0, test_row_size, (rand_test_size,)) rand_idx = th.randint(0, test_row_size, (rand_test_size,))
# CPU indexing
assert th.all(th.eq(input[rand_idx], input_unified[rand_idx])) assert th.all(th.eq(input[rand_idx], input_unified[rand_idx]))
# GPU indexing
rand_idx = rand_idx.to(th.device('cuda')) assert th.all(th.eq(input[rand_idx].to(device), input_unified[rand_idx.to(device)]))
assert th.all(th.eq(input[rand_idx].to(th.device('cuda')), input_unified[rand_idx]))
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet') @unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(F.ctx().type == 'cpu', reason='gpu only test') @unittest.skipIf(F.ctx().type == 'cpu', reason='gpu only test')
......
...@@ -22,6 +22,7 @@ export DGL_LIBRARY_PATH=${PWD}/build ...@@ -22,6 +22,7 @@ export DGL_LIBRARY_PATH=${PWD}/build
export PYTHONPATH=tests:${PWD}/python:$PYTHONPATH export PYTHONPATH=tests:${PWD}/python:$PYTHONPATH
export DGL_DOWNLOAD_DIR=${PWD} export DGL_DOWNLOAD_DIR=${PWD}
export TF_FORCE_GPU_ALLOW_GROWTH=true export TF_FORCE_GPU_ALLOW_GROWTH=true
unset TORCH_ALLOW_TF32_CUBLAS_OVERRIDE
if [ $2 == "gpu" ] if [ $2 == "gpu" ]
then then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment