Unverified Commit ae333d04 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

torch.cuda.is_available() is redundant as apex handles that internally (#9350)

parent 8217d4e3
......@@ -110,13 +110,12 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
def BartLayerNorm(normalized_shape: torch.Size, eps: float = 1e-5, elementwise_affine: bool = True):
if torch.cuda.is_available():
try:
from apex.normalization import FusedLayerNorm
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
......
......@@ -265,14 +265,12 @@ FSMT_INPUTS_DOCSTRING = r"""
have_fused_layer_norm = False
if torch.cuda.is_available():
try:
from apex.normalization import FusedLayerNorm
have_fused_layer_norm = True
except ImportError:
pass
try:
from apex.normalization import FusedLayerNorm
have_fused_layer_norm = True
except ImportError:
pass
LayerNorm = FusedLayerNorm if have_fused_layer_norm else torch.nn.LayerNorm
......
......@@ -511,13 +511,12 @@ class ProphetNetDecoderLMOutput(ModelOutput):
def ProphetNetLayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
if torch.cuda.is_available():
try:
from apex.normalization import FusedLayerNorm
try:
from apex.normalization import FusedLayerNorm
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError:
pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment