Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a2a3afbc
Unverified
Commit
a2a3afbc
authored
Sep 14, 2022
by
Sylvain Gugger
Committed by
GitHub
Sep 14, 2022
Browse files
PyTorch >= 1.7.0 and TensorFlow >= 2.4.0 (#19016)
parent
9f4acd05
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
38 additions
and
109 deletions
+38
-109
src/transformers/models/qdqbert/modeling_qdqbert.py
src/transformers/models/qdqbert/modeling_qdqbert.py
+4
-7
src/transformers/models/realm/modeling_realm.py
src/transformers/models/realm/modeling_realm.py
+4
-12
src/transformers/models/roberta/modeling_roberta.py
src/transformers/models/roberta/modeling_roberta.py
+4
-12
src/transformers/models/vilt/modeling_vilt.py
src/transformers/models/vilt/modeling_vilt.py
+4
-12
src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
...sformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
+4
-12
src/transformers/models/yoso/modeling_yoso.py
src/transformers/models/yoso/modeling_yoso.py
+6
-12
src/transformers/pytorch_utils.py
src/transformers/pytorch_utils.py
+1
-2
src/transformers/trainer.py
src/transformers/trainer.py
+5
-26
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+1
-7
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
...elname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
+5
-7
No files found.
src/transformers/models/qdqbert/modeling_qdqbert.py
View file @
a2a3afbc
...
@@ -39,7 +39,7 @@ from ...modeling_outputs import (
...
@@ -39,7 +39,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
from
...pytorch_utils
import
find_pruneable_heads_and_indices
,
prune_linear_layer
from
...utils
import
(
from
...utils
import
(
add_code_sample_docstrings
,
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings
,
...
@@ -166,11 +166,8 @@ class QDQBertEmbeddings(nn.Module):
...
@@ -166,11 +166,8 @@ class QDQBertEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
)
def
forward
(
def
forward
(
...
...
src/transformers/models/realm/modeling_realm.py
View file @
a2a3afbc
...
@@ -31,12 +31,7 @@ from ...modeling_outputs import (
...
@@ -31,12 +31,7 @@ from ...modeling_outputs import (
ModelOutput
,
ModelOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
from
...utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
from
.configuration_realm
import
RealmConfig
from
.configuration_realm
import
RealmConfig
...
@@ -185,11 +180,8 @@ class RealmEmbeddings(nn.Module):
...
@@ -185,11 +180,8 @@ class RealmEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
)
def
forward
(
def
forward
(
...
...
src/transformers/models/roberta/modeling_roberta.py
View file @
a2a3afbc
...
@@ -35,12 +35,7 @@ from ...modeling_outputs import (
...
@@ -35,12 +35,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
(
from
...utils
import
(
add_code_sample_docstrings
,
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings
,
...
@@ -87,11 +82,8 @@ class RobertaEmbeddings(nn.Module):
...
@@ -87,11 +82,8 @@ class RobertaEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
)
# End copy
# End copy
...
...
src/transformers/models/vilt/modeling_vilt.py
View file @
a2a3afbc
...
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
...
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
find_pruneable_heads_and_indices
,
is_torch_greater_or_equal_than_1_10
,
prune_linear_layer
find_pruneable_heads_and_indices
,
is_torch_greater_or_equal_than_1_10
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
from
...utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
from
.configuration_vilt
import
ViltConfig
from
.configuration_vilt
import
ViltConfig
...
@@ -255,11 +250,8 @@ class TextEmbeddings(nn.Module):
...
@@ -255,11 +250,8 @@ class TextEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
)
def
forward
(
self
,
input_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
inputs_embeds
=
None
):
def
forward
(
self
,
input_ids
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
inputs_embeds
=
None
):
...
...
src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py
View file @
a2a3afbc
...
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
...
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
(
from
...utils
import
(
add_code_sample_docstrings
,
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings
,
...
@@ -80,11 +75,8 @@ class XLMRobertaXLEmbeddings(nn.Module):
...
@@ -80,11 +75,8 @@ class XLMRobertaXLEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
),
persistent
=
False
,
)
)
# End copy
# End copy
...
...
src/transformers/models/yoso/modeling_yoso.py
View file @
a2a3afbc
...
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
...
@@ -34,12 +34,7 @@ from ...modeling_outputs import (
TokenClassifierOutput
,
TokenClassifierOutput
,
)
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
(
from
...pytorch_utils
import
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
prune_linear_layer
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
is_torch_greater_than_1_6
,
prune_linear_layer
,
)
from
...utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
from
...utils
import
add_code_sample_docstrings
,
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
from
.configuration_yoso
import
YosoConfig
from
.configuration_yoso
import
YosoConfig
...
@@ -261,7 +256,6 @@ class YosoEmbeddings(nn.Module):
...
@@ -261,7 +256,6 @@ class YosoEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
))
+
2
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
))
+
2
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
,
device
=
self
.
position_ids
.
device
),
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
,
device
=
self
.
position_ids
.
device
),
...
...
src/transformers/pytorch_utils.py
View file @
a2a3afbc
...
@@ -26,8 +26,7 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
...
@@ -26,8 +26,7 @@ ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
parsed_torch_version_base
=
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
parsed_torch_version_base
=
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
is_torch_greater_or_equal_than_1_6
=
parsed_torch_version_base
>=
version
.
parse
(
"1.6.0"
)
is_torch_greater_than_1_6
=
parsed_torch_version_base
>
version
.
parse
(
"1.6.0"
)
is_torch_less_than_1_8
=
parsed_torch_version_base
<
version
.
parse
(
"1.8.0"
)
is_torch_less_than_1_8
=
parsed_torch_version_base
<
version
.
parse
(
"1.8.0"
)
is_torch_greater_or_equal_than_1_10
=
parsed_torch_version_base
>=
version
.
parse
(
"1.10"
)
is_torch_greater_or_equal_than_1_10
=
parsed_torch_version_base
>=
version
.
parse
(
"1.10"
)
is_torch_less_than_1_11
=
parsed_torch_version_base
<
version
.
parse
(
"1.11"
)
is_torch_less_than_1_11
=
parsed_torch_version_base
<
version
.
parse
(
"1.11"
)
...
...
src/transformers/trainer.py
View file @
a2a3afbc
...
@@ -71,12 +71,7 @@ from .modelcard import TrainingSummary
...
@@ -71,12 +71,7 @@ from .modelcard import TrainingSummary
from
.modeling_utils
import
PreTrainedModel
,
load_sharded_checkpoint
,
unwrap_model
from
.modeling_utils
import
PreTrainedModel
,
load_sharded_checkpoint
,
unwrap_model
from
.models.auto.modeling_auto
import
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
,
MODEL_MAPPING_NAMES
from
.models.auto.modeling_auto
import
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
,
MODEL_MAPPING_NAMES
from
.optimization
import
Adafactor
,
get_scheduler
from
.optimization
import
Adafactor
,
get_scheduler
from
.pytorch_utils
import
(
from
.pytorch_utils
import
ALL_LAYERNORM_LAYERS
,
is_torch_greater_or_equal_than_1_10
,
is_torch_less_than_1_11
ALL_LAYERNORM_LAYERS
,
is_torch_greater_or_equal_than_1_6
,
is_torch_greater_or_equal_than_1_10
,
is_torch_less_than_1_11
,
)
from
.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.trainer_callback
import
(
from
.trainer_callback
import
(
CallbackHandler
,
CallbackHandler
,
...
@@ -155,9 +150,7 @@ from .utils import (
...
@@ -155,9 +150,7 @@ from .utils import (
from
.utils.generic
import
ContextManagers
from
.utils.generic
import
ContextManagers
_is_torch_generator_available
=
False
_is_native_cpu_amp_available
=
is_torch_greater_or_equal_than_1_10
_is_native_cuda_amp_available
=
False
_is_native_cpu_amp_available
=
False
DEFAULT_CALLBACKS
=
[
DefaultFlowCallback
]
DEFAULT_CALLBACKS
=
[
DefaultFlowCallback
]
DEFAULT_PROGRESS_CALLBACK
=
ProgressCallback
DEFAULT_PROGRESS_CALLBACK
=
ProgressCallback
...
@@ -170,13 +163,6 @@ if is_in_notebook():
...
@@ -170,13 +163,6 @@ if is_in_notebook():
if
is_apex_available
():
if
is_apex_available
():
from
apex
import
amp
from
apex
import
amp
if
is_torch_greater_or_equal_than_1_6
:
_is_torch_generator_available
=
True
_is_native_cuda_amp_available
=
True
if
is_torch_greater_or_equal_than_1_10
:
_is_native_cpu_amp_available
=
True
if
is_datasets_available
():
if
is_datasets_available
():
import
datasets
import
datasets
...
@@ -565,12 +551,7 @@ class Trainer:
...
@@ -565,12 +551,7 @@ class Trainer:
else
:
else
:
raise
ValueError
(
"Tried to use cpu amp but native cpu amp is not available"
)
raise
ValueError
(
"Tried to use cpu amp but native cpu amp is not available"
)
else
:
else
:
if
_is_native_cuda_amp_available
:
args
.
half_precision_backend
=
"cuda_amp"
args
.
half_precision_backend
=
"cuda_amp"
elif
args
.
bf16
:
raise
ValueError
(
"Tried to use `bf16` but native amp is not available"
)
else
:
args
.
half_precision_backend
=
"apex"
logger
.
info
(
f
"Using
{
args
.
half_precision_backend
}
half precision backend"
)
logger
.
info
(
f
"Using
{
args
.
half_precision_backend
}
half precision backend"
)
...
@@ -781,7 +762,7 @@ class Trainer:
...
@@ -781,7 +762,7 @@ class Trainer:
return
None
return
None
generator
=
None
generator
=
None
if
self
.
args
.
world_size
<=
1
and
_is_torch_generator_available
:
if
self
.
args
.
world_size
<=
1
:
generator
=
torch
.
Generator
()
generator
=
torch
.
Generator
()
# for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
# for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
# `args.seed`) if data_seed isn't provided.
# `args.seed`) if data_seed isn't provided.
...
@@ -826,9 +807,7 @@ class Trainer:
...
@@ -826,9 +807,7 @@ class Trainer:
else
:
else
:
if
self
.
args
.
world_size
<=
1
:
if
self
.
args
.
world_size
<=
1
:
if
_is_torch_generator_available
:
return
RandomSampler
(
self
.
train_dataset
,
generator
=
generator
)
return
RandomSampler
(
self
.
train_dataset
,
generator
=
generator
)
return
RandomSampler
(
self
.
train_dataset
)
elif
(
elif
(
self
.
args
.
parallel_mode
in
[
ParallelMode
.
TPU
,
ParallelMode
.
SAGEMAKER_MODEL_PARALLEL
]
self
.
args
.
parallel_mode
in
[
ParallelMode
.
TPU
,
ParallelMode
.
SAGEMAKER_MODEL_PARALLEL
]
and
not
self
.
args
.
dataloader_drop_last
and
not
self
.
args
.
dataloader_drop_last
...
...
src/transformers/trainer_pt_utils.py
View file @
a2a3afbc
...
@@ -31,7 +31,6 @@ from typing import Any, Dict, Iterator, List, Optional, Union
...
@@ -31,7 +31,6 @@ from typing import Any, Dict, Iterator, List, Optional, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
packaging
import
version
from
torch
import
nn
from
torch
import
nn
from
torch.utils.data
import
Dataset
,
IterableDataset
,
RandomSampler
,
Sampler
from
torch.utils.data
import
Dataset
,
IterableDataset
,
RandomSampler
,
Sampler
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
...
@@ -831,12 +830,7 @@ def _get_learning_rate(self):
...
@@ -831,12 +830,7 @@ def _get_learning_rate(self):
else
:
else
:
raise
raise
else
:
else
:
last_lr
=
(
last_lr
=
self
.
lr_scheduler
.
get_last_lr
()[
0
]
# backward compatibility for pytorch schedulers
self
.
lr_scheduler
.
get_last_lr
()[
0
]
if
version
.
parse
(
version
.
parse
(
torch
.
__version__
).
base_version
)
>=
version
.
parse
(
"1.4"
)
else
self
.
lr_scheduler
.
get_lr
()[
0
]
)
if
torch
.
is_tensor
(
last_lr
):
if
torch
.
is_tensor
(
last_lr
):
last_lr
=
last_lr
.
item
()
last_lr
=
last_lr
.
item
()
return
last_lr
return
last_lr
...
...
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py
View file @
a2a3afbc
...
@@ -47,7 +47,6 @@ from ...pytorch_utils import (
...
@@ -47,7 +47,6 @@ from ...pytorch_utils import (
apply_chunking_to_forward
,
apply_chunking_to_forward
,
find_pruneable_heads_and_indices
,
find_pruneable_heads_and_indices
,
prune_linear_layer
,
prune_linear_layer
,
is_torch_greater_than_1_6
,
)
)
from
...utils
import
logging
from
...utils
import
logging
from
.
configuration_
{{
cookiecutter
.
lowercase_modelname
}}
import
{{
cookiecutter
.
camelcase_modelname
}}
Config
from
.
configuration_
{{
cookiecutter
.
lowercase_modelname
}}
import
{{
cookiecutter
.
camelcase_modelname
}}
Config
...
@@ -157,7 +156,6 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
...
@@ -157,7 +156,6 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)))
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
self
.
position_embedding_type
=
getattr
(
config
,
"position_embedding_type"
,
"absolute"
)
if
is_torch_greater_than_1_6
:
self
.
register_buffer
(
self
.
register_buffer
(
"token_type_ids"
,
"token_type_ids"
,
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
,
device
=
self
.
position_ids
.
device
),
torch
.
zeros
(
self
.
position_ids
.
size
(),
dtype
=
torch
.
long
,
device
=
self
.
position_ids
.
device
),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment