Commit dc894411 authored by thomwolf's avatar thomwolf
Browse files

update CTRL pytorch model

parent 320b7a7e
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""PyTorch CTRL model.""" """ PyTorch CTRL model."""
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
...@@ -27,7 +27,6 @@ from io import open ...@@ -27,7 +27,6 @@ from io import open
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import pdb
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -46,7 +45,9 @@ def angle_defn(pos, i, d_model_size): ...@@ -46,7 +45,9 @@ def angle_defn(pos, i, d_model_size):
def positional_encoding(position, d_model_size, dtype): def positional_encoding(position, d_model_size, dtype):
# create the sinusoidal pattern for the positional encoding # create the sinusoidal pattern for the positional encoding
angle_rads = (angle_defn(torch.arange(position, dtype=dtype).unsqueeze(1), torch.arange(d_model_size, dtype=dtype).unsqueeze(0), d_model_size)) angle_rads = (angle_defn(torch.arange(position, dtype=dtype).unsqueeze(1),
torch.arange(d_model_size, dtype=dtype).unsqueeze(0),
d_model_size))
sines = torch.sin(angle_rads[:, 0::2]) sines = torch.sin(angle_rads[:, 0::2])
cosines = torch.cos(angle_rads[:, 1::2]) cosines = torch.cos(angle_rads[:, 1::2])
...@@ -54,7 +55,7 @@ def positional_encoding(position, d_model_size, dtype): ...@@ -54,7 +55,7 @@ def positional_encoding(position, d_model_size, dtype):
pos_encoding = torch.cat([sines, cosines], dim=-1).unsqueeze(0) pos_encoding = torch.cat([sines, cosines], dim=-1).unsqueeze(0)
return pos_encoding return pos_encoding
def scaled_dot_product_attention(q, k, v, mask): def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
# calculate attention # calculate attention
matmul_qk = torch.matmul(q, k.permute(0,1,3,2)) matmul_qk = torch.matmul(q, k.permute(0,1,3,2))
...@@ -64,15 +65,25 @@ def scaled_dot_product_attention(q, k, v, mask): ...@@ -64,15 +65,25 @@ def scaled_dot_product_attention(q, k, v, mask):
if mask is not None: if mask is not None:
scaled_attention_logits += (mask * -1e4) scaled_attention_logits += (mask * -1e4)
if attention_mask is not None:
# Apply the attention mask
scaled_attention_logits = scaled_attention_logits + attention_mask
attention_weights = torch.softmax(scaled_attention_logits, dim=-1) attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
# Mask heads if we want to
if head_mask is not None:
attention_weights = attention_weights * head_mask
output = torch.matmul(attention_weights, v) output = torch.matmul(attention_weights, v)
return output, attention_weights return output, attention_weights
class MultiHeadAttention(torch.nn.Module): class MultiHeadAttention(torch.nn.Module):
def __init__(self, d_model_size, num_heads): def __init__(self, d_model_size, num_heads, output_attentions=False):
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
self.output_attentions = output_attentions
self.num_heads = num_heads self.num_heads = num_heads
self.d_model_size = d_model_size self.d_model_size = d_model_size
...@@ -88,7 +99,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -88,7 +99,7 @@ class MultiHeadAttention(torch.nn.Module):
x = x.reshape(batch_size, -1, self.num_heads, self.depth) x = x.reshape(batch_size, -1, self.num_heads, self.depth)
return x.permute([0, 2, 1, 3]) return x.permute([0, 2, 1, 3])
def forward(self, v, k, q, mask): def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None):
batch_size = q.shape[0] batch_size = q.shape[0]
q = self.Wq(q) q = self.Wq(q)
...@@ -98,7 +109,13 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -98,7 +109,13 @@ class MultiHeadAttention(torch.nn.Module):
q = self.split_into_heads(q, batch_size) q = self.split_into_heads(q, batch_size)
k = self.split_into_heads(k, batch_size) k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size) v = self.split_into_heads(v, batch_size)
output = scaled_dot_product_attention(q, k, v, mask) if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
k = torch.cat((past_key, k), dim=-1)
v = torch.cat((past_value, v), dim=-2)
present = torch.stack((k.transpose(-2, -1), v)) # transpose to have same shapes for stacking
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask, output_attentions)
scaled_attention = output[0].permute([0, 2, 1, 3]) scaled_attention = output[0].permute([0, 2, 1, 3])
attn = output[1] attn = output[1]
original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size) original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size)
...@@ -109,14 +126,16 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -109,14 +126,16 @@ class MultiHeadAttention(torch.nn.Module):
def point_wise_feed_forward_network(d_model_size, dff): def point_wise_feed_forward_network(d_model_size, dff):
return torch.nn.Sequential(torch.nn.Linear(d_model_size, dff), torch.nn.ReLU(), torch.nn.Linear(dff, d_model_size)) return torch.nn.Sequential(torch.nn.Linear(d_model_size, dff),
torch.nn.ReLU(),
torch.nn.Linear(dff, d_model_size))
class EncoderLayer(torch.nn.Module): class EncoderLayer(torch.nn.Module):
def __init__(self, d_model_size, num_heads, dff, rate=0.1): def __init__(self, d_model_size, num_heads, dff, rate=0.1, output_attentions=False):
super(EncoderLayer, self).__init__() super(EncoderLayer, self).__init__()
self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads) self.multi_head_attention = MultiHeadAttention(d_model_size, num_heads, output_attentions)
self.ffn = point_wise_feed_forward_network(d_model_size, dff) self.ffn = point_wise_feed_forward_network(d_model_size, dff)
self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6) self.layernorm1 = torch.nn.LayerNorm(d_model_size, eps=1e-6)
...@@ -125,9 +144,12 @@ class EncoderLayer(torch.nn.Module): ...@@ -125,9 +144,12 @@ class EncoderLayer(torch.nn.Module):
self.dropout1 = torch.nn.Dropout(rate) self.dropout1 = torch.nn.Dropout(rate)
self.dropout2 = torch.nn.Dropout(rate) self.dropout2 = torch.nn.Dropout(rate)
def forward(self, x, mask): def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None):
normed = self.layernorm1(x) normed = self.layernorm1(x)
attn_output, attn = self.multi_head_attention(normed, normed, normed, mask) attn_output, attn = self.multi_head_attention(normed, normed, normed, mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask)
attn_output = self.dropout1(attn_output) attn_output = self.dropout1(attn_output)
out1 = x + attn_output out1 = x + attn_output
...@@ -147,9 +169,6 @@ class CTRLPreTrainedModel(PreTrainedModel): ...@@ -147,9 +169,6 @@ class CTRLPreTrainedModel(PreTrainedModel):
pretrained_model_archive_map = CTRL_PRETRAINED_MODEL_ARCHIVE_MAP pretrained_model_archive_map = CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "transformer" base_model_prefix = "transformer"
def __init__(self, *inputs, **kwargs):
super(CTRLPreTrainedModel, self).__init__(*inputs, **kwargs)
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights. """ Initialize the weights.
""" """
...@@ -256,7 +275,11 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -256,7 +275,11 @@ class CTRLModel(CTRLPreTrainedModel):
self.dropout = nn.Dropout(config.embd_pdrop) self.dropout = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([EncoderLayer(config.n_embd, config.n_head, config.dff, config.resid_pdrop) for _ in range(config.n_layer)]) self.h = nn.ModuleList([EncoderLayer(config.n_embd,
config.n_head,
config.dff,
config.resid_pdrop,
config.output_attentions) for _ in range(config.n_layer)])
self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.layernorm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.init_weights() self.init_weights()
...@@ -272,8 +295,54 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -272,8 +295,54 @@ class CTRLModel(CTRLPreTrainedModel):
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads) self.h[layer].attn.prune_heads(heads)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
labels=None): input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past is None:
past_length = 0
past = [None] * len(self.h)
else:
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Attention mask.
if attention_mask is not None:
attention_mask = attention_mask.view(-1, input_shape[-1])
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.n_layer
embedded = self.w(input_ids) embedded = self.w(input_ids)
x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
...@@ -282,26 +351,40 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -282,26 +351,40 @@ class CTRLModel(CTRLPreTrainedModel):
x *= np.sqrt(self.d_model_size) x *= np.sqrt(self.d_model_size)
x += self.pos_encoding[:, :seq_len, :].to(x.device) x += self.pos_encoding[:, position_ids, :].to(x.device)
x = self.dropout(x) x = self.dropout(x)
output_shape = input_shape + (x.size(-1),)
presents = ()
all_hidden_states = () all_hidden_states = ()
all_attentions = [] all_attentions = []
for i in range(self.num_layers): for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (x,) all_hidden_states = all_hidden_states + (x.view(*output_shape),)
x, attn = self.h[i](x, mask) outputs = h(x,
mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i])
x, present = outputs[:2]
presents = presents + (present,)
if self.output_attentions: if self.output_attentions:
all_attentions.append(attn) all_attentions.append(outputs[2])
x = self.layernorm(x) x = self.layernorm(x)
x = x.view(*output_shape)
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (x,) all_hidden_states = all_hidden_states + (x,)
outputs = (x, None) outputs = (x, presents)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs = outputs + (all_attentions,) outputs = outputs + (all_attentions,)
return outputs return outputs
...@@ -359,13 +442,17 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -359,13 +442,17 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
""" Make sure we are sharing the input and output embeddings. """ Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead. Export to TorchScript can't handle parameter sharing so we are cloning them instead.
""" """
self._tie_or_clone_weights(self.lm_head, self._tie_or_clone_weights(self.lm_head, self.transformer.w)
self.transformer.w)
#self._tie_or_clone_weights(self.lm_head.bias,
# self.transformer.w.bias)
def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, def forward(self, input_ids, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
labels=None): labels=None):
transformer_outputs = self.transformer(input_ids) transformer_outputs = self.transformer(input_ids,
past=past,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
...@@ -383,5 +470,3 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -383,5 +470,3 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
outputs = (loss,) + outputs outputs = (loss,) + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment