Unverified Commit 8217d4e3 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[prophetnet] wrong import (#9349)

```
python -c "from apex.normalization import FusedProphetNetLayerNorm"
Traceback (most recent call last):
  File "<string>", line 1, in <module>
ImportError: cannot import name 'FusedProphetNetLayerNorm' from 'apex.normalization' (/home/stas/anaconda3/envs/main-38/lib/python3.8/site-packages/apex/normalization/__init__.py)
```
It looks like this code has never been tested, so it silently fails inside try/except.

Discovered this by accident in https://github.com/huggingface/transformers/issues/9338#issuecomment-752217708
parent 912f6881
...@@ -513,9 +513,9 @@ class ProphetNetDecoderLMOutput(ModelOutput): ...@@ -513,9 +513,9 @@ class ProphetNetDecoderLMOutput(ModelOutput):
def ProphetNetLayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True): def ProphetNetLayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
if torch.cuda.is_available(): if torch.cuda.is_available():
try: try:
from apex.normalization import FusedProphetNetLayerNorm from apex.normalization import FusedLayerNorm
return FusedProphetNetLayerNorm(normalized_shape, eps, elementwise_affine) return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
except ImportError: except ImportError:
pass pass
return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment