Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a2a3afbc
Unverified
Commit
a2a3afbc
authored
Sep 14, 2022
by
Sylvain Gugger
Committed by
GitHub
Sep 14, 2022
Browse files
PyTorch >= 1.7.0 and TensorFlow >= 2.4.0 (#19016)
parent
9f4acd05
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
73 additions
and
216 deletions
+73
-216
setup.py
setup.py
+2
-2
src/transformers/activations.py
src/transformers/activations.py
+2
-12
src/transformers/dependency_versions_table.py
src/transformers/dependency_versions_table.py
+2
-2
src/transformers/models/albert/modeling_albert.py
src/transformers/models/albert/modeling_albert.py
+4
-12
src/transformers/models/bert/modeling_bert.py
src/transformers/models/bert/modeling_bert.py
+4
-12
src/transformers/models/big_bird/modeling_big_bird.py
src/transformers/models/big_bird/modeling_big_bird.py
+4
-7
src/transformers/models/convbert/modeling_convbert.py
src/transformers/models/convbert/modeling_convbert.py
+4
-12
src/transformers/models/data2vec/modeling_data2vec_text.py
src/transformers/models/data2vec/modeling_data2vec_text.py
+4
-12
src/transformers/models/decision_transformer/modeling_decision_transformer.py
...els/decision_transformer/modeling_decision_transformer.py
+4
-21
src/transformers/models/distilbert/modeling_distilbert.py
src/transformers/models/distilbert/modeling_distilbert.py
+4
-10
src/transformers/models/electra/modeling_electra.py
src/transformers/models/electra/modeling_electra.py
+4
-12
src/transformers/models/ernie/modeling_ernie.py
src/transformers/models/ernie/modeling_ernie.py
+4
-12
src/transformers/models/flaubert/modeling_flaubert.py
src/transformers/models/flaubert/modeling_flaubert.py
+3
-5
src/transformers/models/flava/modeling_flava.py
src/transformers/models/flava/modeling_flava.py
+3
-7
src/transformers/models/fnet/modeling_fnet.py
src/transformers/models/fnet/modeling_fnet.py
+4
-7
src/transformers/models/gpt2/modeling_gpt2.py
src/transformers/models/gpt2/modeling_gpt2.py
+3
-20
src/transformers/models/imagegpt/modeling_imagegpt.py
src/transformers/models/imagegpt/modeling_imagegpt.py
+3
-20
src/transformers/models/mctct/modeling_mctct.py
src/transformers/models/mctct/modeling_mctct.py
+5
-7
src/transformers/models/nezha/modeling_nezha.py
src/transformers/models/nezha/modeling_nezha.py
+4
-12
src/transformers/models/nystromformer/modeling_nystromformer.py
...ansformers/models/nystromformer/modeling_nystromformer.py
+6
-12
No files found.
setup.py
View file @
a2a3afbc
...
@@ -155,13 +155,13 @@ _deps = [
...
@@ -155,13 +155,13 @@ _deps = [
"librosa"
,
"librosa"
,
"starlette"
,
"starlette"
,
"tensorflow-cpu>=2.3"
,
"tensorflow-cpu>=2.3"
,
"tensorflow>=2.
3
"
,
"tensorflow>=2.
4
"
,
"tensorflow-text"
,
"tensorflow-text"
,
"tf2onnx"
,
"tf2onnx"
,
"timeout-decorator"
,
"timeout-decorator"
,
"timm"
,
"timm"
,
"tokenizers>=0.11.1,!=0.11.3,<0.13"
,
"tokenizers>=0.11.1,!=0.11.3,<0.13"
,
"torch>=1.
0
,!=
0
.12.0"
,
"torch>=1.
7
,!=
1
.12.0"
,
"torchaudio"
,
"torchaudio"
,
"pyctcdecode>=0.3.0"
,
"pyctcdecode>=0.3.0"
,
"tqdm>=4.27"
,
"tqdm>=4.27"
,
...
...
src/transformers/activations.py
View file @
a2a3afbc
...
@@ -44,7 +44,7 @@ class GELUActivation(nn.Module):
...
@@ -44,7 +44,7 @@ class GELUActivation(nn.Module):
def
__init__
(
self
,
use_gelu_python
:
bool
=
False
):
def
__init__
(
self
,
use_gelu_python
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
if
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
<
version
.
parse
(
"1.4"
)
or
use_gelu_python
:
if
use_gelu_python
:
self
.
act
=
self
.
_gelu_python
self
.
act
=
self
.
_gelu_python
else
:
else
:
self
.
act
=
nn
.
functional
.
gelu
self
.
act
=
nn
.
functional
.
gelu
...
@@ -108,18 +108,8 @@ class SiLUActivation(nn.Module):
...
@@ -108,18 +108,8 @@ class SiLUActivation(nn.Module):
later.
later.
"""
"""
def
__init__
(
self
):
super
().
__init__
()
if
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
<
version
.
parse
(
"1.7"
):
self
.
act
=
self
.
_silu_python
else
:
self
.
act
=
nn
.
functional
.
silu
def
_silu_python
(
self
,
input
:
Tensor
)
->
Tensor
:
return
input
*
torch
.
sigmoid
(
input
)
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
input
:
Tensor
)
->
Tensor
:
return
self
.
act
(
input
)
return
nn
.
functional
.
silu
(
input
)
class
MishActivation
(
nn
.
Module
):
class
MishActivation
(
nn
.
Module
):
...
...
src/transformers/dependency_versions_table.py
View file @
a2a3afbc
...
@@ -61,13 +61,13 @@ deps = {
...
@@ -61,13 +61,13 @@ deps = {
"librosa"
:
"librosa"
,
"librosa"
:
"librosa"
,
"starlette"
:
"starlette"
,
"starlette"
:
"starlette"
,
"tensorflow-cpu"
:
"tensorflow-cpu>=2.3"
,
"tensorflow-cpu"
:
"tensorflow-cpu>=2.3"
,
"tensorflow"
:
"tensorflow>=2.
3
"
,
"tensorflow"
:
"tensorflow>=2.
4
"
,
"tensorflow-text"
:
"tensorflow-text"
,
"tensorflow-text"
:
"tensorflow-text"
,
"tf2onnx"
:
"tf2onnx"
,
"tf2onnx"
:
"tf2onnx"
,
"timeout-decorator"
:
"timeout-decorator"
,
"timeout-decorator"
:
"timeout-decorator"
,
"timm"
:
"timm"
,
"timm"
:
"timm"
,
"tokenizers"
:
"tokenizers>=0.11.1,!=0.11.3,<0.13"
,
"tokenizers"
:
"tokenizers>=0.11.1,!=0.11.3,<0.13"
,
"torch"
:
"torch>=1.
0
,!=
0
.12.0"
,
"torch"
:
"torch>=1.
7
,!=
1
.12.0"
,
"torchaudio"
:
"torchaudio"
,
"torchaudio"
:
"torchaudio"
,
"pyctcdecode"
:
"pyctcdecode>=0.3.0"
,
"pyctcdecode"
:
"pyctcdecode>=0.3.0"
,
"tqdm"
:
"tqdm>=4.27"
,
"tqdm"
:
"tqdm>=4.27"
,
...
...
src/transformers/models/albert/modeling_albert.py
View file @
a2a3afbc
...
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
...
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
(
from
...utils
import
(
ModelOutput
,
ModelOutput
,
add_code_sample_docstrings
,
add_code_sample_docstrings
,
...
@@ -216,12 +211,9 @@ class AlbertEmbeddings(nn.Module):
...
@@ -216,12 +211,9 @@ class AlbertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
"token_type_ids"
,
)
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def
forward
(
def
forward
(
...
...
src/transformers/models/bert/modeling_bert.py
View file @
a2a3afbc
...
@@ -40,12 +40,7 @@ from ...modeling_outputs import (
...
@@ -40,12 +40,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
(
from
...utils
import
(
ModelOutput
,
ModelOutput
,
add_code_sample_docstrings
,
add_code_sample_docstrings
,
...
@@ -199,12 +194,9 @@ class BertEmbeddings(nn.Module):
...
@@ -199,12 +194,9 @@ class BertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
"token_type_ids"
,
)
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
def
forward
(
def
forward
(
self
,
self
,
...
...
src/transformers/models/big_bird/modeling_big_bird.py
View file @
a2a3afbc
...
@@ -37,7 +37,7 @@ from ...modeling_outputs import (
...
@@ -37,7 +37,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
apply_chunking_to_forward
,
is_torch_greater_than_1_6
from
...pytorch_utils
import
apply_chunking_to_forward
from
...utils
import
(
from
...utils
import
(
ModelOutput
,
ModelOutput
,
add_code_sample_docstrings
,
add_code_sample_docstrings
,
...
@@ -259,12 +259,9 @@ class BigBirdEmbeddings(nn.Module):
...
@@ -259,12 +259,9 @@ class BigBirdEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
"token_type_ids"
,
)
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
# End copy
# End copy
self
.
rescale_embeddings
=
config
.
rescale_embeddings
self
.
rescale_embeddings
=
config
.
rescale_embeddings
...
...
src/transformers/models/convbert/modeling_convbert.py
View file @
a2a3afbc
...
@@ -35,12 +35,7 @@ from ...modeling_outputs import (
...
@@ -35,12 +35,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
...modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
from
...utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
from
.configuration_convbert
import
ConvBertConfig
from
.configuration_convbert
import
ConvBertConfig
...
@@ -198,12 +193,9 @@ class ConvBertEmbeddings(nn.Module):
...
@@ -198,12 +193,9 @@ class ConvBertEmbeddings(nn.Module):
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
"token_type_ids"
,
)
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
def
forward
(
def
forward
(
self
,
self
,
...
...
src/transformers/models/data2vec/modeling_data2vec_text.py
View file @
a2a3afbc
...
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
...
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
(
from
...utils
import
(
add_code_sample_docstrings
,
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings
,
...
@@ -87,12 +82,9 @@ class Data2VecTextForTextEmbeddings(nn.Module):
...
@@ -87,12 +82,9 @@ class Data2VecTextForTextEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
"token_type_ids"
,
)
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
# End copy
# End copy
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
...
...
src/transformers/models/decision_transformer/modeling_decision_transformer.py
View file @
a2a3afbc
...
@@ -22,15 +22,12 @@ from typing import Optional, Tuple, Union
...
@@ -22,15 +22,12 @@ from typing import Optional, Tuple, Union
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
torch.cuda.amp
import
autocast
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...modeling_outputs
import
BaseModelOutputWithPastAndCrossAttentions
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
Conv1D
,
find_pruneable_heads_and_indices
,
prune_conv1d_layer
Conv1D
,
find_pruneable_heads_and_indices
,
is_torch_greater_or_equal_than_1_6
,
prune_conv1d_layer
,
)
from
...utils
import
(
from
...utils
import
(
ModelOutput
,
ModelOutput
,
add_start_docstrings
,
add_start_docstrings
,
...
@@ -38,15 +35,6 @@ from ...utils import (
...
@@ -38,15 +35,6 @@ from ...utils import (
logging
,
logging
,
replace_return_docstrings
,
replace_return_docstrings
,
)
)
if
is_torch_greater_or_equal_than_1_6
:
is_amp_available
=
True
from
torch.cuda.amp
import
autocast
else
:
is_amp_available
=
False
from
...modeling_outputs
import
BaseModelOutputWithPastAndCrossAttentions
from
.configuration_decision_transformer
import
DecisionTransformerConfig
from
.configuration_decision_transformer
import
DecisionTransformerConfig
...
@@ -235,12 +223,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
...
@@ -235,12 +223,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
scale_factor
/=
float
(
self
.
layer_idx
+
1
)
scale_factor
/=
float
(
self
.
layer_idx
+
1
)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
if
is_amp_available
:
with
autocast
(
enabled
=
False
):
with
autocast
(
enabled
=
False
):
q
,
k
=
query
.
reshape
(
-
1
,
q_seq_len
,
dk
),
key
.
transpose
(
-
1
,
-
2
).
reshape
(
-
1
,
dk
,
k_seq_len
)
attn_weights
=
torch
.
baddbmm
(
attn_weights
,
q
.
float
(),
k
.
float
(),
beta
=
0
,
alpha
=
scale_factor
)
attn_weights
=
attn_weights
.
reshape
(
bsz
,
num_heads
,
q_seq_len
,
k_seq_len
)
else
:
q
,
k
=
query
.
reshape
(
-
1
,
q_seq_len
,
dk
),
key
.
transpose
(
-
1
,
-
2
).
reshape
(
-
1
,
dk
,
k_seq_len
)
q
,
k
=
query
.
reshape
(
-
1
,
q_seq_len
,
dk
),
key
.
transpose
(
-
1
,
-
2
).
reshape
(
-
1
,
dk
,
k_seq_len
)
attn_weights
=
torch
.
baddbmm
(
attn_weights
,
q
.
float
(),
k
.
float
(),
beta
=
0
,
alpha
=
scale_factor
)
attn_weights
=
torch
.
baddbmm
(
attn_weights
,
q
.
float
(),
k
.
float
(),
beta
=
0
,
alpha
=
scale_factor
)
attn_weights
=
attn_weights
.
reshape
(
bsz
,
num_heads
,
q_seq_len
,
k_seq_len
)
attn_weights
=
attn_weights
.
reshape
(
bsz
,
num_heads
,
q_seq_len
,
k_seq_len
)
...
...
src/transformers/models/distilbert/modeling_distilbert.py
View file @
a2a3afbc
...
@@ -39,12 +39,7 @@ from ...modeling_outputs import (
...
@@ -39,12 +39,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
(
from
...utils
import
(
add_code_sample_docstrings
,
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings
,
...
@@ -106,10 +101,9 @@ class Embeddings(nn.Module):
...
@@ -106,10 +101,9 @@ class Embeddings(nn.Module):
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
dim
,
eps
=
1e-12
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
dim
,
eps
=
1e-12
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)),
persistent
=
False
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)),
persistent
=
False
)
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
...
...
src/transformers/models/electra/modeling_electra.py
View file @
a2a3afbc
...
@@ -36,12 +36,7 @@ from ...modeling_outputs import (
...
@@ -36,12 +36,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
...modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
(
from
...utils
import
(
ModelOutput
,
ModelOutput
,
add_code_sample_docstrings
,
add_code_sample_docstrings
,
...
@@ -169,12 +164,9 @@ class ElectraEmbeddings(nn.Module):
...
@@ -169,12 +164,9 @@ class ElectraEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
"token_type_ids"
,
)
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
def
forward
(
def
forward
(
...
...
src/transformers/models/ernie/modeling_ernie.py
View file @
a2a3afbc
...
@@ -38,12 +38,7 @@ from ...modeling_outputs import (
...
@@ -38,12 +38,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
(
from
...utils
import
(
ModelOutput
,
ModelOutput
,
add_code_sample_docstrings
,
add_code_sample_docstrings
,
...
@@ -96,12 +91,9 @@ class ErnieEmbeddings(nn.Module):
...
@@ -96,12 +91,9 @@ class ErnieEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
"token_type_ids"
,
)
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
def
forward
(
def
forward
(
self
,
self
,
...
...
src/transformers/models/flaubert/modeling_flaubert.py
View file @
a2a3afbc
...
@@ -22,7 +22,6 @@ import torch
...
@@ -22,7 +22,6 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
...modeling_outputs
import
BaseModelOutput
from
...modeling_outputs
import
BaseModelOutput
from
...pytorch_utils
import
is_torch_greater_than_1_6
from
...utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
from
...utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
from
..xlm.modeling_xlm
import
(
from
..xlm.modeling_xlm
import
(
XLMForMultipleChoice
,
XLMForMultipleChoice
,
...
@@ -139,10 +138,9 @@ class FlaubertModel(XLMModel):
...
@@ -139,10 +138,9 @@ class FlaubertModel(XLMModel):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
layerdrop
=
getattr
(
config
,
"layerdrop"
,
0.0
)
self
.
layerdrop
=
getattr
(
config
,
"layerdrop"
,
0.0
)
self
.
pre_norm
=
getattr
(
config
,
"pre_norm"
,
False
)
self
.
pre_norm
=
getattr
(
config
,
"pre_norm"
,
False
)
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)),
persistent
=
False
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)),
persistent
=
False
)
)
@
add_start_docstrings_to_model_forward
(
FLAUBERT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
FLAUBERT_INPUTS_DOCSTRING
)
@
add_code_sample_docstrings
(
@
add_code_sample_docstrings
(
...
...
src/transformers/models/flava/modeling_flava.py
View file @
a2a3afbc
...
@@ -29,7 +29,6 @@ from transformers.utils.doc import add_code_sample_docstrings
...
@@ -29,7 +29,6 @@ from transformers.utils.doc import add_code_sample_docstrings
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
...modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
from
...modeling_utils
import
PreTrainedModel
,
find_pruneable_heads_and_indices
,
prune_linear_layer
from
...modeling_utils
import
PreTrainedModel
,
find_pruneable_heads_and_indices
,
prune_linear_layer
from
...pytorch_utils
import
is_torch_greater_than_1_6
from
...utils
import
(
from
...utils
import
(
ModelOutput
,
ModelOutput
,
add_start_docstrings
,
add_start_docstrings
,
...
@@ -392,12 +391,9 @@ class FlavaTextEmbeddings(nn.Module):
...
@@ -392,12 +391,9 @@ class FlavaTextEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
"token_type_ids"
,
)
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
def
forward
(
def
forward
(
self
,
self
,
...
...
src/transformers/models/fnet/modeling_fnet.py
View file @
a2a3afbc
...
@@ -43,7 +43,7 @@ from ...modeling_outputs import (
...
@@ -43,7 +43,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
apply_chunking_to_forward
,
is_torch_greater_than_1_6
from
...pytorch_utils
import
apply_chunking_to_forward
from
...utils
import
(
from
...utils
import
(
add_code_sample_docstrings
,
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings
,
...
@@ -117,12 +117,9 @@ class FNetEmbeddings(nn.Module):
...
@@ -117,12 +117,9 @@ class FNetEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
"token_type_ids"
,
)
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
def
forward
(
self
,
input_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
inputs_embeds
=
None
):
def
forward
(
self
,
input_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
inputs_embeds
=
None
):
if
input_ids
is
not
None
:
if
input_ids
is
not
None
:
...
...
src/transformers/models/gpt2/modeling_gpt2.py
View file @
a2a3afbc
...
@@ -23,22 +23,9 @@ from typing import Optional, Tuple, Union
...
@@ -23,22 +23,9 @@ from typing import Optional, Tuple, Union
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
torch.cuda.amp
import
autocast
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
...pytorch_utils
import
(
Conv1D
,
find_pruneable_heads_and_indices
,
is_torch_greater_or_equal_than_1_6
,
prune_conv1d_layer
,
)
if
is_torch_greater_or_equal_than_1_6
:
is_amp_available
=
True
from
torch.cuda.amp
import
autocast
else
:
is_amp_available
=
False
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...modeling_outputs
import
(
from
...modeling_outputs
import
(
BaseModelOutputWithPastAndCrossAttentions
,
BaseModelOutputWithPastAndCrossAttentions
,
...
@@ -47,6 +34,7 @@ from ...modeling_outputs import (
...
@@ -47,6 +34,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
...modeling_utils
import
PreTrainedModel
,
SequenceSummary
from
...pytorch_utils
import
Conv1D
,
find_pruneable_heads_and_indices
,
prune_conv1d_layer
from
...utils
import
(
from
...utils
import
(
ModelOutput
,
ModelOutput
,
add_code_sample_docstrings
,
add_code_sample_docstrings
,
...
@@ -247,12 +235,7 @@ class GPT2Attention(nn.Module):
...
@@ -247,12 +235,7 @@ class GPT2Attention(nn.Module):
scale_factor
/=
float
(
self
.
layer_idx
+
1
)
scale_factor
/=
float
(
self
.
layer_idx
+
1
)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
if
is_amp_available
:
with
autocast
(
enabled
=
False
):
with
autocast
(
enabled
=
False
):
q
,
k
=
query
.
reshape
(
-
1
,
q_seq_len
,
dk
),
key
.
transpose
(
-
1
,
-
2
).
reshape
(
-
1
,
dk
,
k_seq_len
)
attn_weights
=
torch
.
baddbmm
(
attn_weights
,
q
.
float
(),
k
.
float
(),
beta
=
0
,
alpha
=
scale_factor
)
attn_weights
=
attn_weights
.
reshape
(
bsz
,
num_heads
,
q_seq_len
,
k_seq_len
)
else
:
q
,
k
=
query
.
reshape
(
-
1
,
q_seq_len
,
dk
),
key
.
transpose
(
-
1
,
-
2
).
reshape
(
-
1
,
dk
,
k_seq_len
)
q
,
k
=
query
.
reshape
(
-
1
,
q_seq_len
,
dk
),
key
.
transpose
(
-
1
,
-
2
).
reshape
(
-
1
,
dk
,
k_seq_len
)
attn_weights
=
torch
.
baddbmm
(
attn_weights
,
q
.
float
(),
k
.
float
(),
beta
=
0
,
alpha
=
scale_factor
)
attn_weights
=
torch
.
baddbmm
(
attn_weights
,
q
.
float
(),
k
.
float
(),
beta
=
0
,
alpha
=
scale_factor
)
attn_weights
=
attn_weights
.
reshape
(
bsz
,
num_heads
,
q_seq_len
,
k_seq_len
)
attn_weights
=
attn_weights
.
reshape
(
bsz
,
num_heads
,
q_seq_len
,
k_seq_len
)
...
...
src/transformers/models/imagegpt/modeling_imagegpt.py
View file @
a2a3afbc
...
@@ -22,22 +22,9 @@ from typing import Any, Optional, Tuple, Union
...
@@ -22,22 +22,9 @@ from typing import Any, Optional, Tuple, Union
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
torch
import
nn
from
torch
import
nn
from
torch.cuda.amp
import
autocast
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
...pytorch_utils
import
(
Conv1D
,
find_pruneable_heads_and_indices
,
is_torch_greater_or_equal_than_1_6
,
prune_conv1d_layer
,
)
if
is_torch_greater_or_equal_than_1_6
:
is_amp_available
=
True
from
torch.cuda.amp
import
autocast
else
:
is_amp_available
=
False
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...modeling_outputs
import
(
from
...modeling_outputs
import
(
BaseModelOutputWithPastAndCrossAttentions
,
BaseModelOutputWithPastAndCrossAttentions
,
...
@@ -45,6 +32,7 @@ from ...modeling_outputs import (
...
@@ -45,6 +32,7 @@ from ...modeling_outputs import (
SequenceClassifierOutputWithPast
,
SequenceClassifierOutputWithPast
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
Conv1D
,
find_pruneable_heads_and_indices
,
prune_conv1d_layer
from
...utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
from
...utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
from
.configuration_imagegpt
import
ImageGPTConfig
from
.configuration_imagegpt
import
ImageGPTConfig
...
@@ -299,12 +287,7 @@ class ImageGPTAttention(nn.Module):
...
@@ -299,12 +287,7 @@ class ImageGPTAttention(nn.Module):
scale_factor
/=
float
(
self
.
layer_idx
+
1
)
scale_factor
/=
float
(
self
.
layer_idx
+
1
)
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
if
is_amp_available
:
with
autocast
(
enabled
=
False
):
with
autocast
(
enabled
=
False
):
q
,
k
=
query
.
reshape
(
-
1
,
q_seq_len
,
dk
),
key
.
transpose
(
-
1
,
-
2
).
reshape
(
-
1
,
dk
,
k_seq_len
)
attn_weights
=
torch
.
baddbmm
(
attn_weights
,
q
.
float
(),
k
.
float
(),
beta
=
0
,
alpha
=
scale_factor
)
attn_weights
=
attn_weights
.
reshape
(
bsz
,
num_heads
,
q_seq_len
,
k_seq_len
)
else
:
q
,
k
=
query
.
reshape
(
-
1
,
q_seq_len
,
dk
),
key
.
transpose
(
-
1
,
-
2
).
reshape
(
-
1
,
dk
,
k_seq_len
)
q
,
k
=
query
.
reshape
(
-
1
,
q_seq_len
,
dk
),
key
.
transpose
(
-
1
,
-
2
).
reshape
(
-
1
,
dk
,
k_seq_len
)
attn_weights
=
torch
.
baddbmm
(
attn_weights
,
q
.
float
(),
k
.
float
(),
beta
=
0
,
alpha
=
scale_factor
)
attn_weights
=
torch
.
baddbmm
(
attn_weights
,
q
.
float
(),
k
.
float
(),
beta
=
0
,
alpha
=
scale_factor
)
attn_weights
=
attn_weights
.
reshape
(
bsz
,
num_heads
,
q_seq_len
,
k_seq_len
)
attn_weights
=
attn_weights
.
reshape
(
bsz
,
num_heads
,
q_seq_len
,
k_seq_len
)
...
...
src/transformers/models/mctct/modeling_mctct.py
View file @
a2a3afbc
...
@@ -33,7 +33,6 @@ from ...modeling_utils import (
...
@@ -33,7 +33,6 @@ from ...modeling_utils import (
find_pruneable_heads_and_indices
,
find_pruneable_heads_and_indices
,
prune_linear_layer
,
prune_linear_layer
,
)
)
from
...pytorch_utils
import
is_torch_greater_than_1_6
from
...utils
import
logging
from
...utils
import
logging
from
.configuration_mctct
import
MCTCTConfig
from
.configuration_mctct
import
MCTCTConfig
...
@@ -153,12 +152,11 @@ class MCTCTEmbeddings(nn.Module):
...
@@ -153,12 +152,11 @@ class MCTCTEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
,
device
=
self
.
position_ids
.
device
),
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
,
device
=
self
.
position_ids
.
device
),
persistent
=
False
,
persistent
=
False
,
)
)
def
forward
(
def
forward
(
self
,
input_features
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
inputs_embeds
=
None
,
past_key_values_length
=
0
self
,
input_features
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
inputs_embeds
=
None
,
past_key_values_length
=
0
...
...
src/transformers/models/nezha/modeling_nezha.py
View file @
a2a3afbc
...
@@ -38,12 +38,7 @@ from ...modeling_outputs import (
...
@@ -38,12 +38,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
(
from
...utils
import
(
ModelOutput
,
ModelOutput
,
add_code_sample_docstrings
,
add_code_sample_docstrings
,
...
@@ -187,12 +182,9 @@ class NezhaEmbeddings(nn.Module):
...
@@ -187,12 +182,9 @@ class NezhaEmbeddings(nn.Module):
# any TensorFlow checkpoint file
# any TensorFlow checkpoint file
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
torch
.
zeros
((
1
,
config
.
max_position_embeddings
),
dtype
=
torch
.
long
),
persistent
=
False
"token_type_ids"
,
)
torch
.
zeros
((
1
,
config
.
max_position_embeddings
),
dtype
=
torch
.
long
),
persistent
=
False
,
)
def
forward
(
def
forward
(
self
,
self
,
...
...
src/transformers/models/nystromformer/modeling_nystromformer.py
View file @
a2a3afbc
...
@@ -33,12 +33,7 @@ from ...modeling_outputs import (
...
@@ -33,12 +33,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
from
...utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
from
.configuration_nystromformer
import
NystromformerConfig
from
.configuration_nystromformer
import
NystromformerConfig
...
@@ -72,12 +67,11 @@ class NystromformerEmbeddings(nn.Module):
...
@@ -72,12 +67,11 @@ class NystromformerEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
))
+
2
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
))
+
2
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
,
device
=
self
.
position_ids
.
device
),
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
,
device
=
self
.
position_ids
.
device
),
persistent
=
False
,
persistent
=
False
,
)
)
def
forward
(
self
,
input_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
inputs_embeds
=
None
):
def
forward
(
self
,
input_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
inputs_embeds
=
None
):
if
input_ids
is
not
None
:
if
input_ids
is
not
None
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment