Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
76296569
Unverified
Commit
76296569
authored
Oct 26, 2022
by
Younes Belkada
Committed by
GitHub
Oct 26, 2022
Browse files
`accelerate` support for `RoBERTa` family (#19906)
parent
6d023270
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
52 additions
and
15 deletions
+52
-15
src/transformers/models/camembert/modeling_camembert.py
src/transformers/models/camembert/modeling_camembert.py
+6
-1
src/transformers/models/data2vec/modeling_data2vec_text.py
src/transformers/models/data2vec/modeling_data2vec_text.py
+6
-1
src/transformers/models/lilt/modeling_lilt.py
src/transformers/models/lilt/modeling_lilt.py
+1
-0
src/transformers/models/longformer/modeling_longformer.py
src/transformers/models/longformer/modeling_longformer.py
+6
-1
src/transformers/models/luke/modeling_luke.py
src/transformers/models/luke/modeling_luke.py
+12
-1
src/transformers/models/roberta/modeling_roberta.py
src/transformers/models/roberta/modeling_roberta.py
+6
-1
src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
+6
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+9
-9
No files found.
src/transformers/models/camembert/modeling_camembert.py
View file @
76296569
...
@@ -728,7 +728,11 @@ class CamembertLMHead(nn.Module):
...
@@ -728,7 +728,11 @@ class CamembertLMHead(nn.Module):
def
_tie_weights
(
self
):
def
_tie_weights
(
self
):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self
.
bias
=
self
.
decoder
.
bias
# For accelerate compatibility and to not break backward compatibility
if
self
.
decoder
.
bias
.
device
.
type
==
"meta"
:
self
.
decoder
.
bias
=
self
.
bias
else
:
self
.
bias
=
self
.
decoder
.
bias
@
add_start_docstrings
(
@
add_start_docstrings
(
...
@@ -752,6 +756,7 @@ class CamembertModel(CamembertPreTrainedModel):
...
@@ -752,6 +756,7 @@ class CamembertModel(CamembertPreTrainedModel):
"""
"""
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
_keys_to_ignore_on_load_missing
=
[
r
"position_ids"
]
_no_split_modules
=
[]
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert
def
__init__
(
self
,
config
,
add_pooling_layer
=
True
):
def
__init__
(
self
,
config
,
add_pooling_layer
=
True
):
...
...
src/transformers/models/data2vec/modeling_data2vec_text.py
View file @
76296569
...
@@ -584,6 +584,7 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
...
@@ -584,6 +584,7 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
config_class
=
Data2VecTextConfig
config_class
=
Data2VecTextConfig
base_model_prefix
=
"data2vec_text"
base_model_prefix
=
"data2vec_text"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[]
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
"""Initialize the weights"""
...
@@ -1147,7 +1148,11 @@ class Data2VecTextLMHead(nn.Module):
...
@@ -1147,7 +1148,11 @@ class Data2VecTextLMHead(nn.Module):
def
_tie_weights
(
self
):
def
_tie_weights
(
self
):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self
.
bias
=
self
.
decoder
.
bias
# For accelerate compatibility and to not break backward compatibility
if
self
.
decoder
.
bias
.
device
.
type
==
"meta"
:
self
.
decoder
.
bias
=
self
.
bias
else
:
self
.
bias
=
self
.
decoder
.
bias
@
add_start_docstrings
(
@
add_start_docstrings
(
...
...
src/transformers/models/lilt/modeling_lilt.py
View file @
76296569
...
@@ -588,6 +588,7 @@ class LiltPreTrainedModel(PreTrainedModel):
...
@@ -588,6 +588,7 @@ class LiltPreTrainedModel(PreTrainedModel):
config_class
=
LiltConfig
config_class
=
LiltConfig
base_model_prefix
=
"lilt"
base_model_prefix
=
"lilt"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
...
...
src/transformers/models/longformer/modeling_longformer.py
View file @
76296569
...
@@ -1412,7 +1412,11 @@ class LongformerLMHead(nn.Module):
...
@@ -1412,7 +1412,11 @@ class LongformerLMHead(nn.Module):
def
_tie_weights
(
self
):
def
_tie_weights
(
self
):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self
.
bias
=
self
.
decoder
.
bias
# For accelerate compatibility and to not break backward compatibility
if
self
.
decoder
.
bias
.
device
.
type
==
"meta"
:
self
.
decoder
.
bias
=
self
.
bias
else
:
self
.
bias
=
self
.
decoder
.
bias
class
LongformerPreTrainedModel
(
PreTrainedModel
):
class
LongformerPreTrainedModel
(
PreTrainedModel
):
...
@@ -1425,6 +1429,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
...
@@ -1425,6 +1429,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"longformer"
base_model_prefix
=
"longformer"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_keys_to_ignore_on_load_unexpected
=
[
r
"position_ids"
]
_keys_to_ignore_on_load_unexpected
=
[
r
"position_ids"
]
_no_split_modules
=
[
"LongformerSelfAttention"
]
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
"""Initialize the weights"""
"""Initialize the weights"""
...
...
src/transformers/models/luke/modeling_luke.py
View file @
76296569
...
@@ -902,6 +902,7 @@ class LukePreTrainedModel(PreTrainedModel):
...
@@ -902,6 +902,7 @@ class LukePreTrainedModel(PreTrainedModel):
config_class
=
LukeConfig
config_class
=
LukeConfig
base_model_prefix
=
"luke"
base_model_prefix
=
"luke"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"LukeAttention"
,
"LukeEntityEmbeddings"
]
def
_init_weights
(
self
,
module
:
nn
.
Module
):
def
_init_weights
(
self
,
module
:
nn
.
Module
):
"""Initialize the weights"""
"""Initialize the weights"""
...
@@ -1264,7 +1265,11 @@ class LukeLMHead(nn.Module):
...
@@ -1264,7 +1265,11 @@ class LukeLMHead(nn.Module):
def
_tie_weights
(
self
):
def
_tie_weights
(
self
):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self
.
bias
=
self
.
decoder
.
bias
# For accelerate compatibility and to not break backward compatibility
if
self
.
decoder
.
bias
.
device
.
type
==
"meta"
:
self
.
decoder
.
bias
=
self
.
bias
else
:
self
.
bias
=
self
.
decoder
.
bias
@
add_start_docstrings
(
@
add_start_docstrings
(
...
@@ -1746,9 +1751,15 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
...
@@ -1746,9 +1751,15 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
hidden_size
=
outputs
.
last_hidden_state
.
size
(
-
1
)
hidden_size
=
outputs
.
last_hidden_state
.
size
(
-
1
)
entity_start_positions
=
entity_start_positions
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
hidden_size
)
entity_start_positions
=
entity_start_positions
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
hidden_size
)
if
entity_start_positions
.
device
!=
outputs
.
last_hidden_state
.
device
:
entity_start_positions
=
entity_start_positions
.
to
(
outputs
.
last_hidden_state
.
device
)
start_states
=
torch
.
gather
(
outputs
.
last_hidden_state
,
-
2
,
entity_start_positions
)
start_states
=
torch
.
gather
(
outputs
.
last_hidden_state
,
-
2
,
entity_start_positions
)
entity_end_positions
=
entity_end_positions
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
hidden_size
)
entity_end_positions
=
entity_end_positions
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
hidden_size
)
if
entity_end_positions
.
device
!=
outputs
.
last_hidden_state
.
device
:
entity_end_positions
=
entity_end_positions
.
to
(
outputs
.
last_hidden_state
.
device
)
end_states
=
torch
.
gather
(
outputs
.
last_hidden_state
,
-
2
,
entity_end_positions
)
end_states
=
torch
.
gather
(
outputs
.
last_hidden_state
,
-
2
,
entity_end_positions
)
feature_vector
=
torch
.
cat
([
start_states
,
end_states
,
outputs
.
entity_last_hidden_state
],
dim
=
2
)
feature_vector
=
torch
.
cat
([
start_states
,
end_states
,
outputs
.
entity_last_hidden_state
],
dim
=
2
)
feature_vector
=
self
.
dropout
(
feature_vector
)
feature_vector
=
self
.
dropout
(
feature_vector
)
...
...
src/transformers/models/roberta/modeling_roberta.py
View file @
76296569
...
@@ -584,6 +584,7 @@ class RobertaPreTrainedModel(PreTrainedModel):
...
@@ -584,6 +584,7 @@ class RobertaPreTrainedModel(PreTrainedModel):
config_class
=
RobertaConfig
config_class
=
RobertaConfig
base_model_prefix
=
"roberta"
base_model_prefix
=
"roberta"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
...
@@ -1146,7 +1147,11 @@ class RobertaLMHead(nn.Module):
...
@@ -1146,7 +1147,11 @@ class RobertaLMHead(nn.Module):
def
_tie_weights
(
self
):
def
_tie_weights
(
self
):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self
.
bias
=
self
.
decoder
.
bias
# For accelerate compatibility and to not break backward compatibility
if
self
.
decoder
.
bias
.
device
.
type
==
"meta"
:
self
.
decoder
.
bias
=
self
.
bias
else
:
self
.
bias
=
self
.
decoder
.
bias
@
add_start_docstrings
(
@
add_start_docstrings
(
...
...
src/transformers/models/xlm_roberta/modeling_xlm_roberta.py
View file @
76296569
...
@@ -586,6 +586,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
...
@@ -586,6 +586,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
config_class
=
XLMRobertaConfig
config_class
=
XLMRobertaConfig
base_model_prefix
=
"roberta"
base_model_prefix
=
"roberta"
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_no_split_modules
=
[]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
...
@@ -1155,7 +1156,11 @@ class XLMRobertaLMHead(nn.Module):
...
@@ -1155,7 +1156,11 @@ class XLMRobertaLMHead(nn.Module):
def
_tie_weights
(
self
):
def
_tie_weights
(
self
):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self
.
bias
=
self
.
decoder
.
bias
# For accelerate compatibility and to not break backward compatibility
if
self
.
decoder
.
bias
.
device
.
type
==
"meta"
:
self
.
decoder
.
bias
=
self
.
bias
else
:
self
.
bias
=
self
.
decoder
.
bias
@
add_start_docstrings
(
@
add_start_docstrings
(
...
...
tests/test_modeling_common.py
View file @
76296569
...
@@ -2312,11 +2312,11 @@ class ModelTesterMixin:
...
@@ -2312,11 +2312,11 @@ class ModelTesterMixin:
if
model_class
.
_no_split_modules
is
None
:
if
model_class
.
_no_split_modules
is
None
:
continue
continue
inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
inputs_dict
_class
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
).
eval
()
model
=
model_class
(
config
).
eval
()
model
=
model
.
to
(
torch_device
)
model
=
model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
base_output
=
model
(
**
inputs_dict
)
base_output
=
model
(
**
inputs_dict
_class
)
model_size
=
compute_module_sizes
(
model
)[
""
]
model_size
=
compute_module_sizes
(
model
)[
""
]
max_size
=
int
(
self
.
model_split_percents
[
0
]
*
model_size
)
max_size
=
int
(
self
.
model_split_percents
[
0
]
*
model_size
)
...
@@ -2334,7 +2334,7 @@ class ModelTesterMixin:
...
@@ -2334,7 +2334,7 @@ class ModelTesterMixin:
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
new_output
=
new_model
(
**
inputs_dict
)
new_output
=
new_model
(
**
inputs_dict
_class
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
...
@@ -2347,12 +2347,12 @@ class ModelTesterMixin:
...
@@ -2347,12 +2347,12 @@ class ModelTesterMixin:
if
model_class
.
_no_split_modules
is
None
:
if
model_class
.
_no_split_modules
is
None
:
continue
continue
inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
inputs_dict
_class
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
).
eval
()
model
=
model_class
(
config
).
eval
()
model
=
model
.
to
(
torch_device
)
model
=
model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
base_output
=
model
(
**
inputs_dict
)
base_output
=
model
(
**
inputs_dict
_class
)
model_size
=
compute_module_sizes
(
model
)[
""
]
model_size
=
compute_module_sizes
(
model
)[
""
]
# We test several splits of sizes to make sure it works.
# We test several splits of sizes to make sure it works.
...
@@ -2369,7 +2369,7 @@ class ModelTesterMixin:
...
@@ -2369,7 +2369,7 @@ class ModelTesterMixin:
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
new_output
=
new_model
(
**
inputs_dict
)
new_output
=
new_model
(
**
inputs_dict
_class
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
...
@@ -2382,12 +2382,12 @@ class ModelTesterMixin:
...
@@ -2382,12 +2382,12 @@ class ModelTesterMixin:
if
model_class
.
_no_split_modules
is
None
:
if
model_class
.
_no_split_modules
is
None
:
continue
continue
inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
inputs_dict
_class
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
model
=
model_class
(
config
).
eval
()
model
=
model_class
(
config
).
eval
()
model
=
model
.
to
(
torch_device
)
model
=
model
.
to
(
torch_device
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
base_output
=
model
(
**
inputs_dict
)
base_output
=
model
(
**
inputs_dict
_class
)
model_size
=
compute_module_sizes
(
model
)[
""
]
model_size
=
compute_module_sizes
(
model
)[
""
]
# We test several splits of sizes to make sure it works.
# We test several splits of sizes to make sure it works.
...
@@ -2404,7 +2404,7 @@ class ModelTesterMixin:
...
@@ -2404,7 +2404,7 @@ class ModelTesterMixin:
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
self
.
check_device_map_is_respected
(
new_model
,
new_model
.
hf_device_map
)
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
new_output
=
new_model
(
**
inputs_dict
)
new_output
=
new_model
(
**
inputs_dict
_class
)
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
self
.
assertTrue
(
torch
.
allclose
(
base_output
[
0
],
new_output
[
0
]))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment