util.py 970 Bytes
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from deepspeed.git_version_info import torch_info


def required_torch_version():
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])

    if TORCH_MAJOR >= 1 and TORCH_MINOR >= 8:
        return True
    else:
        return False


def bf16_required_version_check():
    TORCH_MAJOR = int(torch.__version__.split('.')[0])
    TORCH_MINOR = int(torch.__version__.split('.')[1])

    if type(torch.cuda.nccl.version()) != tuple:
        return False
    else:
        NCCL_MAJOR = torch.cuda.nccl.version()[0]
        NCCL_MINOR = torch.cuda.nccl.version()[1]

    CUDA_MAJOR = int(torch_info['cuda_version'].split('.')[0])
    if (TORCH_MAJOR > 1 or
        (TORCH_MAJOR == 1 and TORCH_MINOR >= 10)) and (CUDA_MAJOR >= 11) and (
            NCCL_MAJOR > 2 or
            (NCCL_MAJOR == 2 and NCCL_MINOR >= 10)) and torch.cuda.is_bf16_supported():
        return True
    else:
        return False