Unverified Commit a2a3afbc authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

PyTorch >= 1.7.0 and TensorFlow >= 2.4.0 (#19016)

parent 9f4acd05
...@@ -155,13 +155,13 @@ _deps = [ ...@@ -155,13 +155,13 @@ _deps = [
"librosa", "librosa",
"starlette", "starlette",
"tensorflow-cpu>=2.3", "tensorflow-cpu>=2.3",
"tensorflow>=2.3", "tensorflow>=2.4",
"tensorflow-text", "tensorflow-text",
"tf2onnx", "tf2onnx",
"timeout-decorator", "timeout-decorator",
"timm", "timm",
"tokenizers>=0.11.1,!=0.11.3,<0.13", "tokenizers>=0.11.1,!=0.11.3,<0.13",
"torch>=1.0,!=0.12.0", "torch>=1.7,!=1.12.0",
"torchaudio", "torchaudio",
"pyctcdecode>=0.3.0", "pyctcdecode>=0.3.0",
"tqdm>=4.27", "tqdm>=4.27",
......
...@@ -44,7 +44,7 @@ class GELUActivation(nn.Module): ...@@ -44,7 +44,7 @@ class GELUActivation(nn.Module):
def __init__(self, use_gelu_python: bool = False): def __init__(self, use_gelu_python: bool = False):
super().__init__() super().__init__()
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.4") or use_gelu_python: if use_gelu_python:
self.act = self._gelu_python self.act = self._gelu_python
else: else:
self.act = nn.functional.gelu self.act = nn.functional.gelu
...@@ -108,18 +108,8 @@ class SiLUActivation(nn.Module): ...@@ -108,18 +108,8 @@ class SiLUActivation(nn.Module):
later. later.
""" """
def __init__(self):
super().__init__()
if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"):
self.act = self._silu_python
else:
self.act = nn.functional.silu
def _silu_python(self, input: Tensor) -> Tensor:
return input * torch.sigmoid(input)
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return self.act(input) return nn.functional.silu(input)
class MishActivation(nn.Module): class MishActivation(nn.Module):
......
...@@ -61,13 +61,13 @@ deps = { ...@@ -61,13 +61,13 @@ deps = {
"librosa": "librosa", "librosa": "librosa",
"starlette": "starlette", "starlette": "starlette",
"tensorflow-cpu": "tensorflow-cpu>=2.3", "tensorflow-cpu": "tensorflow-cpu>=2.3",
"tensorflow": "tensorflow>=2.3", "tensorflow": "tensorflow>=2.4",
"tensorflow-text": "tensorflow-text", "tensorflow-text": "tensorflow-text",
"tf2onnx": "tf2onnx", "tf2onnx": "tf2onnx",
"timeout-decorator": "timeout-decorator", "timeout-decorator": "timeout-decorator",
"timm": "timm", "timm": "timm",
"tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13", "tokenizers": "tokenizers>=0.11.1,!=0.11.3,<0.13",
"torch": "torch>=1.0,!=0.12.0", "torch": "torch>=1.7,!=1.12.0",
"torchaudio": "torchaudio", "torchaudio": "torchaudio",
"pyctcdecode": "pyctcdecode>=0.3.0", "pyctcdecode": "pyctcdecode>=0.3.0",
"tqdm": "tqdm>=4.27", "tqdm": "tqdm>=4.27",
......
...@@ -34,12 +34,7 @@ from ...modeling_outputs import ( ...@@ -34,12 +34,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ( from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -216,12 +211,9 @@ class AlbertEmbeddings(nn.Module): ...@@ -216,12 +211,9 @@ class AlbertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
"token_type_ids", )
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward( def forward(
......
...@@ -40,12 +40,7 @@ from ...modeling_outputs import ( ...@@ -40,12 +40,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ( from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -199,12 +194,9 @@ class BertEmbeddings(nn.Module): ...@@ -199,12 +194,9 @@ class BertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
"token_type_ids", )
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
def forward( def forward(
self, self,
......
...@@ -37,7 +37,7 @@ from ...modeling_outputs import ( ...@@ -37,7 +37,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, is_torch_greater_than_1_6 from ...pytorch_utils import apply_chunking_to_forward
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -259,12 +259,9 @@ class BigBirdEmbeddings(nn.Module): ...@@ -259,12 +259,9 @@ class BigBirdEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
"token_type_ids", )
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
# End copy # End copy
self.rescale_embeddings = config.rescale_embeddings self.rescale_embeddings = config.rescale_embeddings
......
...@@ -35,12 +35,7 @@ from ...modeling_outputs import ( ...@@ -35,12 +35,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import ( from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_convbert import ConvBertConfig from .configuration_convbert import ConvBertConfig
...@@ -198,12 +193,9 @@ class ConvBertEmbeddings(nn.Module): ...@@ -198,12 +193,9 @@ class ConvBertEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
"token_type_ids", )
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
def forward( def forward(
self, self,
......
...@@ -34,12 +34,7 @@ from ...modeling_outputs import ( ...@@ -34,12 +34,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ( from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -87,12 +82,9 @@ class Data2VecTextForTextEmbeddings(nn.Module): ...@@ -87,12 +82,9 @@ class Data2VecTextForTextEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
"token_type_ids", )
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
# End copy # End copy
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
......
...@@ -22,15 +22,12 @@ from typing import Optional, Tuple, Union ...@@ -22,15 +22,12 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.cuda.amp import autocast
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ( from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
Conv1D,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_1_6,
prune_conv1d_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -38,15 +35,6 @@ from ...utils import ( ...@@ -38,15 +35,6 @@ from ...utils import (
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
if is_torch_greater_or_equal_than_1_6:
is_amp_available = True
from torch.cuda.amp import autocast
else:
is_amp_available = False
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from .configuration_decision_transformer import DecisionTransformerConfig from .configuration_decision_transformer import DecisionTransformerConfig
...@@ -235,12 +223,7 @@ class DecisionTransformerGPT2Attention(nn.Module): ...@@ -235,12 +223,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
scale_factor /= float(self.layer_idx + 1) scale_factor /= float(self.layer_idx + 1)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
if is_amp_available: with autocast(enabled=False):
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
else:
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
......
...@@ -39,12 +39,7 @@ from ...modeling_outputs import ( ...@@ -39,12 +39,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ( from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -106,10 +101,9 @@ class Embeddings(nn.Module): ...@@ -106,10 +101,9 @@ class Embeddings(nn.Module):
self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12) self.LayerNorm = nn.LayerNorm(config.dim, eps=1e-12)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False )
)
def forward(self, input_ids: torch.Tensor) -> torch.Tensor: def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
""" """
......
...@@ -36,12 +36,7 @@ from ...modeling_outputs import ( ...@@ -36,12 +36,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import ( from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -169,12 +164,9 @@ class ElectraEmbeddings(nn.Module): ...@@ -169,12 +164,9 @@ class ElectraEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
"token_type_ids", )
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def forward( def forward(
......
...@@ -38,12 +38,7 @@ from ...modeling_outputs import ( ...@@ -38,12 +38,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ( from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -96,12 +91,9 @@ class ErnieEmbeddings(nn.Module): ...@@ -96,12 +91,9 @@ class ErnieEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
"token_type_ids", )
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
def forward( def forward(
self, self,
......
...@@ -22,7 +22,6 @@ import torch ...@@ -22,7 +22,6 @@ import torch
from torch import nn from torch import nn
from ...modeling_outputs import BaseModelOutput from ...modeling_outputs import BaseModelOutput
from ...pytorch_utils import is_torch_greater_than_1_6
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from ..xlm.modeling_xlm import ( from ..xlm.modeling_xlm import (
XLMForMultipleChoice, XLMForMultipleChoice,
...@@ -139,10 +138,9 @@ class FlaubertModel(XLMModel): ...@@ -139,10 +138,9 @@ class FlaubertModel(XLMModel):
super().__init__(config) super().__init__(config)
self.layerdrop = getattr(config, "layerdrop", 0.0) self.layerdrop = getattr(config, "layerdrop", 0.0)
self.pre_norm = getattr(config, "pre_norm", False) self.pre_norm = getattr(config, "pre_norm", False)
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False )
)
@add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(FLAUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
......
...@@ -29,7 +29,6 @@ from transformers.utils.doc import add_code_sample_docstrings ...@@ -29,7 +29,6 @@ from transformers.utils.doc import add_code_sample_docstrings
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...pytorch_utils import is_torch_greater_than_1_6
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
...@@ -392,12 +391,9 @@ class FlavaTextEmbeddings(nn.Module): ...@@ -392,12 +391,9 @@ class FlavaTextEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
"token_type_ids", )
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
def forward( def forward(
self, self,
......
...@@ -43,7 +43,7 @@ from ...modeling_outputs import ( ...@@ -43,7 +43,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, is_torch_greater_than_1_6 from ...pytorch_utils import apply_chunking_to_forward
from ...utils import ( from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
...@@ -117,12 +117,9 @@ class FNetEmbeddings(nn.Module): ...@@ -117,12 +117,9 @@ class FNetEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
"token_type_ids", )
torch.zeros(self.position_ids.size(), dtype=torch.long),
persistent=False,
)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None: if input_ids is not None:
......
...@@ -23,22 +23,9 @@ from typing import Optional, Tuple, Union ...@@ -23,22 +23,9 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...pytorch_utils import (
Conv1D,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_1_6,
prune_conv1d_layer,
)
if is_torch_greater_or_equal_than_1_6:
is_amp_available = True
from torch.cuda.amp import autocast
else:
is_amp_available = False
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -47,6 +34,7 @@ from ...modeling_outputs import ( ...@@ -47,6 +34,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -247,12 +235,7 @@ class GPT2Attention(nn.Module): ...@@ -247,12 +235,7 @@ class GPT2Attention(nn.Module):
scale_factor /= float(self.layer_idx + 1) scale_factor /= float(self.layer_idx + 1)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
if is_amp_available: with autocast(enabled=False):
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
else:
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
......
...@@ -22,22 +22,9 @@ from typing import Any, Optional, Tuple, Union ...@@ -22,22 +22,9 @@ from typing import Any, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...pytorch_utils import (
Conv1D,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_1_6,
prune_conv1d_layer,
)
if is_torch_greater_or_equal_than_1_6:
is_amp_available = True
from torch.cuda.amp import autocast
else:
is_amp_available = False
from ...activations import ACT2FN from ...activations import ACT2FN
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
...@@ -45,6 +32,7 @@ from ...modeling_outputs import ( ...@@ -45,6 +32,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast, SequenceClassifierOutputWithPast,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_imagegpt import ImageGPTConfig from .configuration_imagegpt import ImageGPTConfig
...@@ -299,12 +287,7 @@ class ImageGPTAttention(nn.Module): ...@@ -299,12 +287,7 @@ class ImageGPTAttention(nn.Module):
scale_factor /= float(self.layer_idx + 1) scale_factor /= float(self.layer_idx + 1)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
if is_amp_available: with autocast(enabled=False):
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
else:
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
......
...@@ -33,7 +33,6 @@ from ...modeling_utils import ( ...@@ -33,7 +33,6 @@ from ...modeling_utils import (
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
prune_linear_layer, prune_linear_layer,
) )
from ...pytorch_utils import is_torch_greater_than_1_6
from ...utils import logging from ...utils import logging
from .configuration_mctct import MCTCTConfig from .configuration_mctct import MCTCTConfig
...@@ -153,12 +152,11 @@ class MCTCTEmbeddings(nn.Module): ...@@ -153,12 +152,11 @@ class MCTCTEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "token_type_ids",
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), persistent=False,
persistent=False, )
)
def forward( def forward(
self, input_features=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 self, input_features=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
......
...@@ -38,12 +38,7 @@ from ...modeling_outputs import ( ...@@ -38,12 +38,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ( from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_code_sample_docstrings, add_code_sample_docstrings,
...@@ -187,12 +182,9 @@ class NezhaEmbeddings(nn.Module): ...@@ -187,12 +182,9 @@ class NezhaEmbeddings(nn.Module):
# any TensorFlow checkpoint file # any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "token_type_ids", torch.zeros((1, config.max_position_embeddings), dtype=torch.long), persistent=False
"token_type_ids", )
torch.zeros((1, config.max_position_embeddings), dtype=torch.long),
persistent=False,
)
def forward( def forward(
self, self,
......
...@@ -33,12 +33,7 @@ from ...modeling_outputs import ( ...@@ -33,12 +33,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ( from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_than_1_6,
prune_linear_layer,
)
from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
from .configuration_nystromformer import NystromformerConfig from .configuration_nystromformer import NystromformerConfig
...@@ -72,12 +67,11 @@ class NystromformerEmbeddings(nn.Module): ...@@ -72,12 +67,11 @@ class NystromformerEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + 2)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
if is_torch_greater_than_1_6: self.register_buffer(
self.register_buffer( "token_type_ids",
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device),
torch.zeros(self.position_ids.size(), dtype=torch.long, device=self.position_ids.device), persistent=False,
persistent=False, )
)
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None): def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None: if input_ids is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment