Unverified Commit ccf3bea4 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support old GPUS like p100 (#7)

parent af4f9088
...@@ -7,8 +7,17 @@ import numbers ...@@ -7,8 +7,17 @@ import numbers
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
from torch.nn import functional as F from torch.nn import functional as F
import unicore_fused_layernorm
import unicore_fused_layernorm_backward_gamma_beta try:
import unicore_fused_layernorm
import unicore_fused_layernorm_backward_gamma_beta
HAS_LAYER_NORM = True
except:
print("fused_layer_norm is not installed corrected")
HAS_LAYER_NORM = False
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 7:
HAS_LAYER_NORM = False
class FusedLayerNormFastFunction(torch.autograd.Function): class FusedLayerNormFastFunction(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -54,7 +63,7 @@ class LayerNorm(torch.nn.Module): ...@@ -54,7 +63,7 @@ class LayerNorm(torch.nn.Module):
init.zeros_(self.bias) init.zeros_(self.bias)
def forward(self, input): def forward(self, input):
if not input.is_cuda: if not input.is_cuda or not HAS_LAYER_NORM:
return F.layer_norm( return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps) input, self.normalized_shape, self.weight, self.bias, self.eps)
return FusedLayerNormFastFunction.apply( return FusedLayerNormFastFunction.apply(
......
...@@ -3,9 +3,17 @@ ...@@ -3,9 +3,17 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
import unicore_fused_softmax_dropout
import torch.nn.functional as F import torch.nn.functional as F
try:
import unicore_fused_softmax_dropout
HAS_SOFTMAX = True
except:
print("fused_softmax is not installed corrected")
HAS_SOFTMAX = False
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 7:
HAS_SOFTMAX = False
class SoftmaxDropoutFast(torch.autograd.Function): class SoftmaxDropoutFast(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -94,7 +102,7 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None, ...@@ -94,7 +102,7 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None,
torch.Tensor: the result after softmax torch.Tensor: the result after softmax
""" """
input = input.contiguous() input = input.contiguous()
if input.is_cuda: if input.is_cuda and HAS_SOFTMAX:
input_size = input.size() input_size = input.size()
if mask is not None: if mask is not None:
_check_mask(mask, input) _check_mask(mask, input)
......
...@@ -34,6 +34,7 @@ class UnicoreAdam(UnicoreOptimizer): ...@@ -34,6 +34,7 @@ class UnicoreAdam(UnicoreOptimizer):
not getattr(args, "use_old_adam", False) not getattr(args, "use_old_adam", False)
and fused_adam_cls is not None and fused_adam_cls is not None
and torch.cuda.is_available() and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 7
) )
if use_fused_adam: if use_fused_adam:
logger.info("using FusedAdam") logger.info("using FusedAdam")
......
...@@ -12,23 +12,26 @@ import sys ...@@ -12,23 +12,26 @@ import sys
import warnings import warnings
from functools import partial from functools import partial
from typing import List, Callable, Any, Dict from typing import List, Callable, Any, Dict
import torch
import torch.nn.functional as F
try: try:
import unicore_fused_multi_tensor import unicore_fused_multi_tensor
HAS_MULTI_TENSOR = True HAS_MULTI_TENSOR = True
except: except:
print("please install latest fused_ops to get multi_tensor.") print("fused_multi_tensor is not installed corrected")
HAS_MULTI_TENSOR = False HAS_MULTI_TENSOR = False
try: try:
import unicore_fused_rounding import unicore_fused_rounding
HAS_FUSED_ROUNDING = True HAS_FUSED_ROUNDING = True
except: except:
print("please install latest fused_ops to get fused_rounding.") print("fused_rounding is not installed corrected")
HAS_FUSED_ROUNDING = False HAS_FUSED_ROUNDING = False
import torch if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 7:
import torch.nn.functional as F HAS_MULTI_TENSOR = False
HAS_FUSED_ROUNDING = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment