Commit 68364b49 authored by Hubert Lu's avatar Hubert Lu
Browse files

Conditionally define autocast_dtypes for different torch versions

parent 67ded2e2
......@@ -75,8 +75,11 @@ def _prep_inputs(batch_size, normalized_shape, dtype):
native = fused.clone().to(dtype).requires_grad_(True)
return native, fused
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
TORCH_MAJOR, TORCH_MINOR = int(torch.__version__.split('.')[0]), int(torch.__version__.split('.')[1])
if (TORCH_MAJOR <= 1 and TORCH_MINOR < 10):
autocast_dtypes = (torch.half,)
else:
autocast_dtypes = (torch.half, torch.bfloat16) if torch.cuda.is_bf16_supported() else (torch.half,)
class TestAutocastFusedLayerNorm(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment