Unverified Commit f9305e75 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Give Some Extensions Version Guard in Build&Runtime (#1358)

* guard

* update

* remove unnecessary version guard

* runtime version guard

* cosmetic

* skip tests appropriately
parent fed20d2a
import logging import logging
import warnings
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten # May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch import torch
...@@ -37,3 +38,15 @@ handler = logging.StreamHandler() ...@@ -37,3 +38,15 @@ handler = logging.StreamHandler()
handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S")) handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S"))
_library_root_logger.addHandler(handler) _library_root_logger.addHandler(handler)
_library_root_logger.propagate = False _library_root_logger.propagate = False
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
cudnn_available = torch.backends.cudnn.is_available()
cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
warnings.warn(
f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, "
f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
)
return False
return True
import functools as func import functools as func
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from apex import check_cudnn_version_and_warn
import fast_bottleneck import fast_bottleneck
import nccl_p2p_cuda as inc import nccl_p2p_cuda as inc
assert check_cudnn_version_and_warn(__name__, 8400)
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'): def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
weight_tensor_nchw = tensor weight_tensor_nchw = tensor
nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity) nn.init.kaiming_uniform_(weight_tensor_nchw, a=a, mode=mode, nonlinearity=nonlinearity)
......
import torch
import pdb import pdb
import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from apex import check_cudnn_version_and_warn
import fused_conv_bias_relu import fused_conv_bias_relu
check_cudnn_version_and_warn(__name__, 8400)
class ConvBiasReLU_(torch.autograd.Function): class ConvBiasReLU_(torch.autograd.Function):
@staticmethod @staticmethod
......
import torch
import torch.nn.functional as F
import unittest
import copy import copy
import random
import math import math
from apex.contrib.conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU import random
import unittest
import torch
import torch.nn.functional as F
HAS_CONV_BIAS_RELU = None
try:
from apex.contrib.conv_bias_relu import ConvBiasReLU, ConvBias, ConvBiasMaskReLU
except ImportError as e:
HAS_CONV_BIAS_RELU = False
else:
HAS_CONV_BIAS_RELU = True
@unittest.skipIf(not HAS_CONV_BIAS_RELU, "`apex.contrib.conv_bias_relu` is not found.")
class FusedDenseTest(unittest.TestCase): class FusedDenseTest(unittest.TestCase):
def setUp(self, seed=0): def setUp(self, seed=0):
torch.manual_seed(seed) torch.manual_seed(seed)
......
...@@ -25,14 +25,22 @@ ...@@ -25,14 +25,22 @@
# #
############################################################################### ###############################################################################
import math
import sys import sys
import unittest
import torch import torch
import numpy as np import numpy as np
import unittest
import math
import fmhalib as mha import fmhalib as mha
def _get_device_properties(device = torch.device("cuda")):
# type: (str or torch.device) -> Tuple[int, int]
properties = torch.cuda.get_device_properties(device)
return properties.major, properties.minor
def py_mha(qkv, amask, b, s, h, d): def py_mha(qkv, amask, b, s, h, d):
qkv = qkv.view(b, s, h, 3, d) qkv = qkv.view(b, s, h, 3, d)
q = qkv[:, :, :, 0, :].permute(0,2,1,3) q = qkv[:, :, :, 0, :].permute(0,2,1,3)
...@@ -48,6 +56,7 @@ def py_mha(qkv, amask, b, s, h, d): ...@@ -48,6 +56,7 @@ def py_mha(qkv, amask, b, s, h, d):
return ctx return ctx
@unittest.skipIf(not _get_device_properties() == (8, 0), "FMHA only supports sm80")
class TestFMHA(unittest.TestCase): class TestFMHA(unittest.TestCase):
def run_test(self, s: int, b: int, zero_tensors: bool): def run_test(self, s: int, b: int, zero_tensors: bool):
......
...@@ -12,6 +12,7 @@ except ImportError: ...@@ -12,6 +12,7 @@ except ImportError:
from apex.contrib.focal_loss import focal_loss from apex.contrib.focal_loss import focal_loss
@unittest.skipIf(not reference_available, "Reference implementation `torchvision.ops.focal_loss.sigmoid_focal_loss` is not available.")
class FocalLossTest(unittest.TestCase): class FocalLossTest(unittest.TestCase):
N_SAMPLES = 12 N_SAMPLES = 12
......
...@@ -59,10 +59,15 @@ def append_nvcc_threads(nvcc_extra_args): ...@@ -59,10 +59,15 @@ def append_nvcc_threads(nvcc_extra_args):
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool: def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
green = torch.backends.cudnn.is_available() and torch.backends.cudnn.version() >= required_cudnn_version cudnn_available = torch.backends.cudnn.is_available()
if not green: cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
warnings.warn(f"Skip `{global_option}` as it requires cuDNN {required_cudnn_version} or later") if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
return green warnings.warn(
f"Skip `{global_option}` as it requires cuDNN {required_cudnn_version} or later, "
f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
)
return False
return True
if not torch.cuda.is_available(): if not torch.cuda.is_available():
...@@ -639,18 +644,20 @@ if "--transducer" in sys.argv: ...@@ -639,18 +644,20 @@ if "--transducer" in sys.argv:
) )
) )
# note (mkozuki): Now `--fast_bottleneck` option (i.e. apex/contrib/bottleneck) depends on `--peer_memory` and `--nccl_p2p`.
if "--fast_bottleneck" in sys.argv: if "--fast_bottleneck" in sys.argv:
sys.argv.remove("--fast_bottleneck") sys.argv.remove("--fast_bottleneck")
raise_if_cuda_home_none("--fast_bottleneck") raise_if_cuda_home_none("--fast_bottleneck")
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"]) if check_cudnn_version_and_warn("--fast_bottleneck", 8400):
ext_modules.append( subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/cudnn-frontend/"])
CUDAExtension( ext_modules.append(
name="fast_bottleneck", CUDAExtension(
sources=["apex/contrib/csrc/bottleneck/bottleneck.cpp"], name="fast_bottleneck",
include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")], sources=["apex/contrib/csrc/bottleneck/bottleneck.cpp"],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag}, include_dirs=[os.path.join(this_dir, "apex/contrib/csrc/cudnn-frontend/include")],
extra_compile_args={"cxx": ["-O3"] + version_dependent_macros + generator_flag},
)
) )
)
if "--peer_memory" in sys.argv: if "--peer_memory" in sys.argv:
sys.argv.remove("--peer_memory") sys.argv.remove("--peer_memory")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment