"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1d5c3a3d966f2c6459d6a275296a5526af3e5563"
Commit ad88563b authored by thomwolf's avatar thomwolf
Browse files

WIP GPT-2

parent 64d83c7a
...@@ -28,7 +28,7 @@ from io import open ...@@ -28,7 +28,7 @@ from io import open
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from .modeling_tf_utils import TFPreTrainedModel from .modeling_tf_utils import TFPreTrainedModel, TFConv1D
from .configuration_gpt2 import GPT2Config from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings from .file_utils import add_start_docstrings
...@@ -116,97 +116,100 @@ def gelu(x): ...@@ -116,97 +116,100 @@ def gelu(x):
class TFAttention(tf.keras.layers.Layer): class TFAttention(tf.keras.layers.Layer):
def __init__(self, nx, n_ctx, config, scale=False): def __init__(self, nx, n_ctx, config, scale=False, **kwargs):
super(Attention, self).__init__() super(TFAttention, self).__init__(**kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0 assert n_state % config.n_head == 0
self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) self.n_ctx = n_ctx
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.c_attn = Conv1D(n_state * 3, nx) self.c_attn = TFConv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx) self.c_proj = TFConv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = tf.keras.layers.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop) self.resid_dropout = tf.keras.layers.Dropout(config.resid_pdrop)
self.pruned_heads = set() self.pruned_heads = set()
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: pass
return
mask = torch.ones(self.n_head, self.split_size // self.n_head) @staticmethod
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads @tf.function
for head in heads: def attention_mask(nd, ns, *, dtype):
# Compute how many pruned heads are before the head and move the index accordingly """1's in the lower triangle, counting from the lower right corner.
head = head - sum(1 if h < head else 0 for h in self.pruned_heads) Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
mask[head] = 0 """
mask = mask.view(-1).contiguous().eq(1) i = tf.range(nd)[:,None]
index = torch.arange(len(mask))[mask].long() j = tf.range(ns)
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)]) m = i >= j - ns + nd
return tf.cast(m, dtype)
# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) @tf.function
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) def _attn(self, inputs, training=False):
q, k, v, head_mask = inputs
# Update hyper params # q, k, v have shape [batch, heads, sequence, features]
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) w = tf.matmul(q, k, transpose_b=True)
self.n_head = self.n_head - len(heads)
self.pruned_heads = self.pruned_heads.union(heads)
def _attn(self, q, k, v, head_mask=None):
w = torch.matmul(q, k)
if self.scale: if self.scale:
w = w / math.sqrt(v.size(-1)) n_state = shape_list(v)[-1]
nd, ns = w.size(-2), w.size(-1) w = w * tf.rsqrt(tf.cast(v.shape[-1].value, w.dtype))
b = self.bias[:, :, ns-nd:ns, :ns]
# w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
_, _, nd, ns = shape_list(w)
b = self.attention_mask(nd, ns, dtype=w.dtype)
b = tf.reshape(b, [1, 1, nd, ns])
w = w * b - 1e4 * (1 - b) w = w * b - 1e4 * (1 - b)
w = nn.Softmax(dim=-1)(w) w = tf.nn.softmax(w)
w = self.attn_dropout(w) w = self.attn_dropout(w, training=training)
# Mask heads if we want to # Mask heads if we want to
if head_mask is not None: if head_mask is not None:
w = w * head_mask w = w * head_mask
outputs = [torch.matmul(w, v)] outputs = [tf.matmul(w, v)]
if self.output_attentions: if self.output_attentions:
outputs.append(w) outputs.append(w)
return outputs return outputs
@tf.function
def merge_heads(self, x): def merge_heads(self, x):
x = x.permute(0, 2, 1, 3).contiguous() x = tf.transpose(x, [0, 2, 1, 3])
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) x_shape = tf.shape(x)
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states new_x_shape = x_shape[:-2] + (x_shape[-2] * x_shape[-1],)
return tf.reshape(x, new_x_shape)
def split_heads(self, x, k=False):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) @tf.function
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states def split_heads(self, x):
if k: x_shape = tf.shape(x)
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) new_x_shape = x_shape[:-1] + (self.n_head, x_shape[-1] // self.n_head)
else: x = tf.reshape(x, new_x_shape)
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
@tf.function
def call(self, inputs, training=False):
x, layer_past, head_mask = inputs
def forward(self, x, layer_past=None, head_mask=None):
x = self.c_attn(x) x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2) query, key, value = tf.split(x, 3, axis=2)
query = self.split_heads(query) query = self.split_heads(query)
key = self.split_heads(key, k=True) key = self.split_heads(key)
value = self.split_heads(value) value = self.split_heads(value)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below past_key, past_value = tf.unstack(layer_past, axis=1)
key = torch.cat((past_key, key), dim=-1) key = tf.concat([past_key, key], axis=-2)
value = torch.cat((past_value, value), dim=-2) value = tf.concat([past_value, value], axis=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking present = tf.stack([key, value], axis=1)
attn_outputs = self._attn(query, key, value, head_mask) attn_outputs = self._attn(query, key, value, head_mask)
a = attn_outputs[0] a = attn_outputs[0]
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
a = self.resid_dropout(a) a = self.resid_dropout(a, training=training)
outputs = [a, present] + attn_outputs[1:] outputs = [a, present] + attn_outputs[1:]
return outputs # a, present, (attentions) return outputs # a, present, (attentions)
...@@ -216,8 +219,8 @@ class MLP(nn.Module): ...@@ -216,8 +219,8 @@ class MLP(nn.Module):
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
super(MLP, self).__init__() super(MLP, self).__init__()
nx = config.n_embd nx = config.n_embd
self.c_fc = Conv1D(n_state, nx) self.c_fc = TFConv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state) self.c_proj = TFConv1D(nx, n_state)
self.act = gelu self.act = gelu
self.dropout = nn.Dropout(config.resid_pdrop) self.dropout = nn.Dropout(config.resid_pdrop)
...@@ -227,9 +230,9 @@ class MLP(nn.Module): ...@@ -227,9 +230,9 @@ class MLP(nn.Module):
return self.dropout(h2) return self.dropout(h2)
class Block(nn.Module): class TFBlock(tf.keras.layers.Layer):
def __init__(self, n_ctx, config, scale=False): def __init__(self, n_ctx, config, scale=False, **kwargs):
super(Block, self).__init__() super(TFBlock, self).__init__(**kwargs)
nx = config.n_embd nx = config.n_embd
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale) self.attn = Attention(nx, n_ctx, config, scale)
......
...@@ -263,13 +263,25 @@ class TFConv1D(tf.keras.layers.Layer): ...@@ -263,13 +263,25 @@ class TFConv1D(tf.keras.layers.Layer):
""" """
super(TFConv1D, self).__init__() super(TFConv1D, self).__init__()
self.nf = nf self.nf = nf
w = torch.empty(nx, nf) self.nx = nx
nn.init.normal_(w, std=0.02)
self.weight = nn.Parameter(w) def build(self, input_shape):
self.bias = nn.Parameter(torch.zeros(nf)) self.weight = self.add_weight(
"weight",
shape=[self.nx, self.nf],
initializer=tf.random_normal_initializer(
mean=0., stddev=0.02))
self.bias = self.add_weight(
"bias",
shape=[self.nx, self.nf],
initializer=tf.zeros_initializer())
@tf.function
def call(self, x): def call(self, x):
size_out = t.shape(x)[:-1] + (self.nf,) size_out = tf.shape(x)[:-1] + (self.nf,)
x = tf.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
x = x.view(*size_out) x = tf.reshape(x, [-1, tf.shape(x)[-1]])
x = tf.matmul(x, self.weight) + self.bias
x = tf.reshape(x, size_out)
return x return x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment