"tests/vscode:/vscode.git/clone" did not exist on "6529a5b5c13210b41bcd87c555c72696cd7083a5"
Unverified Commit 7adaefe2 authored by Roohollah Etemadi's avatar Roohollah Etemadi Committed by GitHub
Browse files

support bf16 (#25879)

* added bf16 support

* added cuda availability check

* applied make style, quality
parent af3de8d8
......@@ -305,26 +305,7 @@ def is_torch_bf16_gpu_available():
import torch
# since currently no utility function is available we build our own.
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
# with additional check for torch version
# to succeed: (torch is required to be >= 1.10 anyway)
# 1. the hardware needs to support bf16 (GPU arch >= Ampere, or CPU)
# 2. if using gpu, CUDA >= 11
# 3. torch.autocast exists
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
# really only correct for the 0th gpu (or currently set default device if different from 0)
if torch.cuda.is_available() and torch.version.cuda is not None:
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
return False
if int(torch.version.cuda.split(".")[0]) < 11:
return False
if not hasattr(torch.cuda.amp, "autocast"):
return False
else:
return False
return True
return torch.cuda.is_available() and torch.cuda.is_bf16_supported()
def is_torch_bf16_cpu_available():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment