Unverified Commit eabcc58e authored by ndickson-nvidia's avatar ndickson-nvidia Committed by GitHub
Browse files

[Bug][Feature] Added cublasGemm<__half> specialization (#3988) (#4029)

* * Added specialization of cublasGemm function for `__half` type, to try to address https://github.com/dmlc/dgl/issues/3988



* * Added USE_FP16 guard

* * Added test cases to test_segment_mm, to test newly-added FP16 specialization of cublasGemm

* * Replaced for loop in test_segment_mm with pytest.mark.parametrize, as recommended
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 85c2ff71
......@@ -26,6 +26,18 @@ cublasStatus_t cublasGemm(cublasHandle_t handle, cublasOperation_t transa,
return CUBLAS_STATUS_EXECUTION_FAILED;
}
#ifdef USE_FP16
template <>
cublasStatus_t cublasGemm<__half>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const __half* alpha, const __half* A, int lda,
const __half* B, int ldb, const __half* beta,
__half* C, int ldc) {
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda,
B, ldb, beta, C, ldc);
}
#endif
template <>
cublasStatus_t cublasGemm<float>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
......
......@@ -7,6 +7,7 @@ import pytest, unittest
import networkx as nx
import backend as F
import numpy as np
import torch
random.seed(42)
np.random.seed(42)
......@@ -290,16 +291,16 @@ def test_segment_reduce(reducer):
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_idtype
@pytest.mark.parametrize('feat_size', [1, 8, 16, 64, 256])
def test_segment_mm(idtype, feat_size):
import torch
@pytest.mark.parametrize('dtype,tol', [(torch.float16,1e-2),(torch.float32,3e-3),(torch.float64,1e-4)])
def test_segment_mm(idtype, feat_size, dtype, tol):
dev = F.ctx()
# input
a = torch.tensor(np.random.rand(100, feat_size)).to(dev)
a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
a.requires_grad_()
b = torch.tensor(np.random.rand(10, feat_size, feat_size + 1)).to(dev)
b = torch.tensor(np.random.rand(10, feat_size, feat_size + 1)).to(dev).to(dtype)
b.requires_grad_()
seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0])
dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev)
dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
# compute
c = dgl.ops.segment_mm(a, b, seglen_a)
c.backward(dc)
......@@ -311,16 +312,16 @@ def test_segment_mm(idtype, feat_size):
for i, l in enumerate(seglen_a):
c_t.append(a[off:off+l] @ b[i])
off += l
c_t = torch.cat(c_t)
c_t = torch.cat(c_t).to(dtype)
a.grad.zero_()
b.grad.zero_()
c_t.backward(dc)
da_t = a.grad
db_t = b.grad
assert torch.allclose(c, c_t, atol=1e-4, rtol=1e-4)
assert torch.allclose(da, da_t, atol=1e-4, rtol=1e-4)
assert torch.allclose(db, db_t, atol=1e-4, rtol=1e-4)
assert torch.allclose(c, c_t, atol=tol, rtol=tol)
assert torch.allclose(da, da_t, atol=tol, rtol=tol)
assert torch.allclose(db, db_t, atol=tol, rtol=tol)
@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_idtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment