Unverified Commit e81d8d7f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Bert2Bert] allow bert2bert + relative embeddings (#14324)

* [Bert2Bert] allow bert2bert + relative embeddings

* up

* Update README_ko.md

* up

* up
parent e4d8f517
...@@ -224,7 +224,7 @@ class BertEmbeddings(nn.Module): ...@@ -224,7 +224,7 @@ class BertEmbeddings(nn.Module):
class BertSelfAttention(nn.Module): class BertSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
...@@ -241,7 +241,9 @@ class BertSelfAttention(nn.Module): ...@@ -241,7 +241,9 @@ class BertSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
...@@ -363,9 +365,9 @@ class BertSelfOutput(nn.Module): ...@@ -363,9 +365,9 @@ class BertSelfOutput(nn.Module):
class BertAttention(nn.Module): class BertAttention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
self.self = BertSelfAttention(config) self.self = BertSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = BertSelfOutput(config) self.output = BertSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
...@@ -451,7 +453,7 @@ class BertLayer(nn.Module): ...@@ -451,7 +453,7 @@ class BertLayer(nn.Module):
if self.add_cross_attention: if self.add_cross_attention:
if not self.is_decoder: if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added") raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = BertAttention(config) self.crossattention = BertAttention(config, position_embedding_type="absolute")
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
......
...@@ -216,7 +216,7 @@ class ElectraEmbeddings(nn.Module): ...@@ -216,7 +216,7 @@ class ElectraEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Electra
class ElectraSelfAttention(nn.Module): class ElectraSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
...@@ -233,7 +233,9 @@ class ElectraSelfAttention(nn.Module): ...@@ -233,7 +233,9 @@ class ElectraSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
...@@ -357,9 +359,9 @@ class ElectraSelfOutput(nn.Module): ...@@ -357,9 +359,9 @@ class ElectraSelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Electra
class ElectraAttention(nn.Module): class ElectraAttention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
self.self = ElectraSelfAttention(config) self.self = ElectraSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = ElectraSelfOutput(config) self.output = ElectraSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
...@@ -448,7 +450,7 @@ class ElectraLayer(nn.Module): ...@@ -448,7 +450,7 @@ class ElectraLayer(nn.Module):
if self.add_cross_attention: if self.add_cross_attention:
if not self.is_decoder: if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added") raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = ElectraAttention(config) self.crossattention = ElectraAttention(config, position_embedding_type="absolute")
self.intermediate = ElectraIntermediate(config) self.intermediate = ElectraIntermediate(config)
self.output = ElectraOutput(config) self.output = ElectraOutput(config)
......
...@@ -132,7 +132,7 @@ class LayoutLMEmbeddings(nn.Module): ...@@ -132,7 +132,7 @@ class LayoutLMEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->LayoutLM # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->LayoutLM
class LayoutLMSelfAttention(nn.Module): class LayoutLMSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
...@@ -149,7 +149,9 @@ class LayoutLMSelfAttention(nn.Module): ...@@ -149,7 +149,9 @@ class LayoutLMSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
...@@ -273,9 +275,9 @@ class LayoutLMSelfOutput(nn.Module): ...@@ -273,9 +275,9 @@ class LayoutLMSelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->LayoutLM
class LayoutLMAttention(nn.Module): class LayoutLMAttention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
self.self = LayoutLMSelfAttention(config) self.self = LayoutLMSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = LayoutLMSelfOutput(config) self.output = LayoutLMSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
...@@ -364,7 +366,7 @@ class LayoutLMLayer(nn.Module): ...@@ -364,7 +366,7 @@ class LayoutLMLayer(nn.Module):
if self.add_cross_attention: if self.add_cross_attention:
if not self.is_decoder: if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added") raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = LayoutLMAttention(config) self.crossattention = LayoutLMAttention(config, position_embedding_type="absolute")
self.intermediate = LayoutLMIntermediate(config) self.intermediate = LayoutLMIntermediate(config)
self.output = LayoutLMOutput(config) self.output = LayoutLMOutput(config)
......
...@@ -195,7 +195,7 @@ class MegatronBertEmbeddings(nn.Module): ...@@ -195,7 +195,7 @@ class MegatronBertEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MegatronBert # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MegatronBert
class MegatronBertSelfAttention(nn.Module): class MegatronBertSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
...@@ -212,7 +212,9 @@ class MegatronBertSelfAttention(nn.Module): ...@@ -212,7 +212,9 @@ class MegatronBertSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
......
...@@ -328,7 +328,6 @@ class RemBertSelfOutput(nn.Module): ...@@ -328,7 +328,6 @@ class RemBertSelfOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->RemBert
class RemBertAttention(nn.Module): class RemBertAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -336,6 +335,7 @@ class RemBertAttention(nn.Module): ...@@ -336,6 +335,7 @@ class RemBertAttention(nn.Module):
self.output = RemBertSelfOutput(config) self.output = RemBertSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
# Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
...@@ -354,6 +354,7 @@ class RemBertAttention(nn.Module): ...@@ -354,6 +354,7 @@ class RemBertAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward
def forward( def forward(
self, self,
hidden_states, hidden_states,
...@@ -409,7 +410,6 @@ class RemBertOutput(nn.Module): ...@@ -409,7 +410,6 @@ class RemBertOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->RemBert
class RemBertLayer(nn.Module): class RemBertLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -425,6 +425,7 @@ class RemBertLayer(nn.Module): ...@@ -425,6 +425,7 @@ class RemBertLayer(nn.Module):
self.intermediate = RemBertIntermediate(config) self.intermediate = RemBertIntermediate(config)
self.output = RemBertOutput(config) self.output = RemBertOutput(config)
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward
def forward( def forward(
self, self,
hidden_states, hidden_states,
...@@ -489,6 +490,7 @@ class RemBertLayer(nn.Module): ...@@ -489,6 +490,7 @@ class RemBertLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
......
...@@ -159,7 +159,7 @@ class RobertaEmbeddings(nn.Module): ...@@ -159,7 +159,7 @@ class RobertaEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
class RobertaSelfAttention(nn.Module): class RobertaSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
...@@ -176,7 +176,9 @@ class RobertaSelfAttention(nn.Module): ...@@ -176,7 +176,9 @@ class RobertaSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
...@@ -300,9 +302,9 @@ class RobertaSelfOutput(nn.Module): ...@@ -300,9 +302,9 @@ class RobertaSelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
class RobertaAttention(nn.Module): class RobertaAttention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
self.self = RobertaSelfAttention(config) self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = RobertaSelfOutput(config) self.output = RobertaSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
...@@ -391,7 +393,7 @@ class RobertaLayer(nn.Module): ...@@ -391,7 +393,7 @@ class RobertaLayer(nn.Module):
if self.add_cross_attention: if self.add_cross_attention:
if not self.is_decoder: if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added") raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = RobertaAttention(config) self.crossattention = RobertaAttention(config, position_embedding_type="absolute")
self.intermediate = RobertaIntermediate(config) self.intermediate = RobertaIntermediate(config)
self.output = RobertaOutput(config) self.output = RobertaOutput(config)
......
...@@ -367,14 +367,12 @@ class RoFormerSelfOutput(nn.Module): ...@@ -367,14 +367,12 @@ class RoFormerSelfOutput(nn.Module):
class RoFormerAttention(nn.Module): class RoFormerAttention(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention.__init__ with Bert->RoFormer
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.self = RoFormerSelfAttention(config) self.self = RoFormerSelfAttention(config)
self.output = RoFormerSelfOutput(config) self.output = RoFormerSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
# End Copy
# Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
...@@ -453,7 +451,6 @@ class RoFormerOutput(nn.Module): ...@@ -453,7 +451,6 @@ class RoFormerOutput(nn.Module):
class RoFormerLayer(nn.Module): class RoFormerLayer(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertLayer.__init__ with Bert->RoFormer
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward self.chunk_size_feed_forward = config.chunk_size_feed_forward
...@@ -468,7 +465,6 @@ class RoFormerLayer(nn.Module): ...@@ -468,7 +465,6 @@ class RoFormerLayer(nn.Module):
self.intermediate = RoFormerIntermediate(config) self.intermediate = RoFormerIntermediate(config)
self.output = RoFormerOutput(config) self.output = RoFormerOutput(config)
# End Copy
def forward( def forward(
self, self,
hidden_states, hidden_states,
......
...@@ -99,7 +99,7 @@ class SplinterEmbeddings(nn.Module): ...@@ -99,7 +99,7 @@ class SplinterEmbeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Splinter
class SplinterSelfAttention(nn.Module): class SplinterSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
...@@ -116,7 +116,9 @@ class SplinterSelfAttention(nn.Module): ...@@ -116,7 +116,9 @@ class SplinterSelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
...@@ -240,9 +242,9 @@ class SplinterSelfOutput(nn.Module): ...@@ -240,9 +242,9 @@ class SplinterSelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Splinter
class SplinterAttention(nn.Module): class SplinterAttention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
self.self = SplinterSelfAttention(config) self.self = SplinterSelfAttention(config, position_embedding_type=position_embedding_type)
self.output = SplinterSelfOutput(config) self.output = SplinterSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
...@@ -331,7 +333,7 @@ class SplinterLayer(nn.Module): ...@@ -331,7 +333,7 @@ class SplinterLayer(nn.Module):
if self.add_cross_attention: if self.add_cross_attention:
if not self.is_decoder: if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added") raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = SplinterAttention(config) self.crossattention = SplinterAttention(config, position_embedding_type="absolute")
self.intermediate = SplinterIntermediate(config) self.intermediate = SplinterIntermediate(config)
self.output = SplinterOutput(config) self.output = SplinterOutput(config)
......
...@@ -456,7 +456,6 @@ class TapasSelfOutput(nn.Module): ...@@ -456,7 +456,6 @@ class TapasSelfOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Tapas
class TapasAttention(nn.Module): class TapasAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -464,6 +463,7 @@ class TapasAttention(nn.Module): ...@@ -464,6 +463,7 @@ class TapasAttention(nn.Module):
self.output = TapasSelfOutput(config) self.output = TapasSelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
# Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
return return
...@@ -482,6 +482,7 @@ class TapasAttention(nn.Module): ...@@ -482,6 +482,7 @@ class TapasAttention(nn.Module):
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads) self.pruned_heads = self.pruned_heads.union(heads)
# Copied from transformers.models.bert.modeling_bert.BertAttention.forward
def forward( def forward(
self, self,
hidden_states, hidden_states,
...@@ -537,7 +538,6 @@ class TapasOutput(nn.Module): ...@@ -537,7 +538,6 @@ class TapasOutput(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Tapas
class TapasLayer(nn.Module): class TapasLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
...@@ -553,6 +553,7 @@ class TapasLayer(nn.Module): ...@@ -553,6 +553,7 @@ class TapasLayer(nn.Module):
self.intermediate = TapasIntermediate(config) self.intermediate = TapasIntermediate(config)
self.output = TapasOutput(config) self.output = TapasOutput(config)
# Copied from transformers.models.bert.modeling_bert.BertLayer.forward
def forward( def forward(
self, self,
hidden_states, hidden_states,
...@@ -617,6 +618,7 @@ class TapasLayer(nn.Module): ...@@ -617,6 +618,7 @@ class TapasLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.bert.modeling_bert.BertLayer.feed_forward_chunk
def feed_forward_chunk(self, attention_output): def feed_forward_chunk(self, attention_output):
intermediate_output = self.intermediate(attention_output) intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output) layer_output = self.output(intermediate_output, attention_output)
......
...@@ -203,7 +203,7 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module): ...@@ -203,7 +203,7 @@ class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module): class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError( raise ValueError(
...@@ -220,7 +220,7 @@ class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module): ...@@ -220,7 +220,7 @@ class {{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
self.value = nn.Linear(config.hidden_size, self.all_head_size) self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") self.position_embedding_type = position_embedding_type or getattr(config, "position_embedding_type", "absolute")
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
...@@ -344,9 +344,9 @@ class {{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module): ...@@ -344,9 +344,9 @@ class {{cookiecutter.camelcase_modelname}}SelfOutput(nn.Module):
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}Attention(nn.Module): class {{cookiecutter.camelcase_modelname}}Attention(nn.Module):
def __init__(self, config): def __init__(self, config, position_embedding_type=None):
super().__init__() super().__init__()
self.self = {{cookiecutter.camelcase_modelname}}SelfAttention(config) self.self = {{cookiecutter.camelcase_modelname}}SelfAttention(config, position_embedding_type=position_embedding_type)
self.output = {{cookiecutter.camelcase_modelname}}SelfOutput(config) self.output = {{cookiecutter.camelcase_modelname}}SelfOutput(config)
self.pruned_heads = set() self.pruned_heads = set()
...@@ -434,7 +434,7 @@ class {{cookiecutter.camelcase_modelname}}Layer(nn.Module): ...@@ -434,7 +434,7 @@ class {{cookiecutter.camelcase_modelname}}Layer(nn.Module):
self.add_cross_attention = config.add_cross_attention self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention: if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
self.crossattention = {{cookiecutter.camelcase_modelname}}Attention(config) self.crossattention = {{cookiecutter.camelcase_modelname}}Attention(config, position_embedding_type="absolute")
self.intermediate = {{cookiecutter.camelcase_modelname}}Intermediate(config) self.intermediate = {{cookiecutter.camelcase_modelname}}Intermediate(config)
self.output = {{cookiecutter.camelcase_modelname}}Output(config) self.output = {{cookiecutter.camelcase_modelname}}Output(config)
......
...@@ -567,6 +567,24 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -567,6 +567,24 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
"labels": decoder_token_labels, "labels": decoder_token_labels,
} }
def test_relative_position_embeds(self):
config_and_inputs = self.prepare_config_and_inputs()
encoder_config = config_and_inputs["config"]
decoder_config = config_and_inputs["decoder_config"]
encoder_config.position_embedding_type = "relative_key_query"
decoder_config.position_embedding_type = "relative_key_query"
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
model = EncoderDecoderModel(config).eval().to(torch_device)
logits = model(
input_ids=config_and_inputs["input_ids"], decoder_input_ids=config_and_inputs["decoder_input_ids"]
).logits
self.assertTrue(logits.shape, (13, 7))
@slow @slow
def test_bert2bert_summarization(self): def test_bert2bert_summarization(self):
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment