Commit 88e5bef5 authored by thomwolf's avatar thomwolf
Browse files

share position biases

parent 568c0ffb
...@@ -154,9 +154,10 @@ class T5LayerFF(nn.Module): ...@@ -154,9 +154,10 @@ class T5LayerFF(nn.Module):
class T5Attention(nn.Module): class T5Attention(nn.Module):
NEW_ID = itertools.count() NEW_ID = itertools.count()
def __init__(self, config): def __init__(self, config, has_relative_attention_bias=False):
super(T5Attention, self).__init__() super(T5Attention, self).__init__()
self.layer_id = next(T5Attention.NEW_ID) self.layer_id = next(T5Attention.NEW_ID)
self.has_relative_attention_bias = has_relative_attention_bias
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.relative_attention_num_buckets = config.relative_attention_num_buckets self.relative_attention_num_buckets = config.relative_attention_num_buckets
...@@ -170,6 +171,7 @@ class T5Attention(nn.Module): ...@@ -170,6 +171,7 @@ class T5Attention(nn.Module):
self.v = nn.Linear(self.dim, self.dim, bias=False) self.v = nn.Linear(self.dim, self.dim, bias=False)
self.o = nn.Linear(self.dim, self.dim, bias=False) self.o = nn.Linear(self.dim, self.dim, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set() self.pruned_heads = set()
...@@ -304,6 +306,8 @@ class T5Attention(nn.Module): ...@@ -304,6 +306,8 @@ class T5Attention(nn.Module):
scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen) scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
if position_bias is None: if position_bias is None:
if not self.has_relative_attention_bias:
raise ValueError("No position_bias provided and no weights to compute position_bias")
position_bias = self.compute_bias(qlen, klen) position_bias = self.compute_bias(qlen, klen)
scores += position_bias scores += position_bias
...@@ -325,20 +329,23 @@ class T5Attention(nn.Module): ...@@ -325,20 +329,23 @@ class T5Attention(nn.Module):
outputs = (context,) outputs = (context,)
if self.output_attentions: if self.output_attentions:
outputs = outputs + (weights,) outputs = outputs + (weights,)
if self.has_relative_attention_bias:
outputs = outputs + (position_bias,)
return outputs return outputs
class T5LayerSelfAttention(nn.Module): class T5LayerSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config, has_relative_attention_bias=False):
super(T5LayerSelfAttention, self).__init__() super(T5LayerSelfAttention, self).__init__()
self.SelfAttention = T5Attention(config) self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon) self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
def forward(self, hidden_states, attention_mask=None, head_mask=None): def forward(self, hidden_states, attention_mask=None, position_bias=None, head_mask=None):
norm_x = self.layer_norm(hidden_states) norm_x = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(norm_x, attention_output = self.SelfAttention(norm_x,
attention_mask=attention_mask, attention_mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask) head_mask=head_mask)
y = attention_output[0] y = attention_output[0]
layer_output = hidden_states + self.dropout(y) layer_output = hidden_states + self.dropout(y)
...@@ -347,17 +354,18 @@ class T5LayerSelfAttention(nn.Module): ...@@ -347,17 +354,18 @@ class T5LayerSelfAttention(nn.Module):
class T5LayerCrossAttention(nn.Module): class T5LayerCrossAttention(nn.Module):
def __init__(self, config): def __init__(self, config, has_relative_attention_bias=False):
super(T5LayerCrossAttention, self).__init__() super(T5LayerCrossAttention, self).__init__()
self.EncDecAttention = T5Attention(config) self.EncDecAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon) self.layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
def forward(self, hidden_states, kv, attention_mask=None, head_mask=None): def forward(self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None):
norm_x = self.layer_norm(hidden_states) norm_x = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(norm_x, attention_output = self.EncDecAttention(norm_x,
kv=kv, kv=kv,
attention_mask=attention_mask, attention_mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask) head_mask=head_mask)
y = attention_output[0] y = attention_output[0]
layer_output = hidden_states + self.dropout(y) layer_output = hidden_states + self.dropout(y)
...@@ -366,20 +374,22 @@ class T5LayerCrossAttention(nn.Module): ...@@ -366,20 +374,22 @@ class T5LayerCrossAttention(nn.Module):
class T5Block(nn.Module): class T5Block(nn.Module):
def __init__(self, config): def __init__(self, config, has_relative_attention_bias=False):
super(T5Block, self).__init__() super(T5Block, self).__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.layer_000 = T5LayerSelfAttention(config) self.layer_000 = T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)
if self.is_decoder: if self.is_decoder:
self.layer_001 = T5LayerCrossAttention(config) self.layer_001 = T5LayerCrossAttention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_002 = T5LayerFF(config) self.layer_002 = T5LayerFF(config)
else: else:
self.layer_001 = T5LayerFF(config) self.layer_001 = T5LayerFF(config)
def forward(self, hidden_states, attention_mask=None, def forward(self, hidden_states, attention_mask=None, position_bias=None,
encoder_hidden_states=None, encoder_attention_mask=None, head_mask=None): encoder_hidden_states=None, encoder_attention_mask=None, encoder_decoder_position_bias=None,
head_mask=None):
self_attention_outputs = self.layer_000(hidden_states, self_attention_outputs = self.layer_000(hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask) head_mask=head_mask)
hidden_states = self_attention_outputs[0] hidden_states = self_attention_outputs[0]
outputs = self_attention_outputs[1:] outputs = self_attention_outputs[1:]
...@@ -388,6 +398,7 @@ class T5Block(nn.Module): ...@@ -388,6 +398,7 @@ class T5Block(nn.Module):
cross_attention_outputs = self.layer_001(hidden_states, cross_attention_outputs = self.layer_001(hidden_states,
kv=encoder_hidden_states, kv=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
head_mask=head_mask) head_mask=head_mask)
hidden_states = cross_attention_outputs[0] hidden_states = cross_attention_outputs[0]
outputs = cross_attention_outputs[1:] + outputs outputs = cross_attention_outputs[1:] + outputs
...@@ -402,7 +413,8 @@ class T5Block(nn.Module): ...@@ -402,7 +413,8 @@ class T5Block(nn.Module):
class T5Stack(nn.Module): class T5Stack(nn.Module):
def __init__(self, config): def __init__(self, config):
super(T5Stack, self).__init__() super(T5Stack, self).__init__()
self.blocks = nn.ModuleList([T5Block(config) for _ in range(config.num_layers)]) self.blocks = nn.ModuleList([T5Block(config, has_relative_attention_bias=bool(i == 0))
for i in range(config.num_layers)])
self.final_layer_norm = nn.LayerNorm(config.layer_norm_epsilon) self.final_layer_norm = nn.LayerNorm(config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
...@@ -413,8 +425,12 @@ class T5Stack(nn.Module): ...@@ -413,8 +425,12 @@ class T5Stack(nn.Module):
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None): head_mask=None):
batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
encoder_seq_length = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones(batch_size, seq_length).to(hidden_states.device)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(batch_size, encoder_seq_length).to(hidden_states.device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
...@@ -426,8 +442,7 @@ class T5Stack(nn.Module): ...@@ -426,8 +442,7 @@ class T5Stack(nn.Module):
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if attention_mask.dim() == 2: if attention_mask.dim() == 2:
if self.config.is_decoder: if self.config.is_decoder:
batch_size, seq_length = input_ids.size() seq_ids = torch.arange(seq_length, device=hidden_states.device)
seq_ids = torch.arange(seq_length, device=input_ids.device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else: else:
...@@ -469,16 +484,22 @@ class T5Stack(nn.Module): ...@@ -469,16 +484,22 @@ class T5Stack(nn.Module):
all_hidden_states = () all_hidden_states = ()
all_attentions = () all_attentions = ()
position_bias = None position_bias = None
encoder_decoder_position_bias = None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, layer_outputs = layer_module(hidden_states,
attention_mask=extended_attention_mask, attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i]) head_mask=head_mask[i])
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if i == 0:
position_bias = layer_outputs[2] if len(layer_outputs) > 3 else None
encoder_decoder_position_bias = layer_outputs[4] if len(layer_outputs) > 5 else None
if self.output_attentions: if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_attentions = all_attentions + (layer_outputs[1],)
...@@ -700,14 +721,8 @@ class T5WithLMHead(T5PreTrainedModel): ...@@ -700,14 +721,8 @@ class T5WithLMHead(T5PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
lm_labels=None): outputs = self.transformer(encoder_input_ids, decoder_input_ids, **kwargs)
outputs = self.transformer(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
sequence_output = outputs[0] sequence_output = outputs[0]
lm_logits = self.cls(sequence_output) lm_logits = self.cls(sequence_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment