Unverified Commit 76296569 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

`accelerate` support for `RoBERTa` family (#19906)

parent 6d023270
...@@ -728,7 +728,11 @@ class CamembertLMHead(nn.Module): ...@@ -728,7 +728,11 @@ class CamembertLMHead(nn.Module):
def _tie_weights(self): def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias # For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(
...@@ -752,6 +756,7 @@ class CamembertModel(CamembertPreTrainedModel): ...@@ -752,6 +756,7 @@ class CamembertModel(CamembertPreTrainedModel):
""" """
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert
def __init__(self, config, add_pooling_layer=True): def __init__(self, config, add_pooling_layer=True):
......
...@@ -584,6 +584,7 @@ class Data2VecTextPreTrainedModel(PreTrainedModel): ...@@ -584,6 +584,7 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
config_class = Data2VecTextConfig config_class = Data2VecTextConfig
base_model_prefix = "data2vec_text" base_model_prefix = "data2vec_text"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = []
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -1147,7 +1148,11 @@ class Data2VecTextLMHead(nn.Module): ...@@ -1147,7 +1148,11 @@ class Data2VecTextLMHead(nn.Module):
def _tie_weights(self): def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias # For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(
......
...@@ -588,6 +588,7 @@ class LiltPreTrainedModel(PreTrainedModel): ...@@ -588,6 +588,7 @@ class LiltPreTrainedModel(PreTrainedModel):
config_class = LiltConfig config_class = LiltConfig
base_model_prefix = "lilt" base_model_prefix = "lilt"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -1412,7 +1412,11 @@ class LongformerLMHead(nn.Module): ...@@ -1412,7 +1412,11 @@ class LongformerLMHead(nn.Module):
def _tie_weights(self): def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias # For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
class LongformerPreTrainedModel(PreTrainedModel): class LongformerPreTrainedModel(PreTrainedModel):
...@@ -1425,6 +1429,7 @@ class LongformerPreTrainedModel(PreTrainedModel): ...@@ -1425,6 +1429,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
base_model_prefix = "longformer" base_model_prefix = "longformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"position_ids"] _keys_to_ignore_on_load_unexpected = [r"position_ids"]
_no_split_modules = ["LongformerSelfAttention"]
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
......
...@@ -902,6 +902,7 @@ class LukePreTrainedModel(PreTrainedModel): ...@@ -902,6 +902,7 @@ class LukePreTrainedModel(PreTrainedModel):
config_class = LukeConfig config_class = LukeConfig
base_model_prefix = "luke" base_model_prefix = "luke"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"]
def _init_weights(self, module: nn.Module): def _init_weights(self, module: nn.Module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -1264,7 +1265,11 @@ class LukeLMHead(nn.Module): ...@@ -1264,7 +1265,11 @@ class LukeLMHead(nn.Module):
def _tie_weights(self): def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias # For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(
...@@ -1746,9 +1751,15 @@ class LukeForEntitySpanClassification(LukePreTrainedModel): ...@@ -1746,9 +1751,15 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
hidden_size = outputs.last_hidden_state.size(-1) hidden_size = outputs.last_hidden_state.size(-1)
entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size) entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
if entity_start_positions.device != outputs.last_hidden_state.device:
entity_start_positions = entity_start_positions.to(outputs.last_hidden_state.device)
start_states = torch.gather(outputs.last_hidden_state, -2, entity_start_positions) start_states = torch.gather(outputs.last_hidden_state, -2, entity_start_positions)
entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size) entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
if entity_end_positions.device != outputs.last_hidden_state.device:
entity_end_positions = entity_end_positions.to(outputs.last_hidden_state.device)
end_states = torch.gather(outputs.last_hidden_state, -2, entity_end_positions) end_states = torch.gather(outputs.last_hidden_state, -2, entity_end_positions)
feature_vector = torch.cat([start_states, end_states, outputs.entity_last_hidden_state], dim=2) feature_vector = torch.cat([start_states, end_states, outputs.entity_last_hidden_state], dim=2)
feature_vector = self.dropout(feature_vector) feature_vector = self.dropout(feature_vector)
......
...@@ -584,6 +584,7 @@ class RobertaPreTrainedModel(PreTrainedModel): ...@@ -584,6 +584,7 @@ class RobertaPreTrainedModel(PreTrainedModel):
config_class = RobertaConfig config_class = RobertaConfig
base_model_prefix = "roberta" base_model_prefix = "roberta"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
...@@ -1146,7 +1147,11 @@ class RobertaLMHead(nn.Module): ...@@ -1146,7 +1147,11 @@ class RobertaLMHead(nn.Module):
def _tie_weights(self): def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias # For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(
......
...@@ -586,6 +586,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): ...@@ -586,6 +586,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
config_class = XLMRobertaConfig config_class = XLMRobertaConfig
base_model_prefix = "roberta" base_model_prefix = "roberta"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
...@@ -1155,7 +1156,11 @@ class XLMRobertaLMHead(nn.Module): ...@@ -1155,7 +1156,11 @@ class XLMRobertaLMHead(nn.Module):
def _tie_weights(self): def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized) # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
self.bias = self.decoder.bias # For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
@add_start_docstrings( @add_start_docstrings(
......
...@@ -2312,11 +2312,11 @@ class ModelTesterMixin: ...@@ -2312,11 +2312,11 @@ class ModelTesterMixin:
if model_class._no_split_modules is None: if model_class._no_split_modules is None:
continue continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class) inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval() model = model_class(config).eval()
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size) max_size = int(self.model_split_percents[0] * model_size)
...@@ -2334,7 +2334,7 @@ class ModelTesterMixin: ...@@ -2334,7 +2334,7 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0])) self.assertTrue(torch.allclose(base_output[0], new_output[0]))
...@@ -2347,12 +2347,12 @@ class ModelTesterMixin: ...@@ -2347,12 +2347,12 @@ class ModelTesterMixin:
if model_class._no_split_modules is None: if model_class._no_split_modules is None:
continue continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class) inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval() model = model_class(config).eval()
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works. # We test several splits of sizes to make sure it works.
...@@ -2369,7 +2369,7 @@ class ModelTesterMixin: ...@@ -2369,7 +2369,7 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0])) self.assertTrue(torch.allclose(base_output[0], new_output[0]))
...@@ -2382,12 +2382,12 @@ class ModelTesterMixin: ...@@ -2382,12 +2382,12 @@ class ModelTesterMixin:
if model_class._no_split_modules is None: if model_class._no_split_modules is None:
continue continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class) inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval() model = model_class(config).eval()
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
base_output = model(**inputs_dict) base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""] model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works. # We test several splits of sizes to make sure it works.
...@@ -2404,7 +2404,7 @@ class ModelTesterMixin: ...@@ -2404,7 +2404,7 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map) self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0) torch.manual_seed(0)
new_output = new_model(**inputs_dict) new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0])) self.assertTrue(torch.allclose(base_output[0], new_output[0]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment