Unverified Commit 766e3b74 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Use FP16 tols for distributed tests with TF32 compute (#1831)



* Use FP16 tols for tests with TF32
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Use uniform init instead of constant init
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Revert constant init test, but reduce value
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 3a298e6b
......@@ -47,11 +47,6 @@ if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
)
# Disable TF32
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
# Quantization recipe setup
def quantization_recipe() -> Recipe:
if QUANTIZATION == "fp8":
......@@ -166,7 +161,7 @@ def _gather(tensor, dim=0):
def _constant(tensor):
return nn.init.constant_(tensor, 0.5)
return nn.init.constant_(tensor, 0.05)
def dist_print(msg, src=None, end="\n", error=False):
......@@ -189,7 +184,8 @@ def _get_tolerances(dtype):
if dtype == torch.bfloat16:
return {"rtol": 1.6e-2, "atol": 1e-5}
if dtype == torch.float32:
return {"rtol": 1.2e-4, "atol": 1e-4}
# TF32 has same mantissa bits as FP16
return {"rtol": 1e-3, "atol": 1e-5}
raise ValueError(f"Unsupported dtype ({dtype})")
......
......@@ -56,7 +56,7 @@ def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available:
pytest.skip(fp8_available)
pytest.skip(reason_for_no_fp8)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment