Commit 68364b49 authored by Hubert Lu's avatar Hubert Lu
Browse files

Conditionally define autocast_dtypes for different torch versions

parent 67ded2e2
...@@ -75,8 +75,11 @@ def _prep_inputs(batch_size, normalized_shape, dtype): ...@@ -75,8 +75,11 @@ def _prep_inputs(batch_size, normalized_shape, dtype):
native = fused.clone().to(dtype).requires_grad_(True) native = fused.clone().to(dtype).requires_grad_(True)
return native, fused return native, fused
TORCH_MAJOR, TORCH_MINOR = int(torch.__version__.split('.')[0]), int(torch.__version__.split('.')[1])
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,) if (TORCH_MAJOR <= 1 and TORCH_MINOR < 10):
autocast_dtypes = (torch.half,)
else:
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
class TestAutocastFusedLayerNorm(unittest.TestCase): class TestAutocastFusedLayerNorm(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment