"git@developer.sourcefind.cn:wqshmzh/ktransformers.git" did not exist on "ca9695b488f2167c8d79a112785c872a853cb408"
Commit b860e47c authored by thomwolf's avatar thomwolf
Browse files

add head masking and pruning to gpt-2

parent 7220d47a
...@@ -44,6 +44,30 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.hugging ...@@ -44,6 +44,30 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.hugging
PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"} "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
def prune_conv1d_layer(layer, index, dim=1):
""" Prune a Conv1D layer (a model parameters) to keep only entries in index.
A Conv1D work as a Linear layer (see e.g. BERT) but the weights are transposed.
Return the pruned layer as a new layer with requires_grad=True.
Used to remove heads.
"""
index = index.to(layer.weight.device)
W = layer.weight.index_select(dim, index).clone().detach()
if dim == 0:
b = layer.bias.clone().detach()
else:
b = layer.bias[index].clone().detach()
new_size = list(layer.weight.size())
new_size[dim] = len(index)
new_layer = Conv1D(new_size[1], new_size[0])
new_layer.weight.requires_grad = False
new_layer.weight.copy_(W.contiguous())
new_layer.weight.requires_grad = True
new_layer.bias.requires_grad = False
new_layer.bias.copy_(b.contiguous())
new_layer.bias.requires_grad = True
return new_layer
def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path): def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model """ Load tf checkpoints in a pytorch model
""" """
...@@ -223,7 +247,7 @@ class Conv1D(nn.Module): ...@@ -223,7 +247,7 @@ class Conv1D(nn.Module):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False): def __init__(self, nx, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
super(Attention, self).__init__() super(Attention, self).__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem] # [switch nx => n_state from Block to Attention to keep identical to TF implem]
...@@ -232,13 +256,31 @@ class Attention(nn.Module): ...@@ -232,13 +256,31 @@ class Attention(nn.Module):
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.keep_multihead_output = keep_multihead_output
self.multihead_output = None
self.c_attn = Conv1D(n_state * 3, nx) self.c_attn = Conv1D(n_state * 3, nx)
self.c_proj = Conv1D(n_state, nx) self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop)
def _attn(self, q, k, v): def prune_heads(self, heads):
mask = torch.ones(self.n_head, self.split_size // self.n_head)
for head in heads:
mask[head] = 0
mask = mask.view(-1).contiguous().eq(1)
index = torch.arange(len(mask))[mask].long()
index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)])
# Prune conv1d layers
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
# Update hyper params
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
self.n_head = self.n_head - len(heads)
def _attn(self, q, k, v, head_mask=None):
w = torch.matmul(q, k) w = torch.matmul(q, k)
if self.scale: if self.scale:
w = w / math.sqrt(v.size(-1)) w = w / math.sqrt(v.size(-1))
...@@ -248,6 +290,11 @@ class Attention(nn.Module): ...@@ -248,6 +290,11 @@ class Attention(nn.Module):
w = nn.Softmax(dim=-1)(w) w = nn.Softmax(dim=-1)(w)
w = self.attn_dropout(w) w = self.attn_dropout(w)
# Mask heads if we want to
if head_mask is not None:
w = w * head_mask
if self.output_attentions: if self.output_attentions:
return w, torch.matmul(w, v) return w, torch.matmul(w, v)
return torch.matmul(w, v) return torch.matmul(w, v)
...@@ -265,7 +312,7 @@ class Attention(nn.Module): ...@@ -265,7 +312,7 @@ class Attention(nn.Module):
else: else:
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward(self, x, layer_past=None): def forward(self, x, layer_past=None, head_mask=None):
x = self.c_attn(x) x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2) query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query) query = self.split_heads(query)
...@@ -276,7 +323,12 @@ class Attention(nn.Module): ...@@ -276,7 +323,12 @@ class Attention(nn.Module):
key = torch.cat((past_key, key), dim=-1) key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
a = self._attn(query, key, value)
a = self._attn(query, key, value, head_mask)
if self.keep_multihead_output:
self.multihead_output = a
self.multihead_output.retain_grad()
if self.output_attentions: if self.output_attentions:
attentions, a = a attentions, a = a
a = self.merge_heads(a) a = self.merge_heads(a)
...@@ -303,17 +355,17 @@ class MLP(nn.Module): ...@@ -303,17 +355,17 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False, output_attentions=False): def __init__(self, n_ctx, config, scale=False, output_attentions=False, keep_multihead_output=False):
super(Block, self).__init__() super(Block, self).__init__()
nx = config.n_embd nx = config.n_embd
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale, output_attentions) self.attn = Attention(nx, n_ctx, config, scale, output_attentions, keep_multihead_output)
self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None): def forward(self, x, layer_past=None, head_mask=None):
output_attn = self.attn(self.ln_1(x), layer_past=layer_past) output_attn = self.attn(self.ln_1(x), layer_past=layer_past, head_mask=head_mask)
if self.output_attentions: if self.output_attentions:
attentions, a, present = output_attn attentions, a, present = output_attn
else: else:
...@@ -593,13 +645,14 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -593,13 +645,14 @@ class GPT2Model(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config, output_attentions=False): def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(GPT2Model, self).__init__(config) super(GPT2Model, self).__init__(config)
self.output_attentions = output_attentions self.output_attentions = output_attentions
self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd) self.wte = nn.Embedding(config.total_tokens_embeddings, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop) self.drop = nn.Dropout(config.embd_pdrop)
block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions) block = Block(config.n_ctx, config, scale=True, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
...@@ -619,7 +672,20 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -619,7 +672,20 @@ class GPT2Model(GPT2PreTrainedModel):
# Copy word embeddings from the previous weights # Copy word embeddings from the previous weights
self.wte.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :] self.wte.weight.data[:self.config.vocab_size, :] = old_embed.weight.data[:self.config.vocab_size, :]
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None): def prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for layer, heads in heads_to_prune.items():
self.h[layer].attn.prune_heads(heads)
def get_multihead_outputs(self):
""" Gather all multi-head outputs.
Return: list (layers) of multihead module outputs with gradients
"""
return [h.attn.multihead_output for h in self.h]
def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None, head_mask=None):
if past is None: if past is None:
past_length = 0 past_length = 0
past = [None] * len(self.h) past = [None] * len(self.h)
...@@ -629,6 +695,17 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -629,6 +695,17 @@ class GPT2Model(GPT2PreTrainedModel):
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
# Prepare head mask if needed
# 1.0 in head_mask indicate we mask the head
# attention_probs has shape bsz x n_heads x N x N
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each instance in batch
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
head_mask = (1.0 - head_mask)
input_shape = input_ids.size() input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_ids.size(-1)) input_ids = input_ids.view(-1, input_ids.size(-1))
position_ids = position_ids.view(-1, position_ids.size(-1)) position_ids = position_ids.view(-1, position_ids.size(-1))
...@@ -646,11 +723,12 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -646,11 +723,12 @@ class GPT2Model(GPT2PreTrainedModel):
presents = [] presents = []
all_attentions = [] all_attentions = []
for block, layer_past in zip(self.h, past): for block, layer_past in zip(self.h, past):
outputs = block(hidden_states, layer_past, head_mask)
if self.output_attentions: if self.output_attentions:
attentions, hidden_states, present = block(hidden_states, layer_past) attentions, hidden_states, present = outputs
all_attentions.append(attentions) all_attentions.append(attentions)
else: else:
hidden_states, present = block(hidden_states, layer_past) hidden_states, present = outputs
presents.append(present) presents.append(present)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
...@@ -703,9 +781,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -703,9 +781,10 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config, output_attentions=False): def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(GPT2LMHeadModel, self).__init__(config) super(GPT2LMHeadModel, self).__init__(config)
self.transformer = GPT2Model(config, output_attentions=output_attentions) self.transformer = GPT2Model(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -717,8 +796,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -717,8 +796,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None): def forward(self, input_ids, position_ids=None, token_type_ids=None, lm_labels=None, past=None, head_mask=None):
transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past) transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
if self.transformer.output_attentions: if self.transformer.output_attentions:
all_attentions, hidden_states, presents = transformer_output all_attentions, hidden_states, presents = transformer_output
else: else:
...@@ -787,9 +866,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -787,9 +866,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
``` ```
""" """
def __init__(self, config, output_attentions=False): def __init__(self, config, output_attentions=False, keep_multihead_output=False):
super(GPT2DoubleHeadsModel, self).__init__(config) super(GPT2DoubleHeadsModel, self).__init__(config)
self.transformer = GPT2Model(config, output_attentions=output_attentions) self.transformer = GPT2Model(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
self.lm_head = GPT2LMHead(self.transformer.wte.weight, config) self.lm_head = GPT2LMHead(self.transformer.wte.weight, config)
self.multiple_choice_head = GPT2MultipleChoiceHead(config) self.multiple_choice_head = GPT2MultipleChoiceHead(config)
self.apply(self.init_weights) self.apply(self.init_weights)
...@@ -802,8 +882,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -802,8 +882,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
self.transformer.set_num_special_tokens(num_special_tokens) self.transformer.set_num_special_tokens(num_special_tokens)
self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens) self.lm_head.set_embeddings_weights(self.transformer.wte.weight, predict_special_tokens=predict_special_tokens)
def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None, position_ids=None, past=None): def forward(self, input_ids, mc_token_ids, lm_labels=None, mc_labels=None, token_type_ids=None,
transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past) position_ids=None, past=None, head_mask=None):
transformer_output = self.transformer(input_ids, position_ids, token_type_ids, past, head_mask)
if self.transformer.output_attentions: if self.transformer.output_attentions:
all_attentions, hidden_states, presents = transformer_output all_attentions, hidden_states, presents = transformer_output
else: else:
......
...@@ -209,6 +209,73 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -209,6 +209,73 @@ class GPT2ModelTest(unittest.TestCase):
[list(l.size()) for l in result["loss"]], [list(l.size()) for l in result["loss"]],
[[], []]) [[], []])
def create_and_check_gpt2_for_headmasking(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
head_mask = torch.ones(self.n_head).to(input_ids.device)
head_mask[0] = 0.0
head_mask[-1] = 0.0 # Mask all but the first and last heads
if isinstance(model, GPT2DoubleHeadsModel):
output = model(input_ids, mc_token_ids, head_mask=head_mask)
else:
output = model(input_ids, head_mask=head_mask)
output = sum(t.sum() for t in output[:-1])
output = output.sum()
output.backward()
multihead_outputs = (model if isinstance(model, GPT2Model) else model.transformer).get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertEqual(
len(multihead_outputs[0][:, 1:(self.n_head-1), :, :].nonzero()),
0)
self.parent.assertEqual(
len(multihead_outputs[0][:, 0, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
self.parent.assertEqual(
len(multihead_outputs[0][:, self.n_head-1, :, :].nonzero()),
self.batch_size * self.n_choices * self.seq_length * self.n_embd // self.n_head)
def create_and_check_gpt2_for_head_pruning(self, config, input_ids, token_type_ids, position_ids,
mc_labels, lm_labels, mc_token_ids):
for model_class in (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel):
model = model_class(config=config, keep_multihead_output=True)
model.eval()
transformer = model if isinstance(model, GPT2Model) else model.transformer
heads_to_prune = {0: list(range(1, self.n_head)),
-1: [0]}
transformer.prune_heads(heads_to_prune)
if isinstance(model, GPT2DoubleHeadsModel):
output = model(input_ids, mc_token_ids)
else:
output = model(input_ids)
output = sum(t.sum() for t in output[:-1])
output = output.sum()
output.backward()
multihead_outputs = transformer.get_multihead_outputs()
self.parent.assertEqual(len(multihead_outputs), self.n_layer)
self.parent.assertListEqual(
list(multihead_outputs[0].size()),
[self.batch_size * self.n_choices, 1,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertListEqual(
list(multihead_outputs[1].size()),
[self.batch_size * self.n_choices, self.n_head,
self.seq_length, self.n_embd // self.n_head])
self.parent.assertListEqual(
list(multihead_outputs[-1].size()),
[self.batch_size * self.n_choices, self.n_head-1,
self.seq_length, self.n_embd // self.n_head])
def test_default(self): def test_default(self):
self.run_tester(GPT2ModelTest.GPT2ModelTester(self)) self.run_tester(GPT2ModelTest.GPT2ModelTester(self))
...@@ -247,6 +314,9 @@ class GPT2ModelTest(unittest.TestCase): ...@@ -247,6 +314,9 @@ class GPT2ModelTest(unittest.TestCase):
tester.check_gpt2_double_heads_output(output_result) tester.check_gpt2_double_heads_output(output_result)
tester.check_gpt2_double_heads_loss_output(output_result) tester.check_gpt2_double_heads_loss_output(output_result)
tester.create_and_check_gpt2_for_headmasking(*config_and_inputs)
tester.create_and_check_gpt2_for_head_pruning(*config_and_inputs)
@classmethod @classmethod
def ids_tensor(cls, shape, vocab_size, rng=None, name=None): def ids_tensor(cls, shape, vocab_size, rng=None, name=None):
"""Creates a random int32 tensor of the shape within the vocab size.""" """Creates a random int32 tensor of the shape within the vocab size."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment