Commit 33e72b08 authored by thomwolf's avatar thomwolf
Browse files

fix inner dimensions for 3B/11B models

parent f19dad61
......@@ -30,7 +30,7 @@ from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from .modeling_utils import PreTrainedModel
from .modeling_utils import PreTrainedModel, prune_linear_layer
from .configuration_t5 import T5Config
from .file_utils import add_start_docstrings, DUMMY_INPUTS, DUMMY_MASK
......@@ -191,28 +191,26 @@ class T5Attention(nn.Module):
self.output_attentions = config.output_attentions
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.dim = config.d_model
self.d_model = config.d_model
self.d_kv = config.d_kv
self.n_heads = config.num_heads
self.dropout = config.dropout_rate
assert self.dim % self.n_heads == 0
assert self.dim // self.n_heads == self.d_kv
self.inner_dim = self.n_heads * self.d_kv
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.dim, self.dim, bias=False)
self.k = nn.Linear(self.dim, self.dim, bias=False)
self.v = nn.Linear(self.dim, self.dim, bias=False)
self.o = nn.Linear(self.dim, self.dim, bias=False)
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set()
def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads
if len(heads) == 0:
return
mask = torch.ones(self.n_heads, attention_head_size)
mask = torch.ones(self.n_heads, self.d_kv)
heads = set(heads) - self.pruned_heads
for head in heads:
head -= sum(1 if h < head else 0 for h in self.pruned_heads)
......@@ -226,7 +224,7 @@ class T5Attention(nn.Module):
self.o = prune_linear_layer(self.o, index, dim=1)
# Update hyper params
self.n_heads = self.n_heads - len(heads)
self.dim = attention_head_size * self.n_heads
self.inner_dim = self.d_kv * self.n_heads
self.pruned_heads = self.pruned_heads.union(heads)
@staticmethod
......@@ -303,17 +301,14 @@ class T5Attention(nn.Module):
klen = qlen if cache is None else cache['slen'] + qlen
else:
klen = kv.size(1)
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads = self.n_heads
dim_per_head = self.dim // n_heads
def shape(x):
""" projection """
return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)
return x.view(bs, -1, self.n_heads, self.d_kv).transpose(1, 2)
def unshape(x):
""" compute context """
return x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
return x.transpose(1, 2).contiguous().view(bs, -1, self.inner_dim)
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head)
if kv is None:
......
......@@ -108,17 +108,16 @@ class TFT5Attention(tf.keras.layers.Layer):
self.output_attentions = config.output_attentions
self.relative_attention_num_buckets = config.relative_attention_num_buckets
self.dim = config.d_model
self.d_model = config.d_model
self.d_kv = config.d_kv
self.n_heads = config.num_heads
assert self.dim % self.n_heads == 0
assert self.dim // self.n_heads == self.d_kv
self.inner_dim = self.n_heads * self.d_kv
# Mesh TensorFlow initialization to avoid scaling before softmax
self.q = tf.keras.layers.Dense(self.dim, use_bias=False, name='q')
self.k = tf.keras.layers.Dense(self.dim, use_bias=False, name='k')
self.v = tf.keras.layers.Dense(self.dim, use_bias=False, name='v')
self.o = tf.keras.layers.Dense(self.dim, use_bias=False, name='o')
self.q = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name='q')
self.k = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name='k')
self.v = tf.keras.layers.Dense(self.inner_dim, use_bias=False, name='v')
self.o = tf.keras.layers.Dense(self.d_model, use_bias=False, name='o')
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
if self.has_relative_attention_bias:
......@@ -199,17 +198,14 @@ class TFT5Attention(tf.keras.layers.Layer):
klen = qlen if cache is None else cache['slen'] + qlen
else:
klen = shape_list(kv)[1]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads = self.n_heads
dim_per_head = self.dim // n_heads
def shape(x):
""" projection """
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, dim_per_head)), perm=(0, 2, 1, 3))
return tf.transpose(tf.reshape(x, (bs, -1, self.n_heads, self.d_kv)), perm=(0, 2, 1, 3))
def unshape(x):
""" compute context """
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.n_heads * dim_per_head))
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.inner_dim))
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head)
if kv is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment