Commit dc894411 authored by thomwolf's avatar thomwolf
Browse files

update CTRL pytorch model

parent 320b7a7e
......@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch CTRL model."""
""" PyTorch CTRL model."""
from __future__ import absolute_import, division, print_function, unicode_literals
......@@ -27,7 +27,6 @@ from io import open
import numpy as np
import torch
import torch.nn as nn
import pdb
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter
......@@ -46,7 +45,9 @@ def angle_defn(pos, i, d_model_size):
def positional_encoding(position, d_model_size, dtype):
# create the sinusoidal pattern for the positional encoding
angle_rads = (angle_defn(torch.arange(position, dtype=dtype).unsqueeze(1), torch.arange(d_model_size, dtype=dtype).unsqueeze(0), d_model_size))
angle_rads = (angle_defn(torch.arange(position, dtype=dtype).unsqueeze(1),
torch.arange(d_model_size, dtype=dtype).unsqueeze(0),
d_model_size))
sines = torch.sin(angle_rads[:, 0::2])
cosines = torch.cos(angle_rads[:, 1::2])
......@@ -54,7 +55,7 @@ def positional_encoding(position, d_model_size, dtype):
pos_encoding = torch.cat([sines, cosines], dim=-1).unsqueeze(0)
return pos_encoding
def scaled_dot_product_attention(q, k, v, mask):
def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
# calculate attention
matmul_qk = torch.matmul(q, k.permute(0,1,3,2))
......@@ -64,15 +65,25 @@ def scaled_dot_product_attention(q, k, v, mask):
if mask is not None:
scaled_attention_logits += (mask * -1e4)
if attention_mask is not None:
# Apply the attention mask
scaled_attention_logits = scaled_attention_logits + attention_mask
attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
# Mask heads if we want to
if head_mask is not None:
attention_weights = attention_weights * head_mask
output = torch.matmul(attention_weights, v)
return output, attention_weights
class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model_size, num_heads):
def __init__(self, d_model_size, num_heads, output_attentions=False):
super(MultiHeadAttention, self).__init__()
self.output_attentions = output_attentions
self.num_heads = num_heads
self.d_model_size = d_model_size
......@@ -88,7 +99,7 @@ class MultiHeadAttention(torch.nn.Module):
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
return x.permute([0, 2, 1, 3])
def forward(self, v, k, q, mask):
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None):
batch_size = q.shape[0]
q = self.Wq(q)
......@@ -98,7 +109,13 @@ class MultiHeadAttention(torch.nn.Module):
q = self.split_into_heads(q, batch_size)
k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size)
output = scaled_dot_product_attention(q, k, v, mask)
if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
k = torch.cat((past_key, k), dim=-1)
v = torch.cat((past_value, v), dim=-2)
present = torch.stack((k.transpose(-2, -1), v)) # transpose to have same shapes for stacking
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask, output_attentions)
scaled_attention = output[0].permute([0, 2, 1, 3])
attn = output[1]
original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size)
......@@ -109,14 +126,16 @@ class MultiHeadAttention(torch.nn.Module):
def point_wise_feed_forward_network(d_model_size, dff):
return torch.nn.Sequential(torch.nn.Linear(d_model_size, dff), torch.nn.ReLU(), torch.nn.Linear(dff, d_model_size))
return torch.nn.Sequential(torch.nn.Linear(d_model_size, dff),
torch.nn.ReLU(),
torch.nn.Linear(dff, d_model_size))
class EncoderLayer(torch.nn.Module):
def __init__(self, d_model_size, num_heads, dff, rate=0.1):
def __init__(self, d_model_size, num_heads, dff, rate=0.1, output_attentions=False):
super(EncoderLayer, self).__init__()
self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads)
self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads, output_attentions)
self.ffn = point_wise_feed_forward_network(d_model_size, dff)
self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6)
......@@ -125,9 +144,12 @@ class EncoderLayer(torch.nn.Module):
self.dropout1 = torch.nn.Dropout(rate)
self.dropout2 = torch.nn.Dropout(rate)
def forward(self, x, mask):
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None):
normed = self.layernorm1(x)
attn_output, attn = self.multi_head_attention(normed, normed, normed, mask)
attn_output, attn = self.multi_head_attention(normed, normed, normed, mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask)
attn_output = self.dropout1(attn_output)
out1 = x + attn_output
......@@ -147,9 +169,6 @@ class CTRLPreTrainedModel(PreTrainedModel):
pretrained_model_archive_map = CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(CTRLPreTrainedModel, self).__init__(*inputs, **kwargs)
def _init_weights(self, module):
""" Initialize the weights.
"""
......@@ -256,7 +275,11 @@ class CTRLModel(CTRLPreTrainedModel):
self.dropout = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop) for _ in range(config.n_layer)])
self.h = nn.ModuleList([EncoderLayer(config.n_embd,
config.n_head,
config.dff,
config.resid_pdrop,
config.output_attentions) for _ in range(config.n_layer)])
self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.init_weights()
......@@ -272,8 +295,54 @@ class CTRLModel(CTRLPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None):
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Attention mask.
if attention_mask is not None:
attention_mask = attention_mask.view(-1, input_shape[-1])
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.n_layer
embedded = self.w(input_ids)
x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
......@@ -282,26 +351,40 @@ class CTRLModel(CTRLPreTrainedModel):
x *= np.sqrt(self.d_model_size)
x += self.pos_encoding[:, :seq_len, :].to(x.device)
x += self.pos_encoding[:, position_ids, :].to(x.device)
x = self.dropout(x)
output_shape = input_shape + (x.size(-1),)
presents = ()
all_hidden_states = ()
all_attentions = []
for i in range(self.num_layers):
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (x,)
x, attn = self.h[i](x, mask)
all_hidden_states = all_hidden_states + (x.view(*output_shape),)
outputs = h(x,
mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i])
x, present = outputs[:2]
presents = presents + (present,)
if self.output_attentions:
all_attentions.append(attn)
all_attentions.append(outputs[2])
x = self.layernorm(x)
x = x.view(*output_shape)
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (x,)
outputs = (x, None)
outputs = (x, presents)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,)
return outputs
......@@ -359,13 +442,17 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.lm_head,
self.transformer.w)
#self._tie_or_clone_weights(self.lm_head.bias,
# self.transformer.w.bias)
self._tie_or_clone_weights(self.lm_head, self.transformer.w)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None):
transformer_outputs = self.transformer(input_ids)
transformer_outputs = self.transformer(input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
......@@ -383,5 +470,3 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
outputs = (loss,) + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment