Unverified Commit 02b176c4 authored by LSinev's avatar LSinev Committed by GitHub
Browse files

Fix torch version comparisons (#18460)

Comparisons like
version.parse(torch.__version__) > version.parse("1.6")
are True for torch==1.6.0+cu101 or torch==1.6.0+cpu

version.parse(version.parse(torch.__version__).base_version) are preferred (and available in pytorch_utils.py
parent be41eaf5
...@@ -30,7 +30,7 @@ from transformers import ( ...@@ -30,7 +30,7 @@ from transformers import (
if is_apex_available(): if is_apex_available():
from apex import amp from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"): if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
_is_native_amp_available = True _is_native_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
......
...@@ -33,7 +33,7 @@ if is_apex_available(): ...@@ -33,7 +33,7 @@ if is_apex_available():
from apex import amp from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"): if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
_is_native_amp_available = True _is_native_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
......
...@@ -26,7 +26,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices ...@@ -26,7 +26,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
if is_apex_available(): if is_apex_available():
from apex import amp from apex import amp
if version.parse(torch.__version__) >= version.parse("1.6"): if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.6"):
_is_native_amp_available = True _is_native_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
......
...@@ -44,7 +44,7 @@ class GELUActivation(nn.Module): ...@@ -44,7 +44,7 @@ class GELUActivation(nn.Module):
def __init__(self, use_gelu_python: bool = False): def __init__(self, use_gelu_python: bool = False):
super().__init__() super().__init__()
if version.parse(torch.__version__) < version.parse("1.4") or use_gelu_python: if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.4") or use_gelu_python:
self.act = self._gelu_python self.act = self._gelu_python
else: else:
self.act = nn.functional.gelu self.act = nn.functional.gelu
...@@ -110,7 +110,7 @@ class SiLUActivation(nn.Module): ...@@ -110,7 +110,7 @@ class SiLUActivation(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
if version.parse(torch.__version__) < version.parse("1.7"): if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
self.act = self._silu_python self.act = self._silu_python
else: else:
self.act = nn.functional.silu self.act = nn.functional.silu
...@@ -130,7 +130,7 @@ class MishActivation(nn.Module): ...@@ -130,7 +130,7 @@ class MishActivation(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
if version.parse(torch.__version__) < version.parse("1.9"): if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.9"):
self.act = self._mish_python self.act = self._mish_python
else: else:
self.act = nn.functional.mish self.act = nn.functional.mish
......
...@@ -273,6 +273,8 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format ...@@ -273,6 +273,8 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
import torch import torch
from torch.onnx import export from torch.onnx import export
from .pytorch_utils import is_torch_less_than_1_11
print(f"Using framework PyTorch: {torch.__version__}") print(f"Using framework PyTorch: {torch.__version__}")
with torch.no_grad(): with torch.no_grad():
...@@ -281,7 +283,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format ...@@ -281,7 +283,7 @@ def convert_pytorch(nlp: Pipeline, opset: int, output: Path, use_external_format
# PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11, # PyTorch deprecated the `enable_onnx_checker` and `use_external_data_format` arguments in v1.11,
# so we check the torch version for backwards compatibility # so we check the torch version for backwards compatibility
if parse(torch.__version__) <= parse("1.10.99"): if is_torch_less_than_1_11:
export( export(
nlp.model, nlp.model,
model_args, model_args,
......
...@@ -20,7 +20,6 @@ from dataclasses import dataclass ...@@ -20,7 +20,6 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -35,7 +34,12 @@ from ...modeling_outputs import ( ...@@ -35,7 +34,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -212,7 +216,7 @@ class AlbertEmbeddings(nn.Module): ...@@ -212,7 +216,7 @@ class AlbertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),
......
...@@ -24,7 +24,6 @@ from typing import List, Optional, Tuple, Union ...@@ -24,7 +24,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -41,7 +40,12 @@ from ...modeling_outputs import ( ...@@ -41,7 +40,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -195,7 +199,7 @@ class BertEmbeddings(nn.Module): ...@@ -195,7 +199,7 @@ class BertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),
......
...@@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union ...@@ -23,7 +23,6 @@ from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -38,7 +37,7 @@ from ...modeling_outputs import ( ...@@ -38,7 +37,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward from ...pytorch_utils import apply_chunking_to_forward, is_torch_greater_than_1_6
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -260,7 +259,7 @@ class BigBirdEmbeddings(nn.Module): ...@@ -260,7 +259,7 @@ class BigBirdEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),
......
...@@ -22,7 +22,6 @@ from typing import Optional, Tuple, Union ...@@ -22,7 +22,6 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -36,7 +35,12 @@ from ...modeling_outputs import ( ...@@ -36,7 +35,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_convbert import ConvBertConfig from .configuration_convbert import ConvBertConfig
...@@ -194,7 +198,7 @@ class ConvBertEmbeddings(nn.Module): ...@@ -194,7 +198,7 @@ class ConvBertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),
......
...@@ -19,7 +19,6 @@ from typing import List, Optional, Tuple, Union ...@@ -19,7 +19,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -35,7 +34,12 @@ from ...modeling_outputs import ( ...@@ -35,7 +34,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -83,7 +87,7 @@ class Data2VecTextForTextEmbeddings(nn.Module): ...@@ -83,7 +87,7 @@ class Data2VecTextForTextEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),
......
...@@ -21,12 +21,16 @@ from typing import Optional, Tuple, Union ...@@ -21,12 +21,16 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer from ...pytorch_utils import (
Conv1D,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_1_6,
prune_conv1d_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -36,7 +40,7 @@ from ...utils import ( ...@@ -36,7 +40,7 @@ from ...utils import (
) )
if version.parse(torch.__version__) >= version.parse("1.6"): if is_torch_greater_or_equal_than_1_6:
is_amp_available = True is_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
else: else:
......
...@@ -23,7 +23,6 @@ from typing import Dict, List, Optional, Set, Tuple, Union ...@@ -23,7 +23,6 @@ from typing import Dict, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -40,7 +39,12 @@ from ...modeling_outputs import ( ...@@ -40,7 +39,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -102,7 +106,7 @@ class Embeddings(nn.Module): ...@@ -102,7 +106,7 @@ class Embeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
) )
......
...@@ -21,7 +21,6 @@ from typing import List, Optional, Tuple, Union ...@@ -21,7 +21,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -37,7 +36,12 @@ from ...modeling_outputs import ( ...@@ -37,7 +36,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -165,7 +169,7 @@ class ElectraEmbeddings(nn.Module): ...@@ -165,7 +169,7 @@ class ElectraEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),
......
...@@ -19,10 +19,10 @@ import random ...@@ -19,10 +19,10 @@ import random
from typing import Dict, Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import torch import torch
from packaging import version
from torch import nn from torch import nn
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...pytorch_utils import is_torch_greater_than_1_6
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ..xlm.modeling_xlm import ( from ..xlm.modeling_xlm import (
XLMForMultipleChoice, XLMForMultipleChoice,
...@@ -139,7 +139,7 @@ class FlaubertModel(XLMModel): ...@@ -139,7 +139,7 @@ class FlaubertModel(XLMModel):
super().__init__(config) super().__init__(config)
self.layerdrop = getattr(config, "layerdrop", 0.0) self.layerdrop = getattr(config, "layerdrop", 0.0)
self.pre_norm = getattr(config, "pre_norm", False) self.pre_norm = getattr(config, "pre_norm", False)
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
) )
......
...@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union ...@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from transformers.utils.doc import add_code_sample_docstrings from transformers.utils.doc import add_code_sample_docstrings
...@@ -30,6 +29,7 @@ from transformers.utils.doc import add_code_sample_docstrings ...@@ -30,6 +29,7 @@ from transformers.utils.doc import add_code_sample_docstrings
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import is_torch_greater_than_1_6
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -392,7 +392,7 @@ class FlavaTextEmbeddings(nn.Module): ...@@ -392,7 +392,7 @@ class FlavaTextEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),
......
...@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union ...@@ -21,7 +21,6 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -44,7 +43,7 @@ from ...modeling_outputs import ( ...@@ -44,7 +43,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward from ...pytorch_utils import apply_chunking_to_forward, is_torch_greater_than_1_6
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -118,7 +117,7 @@ class FNetEmbeddings(nn.Module): ...@@ -118,7 +117,7 @@ class FNetEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long), torch.zeros(self.position_ids.size(), dtype=torch.long),
......
...@@ -22,12 +22,18 @@ from typing import Optional, Tuple, Union ...@@ -22,12 +22,18 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...pytorch_utils import (
Conv1D,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_1_6,
prune_conv1d_layer,
)
if version.parse(torch.__version__) >= version.parse("1.6"): if is_torch_greater_or_equal_than_1_6:
is_amp_available = True is_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
else: else:
...@@ -41,7 +47,6 @@ from ...modeling_outputs import ( ...@@ -41,7 +47,6 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
......
...@@ -21,12 +21,18 @@ from typing import Any, Optional, Tuple, Union ...@@ -21,12 +21,18 @@ from typing import Any, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...pytorch_utils import (
Conv1D,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_1_6,
prune_conv1d_layer,
)
if version.parse(torch.__version__) >= version.parse("1.6"): if is_torch_greater_or_equal_than_1_6:
is_amp_available = True is_amp_available = True
from torch.cuda.amp import autocast from torch.cuda.amp import autocast
else: else:
...@@ -39,7 +45,6 @@ from ...modeling_outputs import ( ...@@ -39,7 +45,6 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_imagegpt import ImageGPTConfig from .configuration_imagegpt import ImageGPTConfig
......
...@@ -21,7 +21,6 @@ from typing import Optional ...@@ -21,7 +21,6 @@ from typing import Optional
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
...@@ -34,6 +33,7 @@ from ...modeling_utils import ( ...@@ -34,6 +33,7 @@ from ...modeling_utils import (
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from ...pytorch_utils import is_torch_greater_than_1_6
from ...utils import logging from ...utils import logging
from .configuration_mctct import MCTCTConfig from .configuration_mctct import MCTCTConfig
...@@ -153,7 +153,7 @@ class MCTCTEmbeddings(nn.Module): ...@@ -153,7 +153,7 @@ class MCTCTEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
......
...@@ -23,7 +23,6 @@ from typing import List, Optional, Tuple, Union ...@@ -23,7 +23,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
...@@ -39,7 +38,12 @@ from ...modeling_outputs import ( ...@@ -39,7 +38,12 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -183,7 +187,7 @@ class NezhaEmbeddings(nn.Module): ...@@ -183,7 +187,7 @@ class NezhaEmbeddings(nn.Module):
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
if version.parse(torch.__version__) > version.parse("1.6.0"): if is_torch_greater_than_1_6:
self.register_buffer( self.register_buffer(
"token_type_ids", "token_type_ids",
torch.zeros((1, config.max_position_embeddings), dtype=torch.long), torch.zeros((1, config.max_position_embeddings), dtype=torch.long),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment