Commit ad88563b authored by thomwolf's avatar thomwolf
Browse files

WIP GPT-2

parent 64d83c7a
......@@ -28,7 +28,7 @@ from io import open
import numpy as np
import tensorflow as tf
from .modeling_tf_utils import TFPreTrainedModel
from .modeling_tf_utils import TFPreTrainedModel, TFConv1D
from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings
......@@ -116,97 +116,100 @@ def gelu(x):
class TFAttention(tf.keras.layers.Layer):
def __init__(self, nx, n_ctx, config, scale=False):
super(Attention, self).__init__()
def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
super(TFAttention, self).__init__(**kwargs)
self.output_attentions = config.output_attentions
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
self.n_ctx = n_ctx
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.c_attn = TFConv1D(n_state * 3, nx)
self.c_proj = TFConv1D(n_state, nx)
self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
self.pruned_heads = set()
def prune_heads(self, heads):
if len(heads) == 0:
return
mask = torch.ones(self.n_head, self.split_size // self.n_head)
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
for head in heads:
# Compute how many pruned heads are before the head and move the index accordingly
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
# Update hyper params
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
self.n_head = self.n_head - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)
def _attn(self, q, k, v, head_mask=None):
w = torch.matmul(q, k)
pass
@staticmethod
@tf.function
def attention_mask(nd, ns, *, dtype):
"""1's in the lower triangle, counting from the lower right corner.
Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
"""
i = tf.range(nd)[:,None]
j = tf.range(ns)
m = i >= j - ns + nd
return tf.cast(m, dtype)
@tf.function
def _attn(self, inputs, training=False):
q, k, v, head_mask = inputs
# q, k, v have shape [batch, heads, sequence, features]
w = tf.matmul(q, k, transpose_b=True)
if self.scale:
w = w / math.sqrt(v.size(-1))
nd, ns = w.size(-2), w.size(-1)
b = self.bias[:, :, ns-nd:ns, :ns]
n_state = shape_list(v)[-1]
w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype))
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
_, _, nd, ns = shape_list(w)
b = self.attention_mask(nd, ns, dtype=w.dtype)
b = tf.reshape(b, [1, 1, nd, ns])
w = w * b - 1e4 * (1 - b)
w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w)
w = tf.nn.softmax(w)
w = self.attn_dropout(w, training=training)
# Mask heads if we want to
if head_mask is not None:
w = w * head_mask
outputs = [torch.matmul(w, v)]
outputs = [tf.matmul(w, v)]
if self.output_attentions:
outputs.append(w)
return outputs
@tf.function
def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous()
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k:
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
else:
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
x = tf.transpose(x, [0, 2, 1, 3])
x_shape = tf.shape(x)
new_x_shape = x_shape[:-2] + (x_shape[-2] * x_shape[-1],)
return tf.reshape(x, new_x_shape)
@tf.function
def split_heads(self, x):
x_shape = tf.shape(x)
new_x_shape = x_shape[:-1] + (self.n_head, x_shape[-1] // self.n_head)
x = tf.reshape(x, new_x_shape)
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
@tf.function
def call(self, inputs, training=False):
x, layer_past, head_mask = inputs
def forward(self, x, layer_past=None, head_mask=None):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query, key, value = tf.split(x, 3, axis=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
key = self.split_heads(key)
value = self.split_heads(value)
if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
past_key, past_value = tf.unstack(layer_past, axis=1)
key = tf.concat([past_key, key], axis=-2)
value = tf.concat([past_value, value], axis=-2)
present = tf.stack([key, value], axis=1)
attn_outputs = self._attn(query, key, value, head_mask)
a = attn_outputs[0]
a = self.merge_heads(a)
a = self.c_proj(a)
a = self.resid_dropout(a)
a = self.resid_dropout(a, training=training)
outputs = [a, present] + attn_outputs[1:]
return outputs # a, present, (attentions)
......@@ -216,8 +219,8 @@ class MLP(nn.Module):
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
super(MLP, self).__init__()
nx = config.n_embd
self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state)
self.c_fc = TFConv1D(n_state, nx)
self.c_proj = TFConv1D(nx, n_state)
self.act = gelu
self.dropout = nn.Dropout(config.resid_pdrop)
......@@ -227,9 +230,9 @@ class MLP(nn.Module):
return self.dropout(h2)
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False):
super(Block, self).__init__()
class TFBlock(tf.keras.layers.Layer):
def __init__(self, n_ctx, config, scale=False, **kwargs):
super(TFBlock, self).__init__(**kwargs)
nx = config.n_embd
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale)
......
......@@ -263,13 +263,25 @@ class TFConv1D(tf.keras.layers.Layer):
"""
super(TFConv1D, self).__init__()
self.nf = nf
w = torch.empty(nx, nf)
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w)
self.bias = nn.Parameter(torch.zeros(nf))
self.nx = nx
def build(self, input_shape):
self.weight = self.add_weight(
"weight",
shape=[self.nx, self.nf],
initializer=tf.random_normal_initializer(
mean=0., stddev=0.02))
self.bias = self.add_weight(
"bias",
shape=[self.nx, self.nf],
initializer=tf.zeros_initializer())
@tf.function
def call(self, x):
size_out = t.shape(x)[:-1] + (self.nf,)
x = tf.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out)
size_out = tf.shape(x)[:-1] + (self.nf,)
x = tf.reshape(x, [-1, tf.shape(x)[-1]])
x = tf.matmul(x, self.weight) + self.bias
x = tf.reshape(x, size_out)
return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment