Unverified Commit 76296569 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

`accelerate` support for `RoBERTa` family (#19906)

parent 6d023270
......@@ -728,6 +728,10 @@ class CamembertLMHead(nn.Module):
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
......@@ -752,6 +756,7 @@ class CamembertModel(CamembertPreTrainedModel):
"""
_keys_to_ignore_on_load_missing = [r"position_ids"]
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Camembert
def __init__(self, config, add_pooling_layer=True):
......
......@@ -584,6 +584,7 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
config_class = Data2VecTextConfig
base_model_prefix = "data2vec_text"
supports_gradient_checkpointing = True
_no_split_modules = []
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -1147,6 +1148,10 @@ class Data2VecTextLMHead(nn.Module):
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
......
......@@ -588,6 +588,7 @@ class LiltPreTrainedModel(PreTrainedModel):
config_class = LiltConfig
base_model_prefix = "lilt"
supports_gradient_checkpointing = True
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
......
......@@ -1412,6 +1412,10 @@ class LongformerLMHead(nn.Module):
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
......@@ -1425,6 +1429,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
base_model_prefix = "longformer"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"position_ids"]
_no_split_modules = ["LongformerSelfAttention"]
def _init_weights(self, module):
"""Initialize the weights"""
......
......@@ -902,6 +902,7 @@ class LukePreTrainedModel(PreTrainedModel):
config_class = LukeConfig
base_model_prefix = "luke"
supports_gradient_checkpointing = True
_no_split_modules = ["LukeAttention", "LukeEntityEmbeddings"]
def _init_weights(self, module: nn.Module):
"""Initialize the weights"""
......@@ -1264,6 +1265,10 @@ class LukeLMHead(nn.Module):
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
......@@ -1746,9 +1751,15 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
hidden_size = outputs.last_hidden_state.size(-1)
entity_start_positions = entity_start_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
if entity_start_positions.device != outputs.last_hidden_state.device:
entity_start_positions = entity_start_positions.to(outputs.last_hidden_state.device)
start_states = torch.gather(outputs.last_hidden_state, -2, entity_start_positions)
entity_end_positions = entity_end_positions.unsqueeze(-1).expand(-1, -1, hidden_size)
if entity_end_positions.device != outputs.last_hidden_state.device:
entity_end_positions = entity_end_positions.to(outputs.last_hidden_state.device)
end_states = torch.gather(outputs.last_hidden_state, -2, entity_end_positions)
feature_vector = torch.cat([start_states, end_states, outputs.entity_last_hidden_state], dim=2)
feature_vector = self.dropout(feature_vector)
......
......@@ -584,6 +584,7 @@ class RobertaPreTrainedModel(PreTrainedModel):
config_class = RobertaConfig
base_model_prefix = "roberta"
supports_gradient_checkpointing = True
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
......@@ -1146,6 +1147,10 @@ class RobertaLMHead(nn.Module):
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
......
......@@ -586,6 +586,7 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
config_class = XLMRobertaConfig
base_model_prefix = "roberta"
supports_gradient_checkpointing = True
_no_split_modules = []
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module):
......@@ -1155,6 +1156,10 @@ class XLMRobertaLMHead(nn.Module):
def _tie_weights(self):
# To tie those two weights if they get disconnected (on TPU or when the bias is resized)
# For accelerate compatibility and to not break backward compatibility
if self.decoder.bias.device.type == "meta":
self.decoder.bias = self.bias
else:
self.bias = self.decoder.bias
......
......@@ -2312,11 +2312,11 @@ class ModelTesterMixin:
if model_class._no_split_modules is None:
continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""]
max_size = int(self.model_split_percents[0] * model_size)
......@@ -2334,7 +2334,7 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
......@@ -2347,12 +2347,12 @@ class ModelTesterMixin:
if model_class._no_split_modules is None:
continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
......@@ -2369,7 +2369,7 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
......@@ -2382,12 +2382,12 @@ class ModelTesterMixin:
if model_class._no_split_modules is None:
continue
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict)
base_output = model(**inputs_dict_class)
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
......@@ -2404,7 +2404,7 @@ class ModelTesterMixin:
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict)
new_output = new_model(**inputs_dict_class)
self.assertTrue(torch.allclose(base_output[0], new_output[0]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment