Unverified Commit ccf3bea4 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support old GPUS like p100 (#7)

parent af4f9088
......@@ -7,8 +7,17 @@ import numbers
from torch.nn.parameter import Parameter
from torch.nn import init
from torch.nn import functional as F
import unicore_fused_layernorm
import unicore_fused_layernorm_backward_gamma_beta
try:
import unicore_fused_layernorm
import unicore_fused_layernorm_backward_gamma_beta
HAS_LAYER_NORM = True
except:
print("fused_layer_norm is not installed corrected")
HAS_LAYER_NORM = False
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 7:
HAS_LAYER_NORM = False
class FusedLayerNormFastFunction(torch.autograd.Function):
@staticmethod
......@@ -54,7 +63,7 @@ class LayerNorm(torch.nn.Module):
init.zeros_(self.bias)
def forward(self, input):
if not input.is_cuda:
if not input.is_cuda or not HAS_LAYER_NORM:
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps)
return FusedLayerNormFastFunction.apply(
......
......@@ -3,9 +3,17 @@
# LICENSE file in the root directory of this source tree.
import torch
import unicore_fused_softmax_dropout
import torch.nn.functional as F
try:
import unicore_fused_softmax_dropout
HAS_SOFTMAX = True
except:
print("fused_softmax is not installed corrected")
HAS_SOFTMAX = False
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 7:
HAS_SOFTMAX = False
class SoftmaxDropoutFast(torch.autograd.Function):
@staticmethod
......@@ -94,7 +102,7 @@ def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None,
torch.Tensor: the result after softmax
"""
input = input.contiguous()
if input.is_cuda:
if input.is_cuda and HAS_SOFTMAX:
input_size = input.size()
if mask is not None:
_check_mask(mask, input)
......
......@@ -34,6 +34,7 @@ class UnicoreAdam(UnicoreOptimizer):
not getattr(args, "use_old_adam", False)
and fused_adam_cls is not None
and torch.cuda.is_available()
and torch.cuda.get_device_capability()[0] >= 7
)
if use_fused_adam:
logger.info("using FusedAdam")
......
......@@ -12,23 +12,26 @@ import sys
import warnings
from functools import partial
from typing import List, Callable, Any, Dict
import torch
import torch.nn.functional as F
try:
import unicore_fused_multi_tensor
HAS_MULTI_TENSOR = True
except:
print("please install latest fused_ops to get multi_tensor.")
print("fused_multi_tensor is not installed corrected")
HAS_MULTI_TENSOR = False
try:
import unicore_fused_rounding
HAS_FUSED_ROUNDING = True
except:
print("please install latest fused_ops to get fused_rounding.")
print("fused_rounding is not installed corrected")
HAS_FUSED_ROUNDING = False
import torch
import torch.nn.functional as F
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 7:
HAS_MULTI_TENSOR = False
HAS_FUSED_ROUNDING = False
logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment