"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5ce4814af1de6d2dc2cc67a46d3862ce62261e2b"
Unverified Commit ce378327 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

fix bf16 tests (#5089)

parent 37bd0925
......@@ -9,10 +9,14 @@ consumption. This feature requires DGL 0.9+.
Message-Passing with Half Precision
-----------------------------------
DGL allows message-passing on ``float16 (fp16)`` / ``bfloat16 (bf16)`` (requires CUDA >= 11.0)
DGL allows message-passing on ``float16 (fp16)`` / ``bfloat16 (bf16)``
features for both UDFs (User Defined Functions) and built-in functions
(e.g., ``dgl.function.sum``, ``dgl.function.copy_u``).
.. note::
Please check bfloat16 support via ``torch.cuda.is_bf16_supported()`` before using it.
Typically it requires CUDA >= 11.0 and GPU compute capability >= 8.0.
The following example shows how to use DGL's message-passing APIs on half-precision
features:
......
......@@ -2,16 +2,15 @@ import random
import unittest
import backend as F
import dgl
import numpy as np
import pytest
import torch
from dgl.ops import edge_softmax, gsddmm, gspmm, segment_reduce
from test_utils import parametrize_idtype
from test_utils.graph_cases import get_cases
import dgl
from dgl.ops import edge_softmax, gsddmm, gspmm, segment_reduce
from dgl.utils import version
random.seed(42)
np.random.seed(42)
......@@ -177,30 +176,32 @@ def test_spmm(idtype, g, shp, msg, reducer):
@unittest.skipIf(
dgl.backend.backend_name != "pytorch",
reason="Only support PyTorch for now."
reason="Only support PyTorch for now.",
)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Don't support half precision on CPU."
reason="Don't support half precision on CPU.",
)
@parametrize_idtype
@pytest.mark.parametrize(
"dtype, rtol, atol",
[(torch.float16, 1e-3, 0.5), (torch.bfloat16, 4e-3, 2.)]
[(torch.float16, 1e-3, 0.5), (torch.bfloat16, 4e-3, 2.0)],
)
def test_half_spmm(idtype, dtype, rtol, atol):
if version.parse(torch.version.cuda) < version.parse("11.0") \
and dtype == torch.bfloat16:
pytest.skip("BF16 requires CUDA >= 11.0.")
if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
pytest.skip("BF16 is not supported.")
# make sure the spmm result is < 512 to match the rtol/atol we set.
g = dgl.graph((torch.arange(900), torch.tensor([0] * 900)),
idtype=idtype, device=F.ctx())
g = dgl.graph(
(torch.arange(900), torch.tensor([0] * 900)),
idtype=idtype,
device=F.ctx(),
)
feat_fp32 = torch.rand((g.num_src_nodes(), 32)).to(0)
feat_half = feat_fp32.to(dtype)
# test SpMMCSR
g = g.formats(['csc'])
g = g.formats(["csc"])
res_fp32 = dgl.ops.copy_u_sum(g, feat_fp32)[0]
res_half = dgl.ops.copy_u_sum(g, feat_half)[0].float()
assert torch.allclose(res_fp32, res_half, rtol=rtol, atol=atol)
......@@ -364,20 +365,25 @@ def test_segment_reduce(reducer):
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
@pytest.mark.parametrize(
"dtype, tol",
[(torch.float16, 1e-2), (torch.bfloat16, 1e-2),
(torch.float32, 3e-3), (torch.float64, 1e-4)],
[
(torch.float16, 1e-2),
(torch.bfloat16, 1e-2),
(torch.float32, 3e-3),
(torch.float64, 1e-4),
],
)
def test_segment_mm(idtype, feat_size, dtype, tol):
if F._default_context_str == "cpu" and dtype in (torch.float16, torch.bfloat16):
pytest.skip(
"Only support float32 and float64 on CPU."
)
if F._default_context_str == "gpu" \
and version.parse(torch.version.cuda) < version.parse("11.0") \
and dtype == torch.bfloat16:
pytest.skip(
"BF16 requires CUDA >= 11.0."
)
if F._default_context_str == "cpu" and dtype in (
torch.float16,
torch.bfloat16,
):
pytest.skip("Only support float32 and float64 on CPU.")
if (
F._default_context_str == "gpu"
and dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
pytest.skip("BF16 is not supported.")
dev = F.ctx()
# input
a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
......@@ -419,22 +425,35 @@ def test_segment_mm(idtype, feat_size, dtype, tol):
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
@pytest.mark.parametrize(
"dtype, tol",
[(torch.float16, 1e-2), (torch.bfloat16, 2e-2),
(torch.float32, 3e-3), (torch.float64, 1e-4)]
[
(torch.float16, 1e-2),
(torch.bfloat16, 2e-2),
(torch.float32, 3e-3),
(torch.float64, 1e-4),
],
)
def test_gather_mm_idx_b(feat_size, dtype, tol):
if F._default_context_str == "cpu" and dtype in (torch.float16, torch.bfloat16):
if F._default_context_str == "cpu" and dtype in (
torch.float16,
torch.bfloat16,
):
pytest.skip("Only support float32 and float64 on CPU.")
if F._default_context_str == "gpu" \
and version.parse(torch.version.cuda) < version.parse("11.0") \
and dtype == torch.bfloat16:
pytest.skip("BF16 requires CUDA >= 11.0.")
if (
F._default_context_str == "gpu"
and dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
pytest.skip("BF16 is not supported.")
dev = F.ctx()
# input
a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
a.requires_grad_()
b = torch.tensor(np.random.rand(10, feat_size, feat_size + 1)).to(dev).to(dtype)
b = (
torch.tensor(np.random.rand(10, feat_size, feat_size + 1))
.to(dev)
.to(dtype)
)
b.requires_grad_()
idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev).long()
dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment