Commit 3edfa1d6 authored by thomwolf's avatar thomwolf
Browse files

update model to use past

parent bd5363cc
...@@ -52,7 +52,7 @@ def positional_encoding(position, d_model_size, dtype): ...@@ -52,7 +52,7 @@ def positional_encoding(position, d_model_size, dtype):
sines = torch.sin(angle_rads[:, 0::2]) sines = torch.sin(angle_rads[:, 0::2])
cosines = torch.cos(angle_rads[:, 1::2]) cosines = torch.cos(angle_rads[:, 1::2])
pos_encoding = torch.cat([sines, cosines], dim=-1).unsqueeze(0) pos_encoding = torch.cat([sines, cosines], dim=-1)
return pos_encoding return pos_encoding
def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None): def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=None):
...@@ -110,18 +110,21 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -110,18 +110,21 @@ class MultiHeadAttention(torch.nn.Module):
k = self.split_into_heads(k, batch_size) k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size) v = self.split_into_heads(v, batch_size)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below past_key, past_value = layer_past[0], layer_past[1]
k = torch.cat((past_key, k), dim=-1) k = torch.cat((past_key, k), dim=-1)
v = torch.cat((past_value, v), dim=-2) v = torch.cat((past_value, v), dim=-2)
present = torch.stack((k.transpose(-2, -1), v)) # transpose to have same shapes for stacking present = torch.stack((k, v))
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask, output_attentions) output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = output[0].permute([0, 2, 1, 3]) scaled_attention = output[0].permute([0, 2, 1, 3])
attn = output[1] attn = output[1]
original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size) original_size_attention = scaled_attention.reshape(batch_size, -1, self.d_model_size)
output = self.dense(original_size_attention) output = self.dense(original_size_attention)
return output, attn outputs = (output, present)
if self.output_attentions:
outputs = outputs + (attn,)
return outputs
...@@ -146,10 +149,11 @@ class EncoderLayer(torch.nn.Module): ...@@ -146,10 +149,11 @@ class EncoderLayer(torch.nn.Module):
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None): def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None):
normed = self.layernorm1(x) normed = self.layernorm1(x)
attn_output, attn = self.multi_head_attention(normed, normed, normed, mask, attn_outputs = self.multi_head_attention(normed, normed, normed, mask,
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask) head_mask=head_mask)
attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output) attn_output = self.dropout1(attn_output)
out1 = x + attn_output out1 = x + attn_output
...@@ -158,7 +162,8 @@ class EncoderLayer(torch.nn.Module): ...@@ -158,7 +162,8 @@ class EncoderLayer(torch.nn.Module):
ffn_output = self.dropout2(ffn_output) ffn_output = self.dropout2(ffn_output)
out2 = out1 + ffn_output out2 = out1 + ffn_output
return out2, attn outputs = (out2,) + attn_outputs[1:]
return outputs
class CTRLPreTrainedModel(PreTrainedModel): class CTRLPreTrainedModel(PreTrainedModel):
...@@ -344,14 +349,15 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -344,14 +349,15 @@ class CTRLModel(CTRLPreTrainedModel):
else: else:
head_mask = [None] * self.config.n_layer head_mask = [None] * self.config.n_layer
embedded = self.w(input_ids) x = self.w(input_ids)
x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded # x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len = input_ids.shape[1] seq_len = input_ids.shape[1]
mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device) mask = torch.triu(torch.ones(seq_len, seq_len), 1).to(x.device)
x *= np.sqrt(self.d_model_size) x *= np.sqrt(self.d_model_size)
x += self.pos_encoding[:, position_ids, :].to(x.device) pos_x = self.pos_encoding[position_ids, :].to(x.device)
x += pos_x
x = self.dropout(x) x = self.dropout(x)
......
...@@ -144,14 +144,16 @@ class CTRLModelTest(CommonTestCases.CommonModelTester): ...@@ -144,14 +144,16 @@ class CTRLModelTest(CommonTestCases.CommonModelTester):
model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask) model(input_ids, token_type_ids=token_type_ids, head_mask=head_mask)
model(input_ids, token_type_ids=token_type_ids) model(input_ids, token_type_ids=token_type_ids)
sequence_output, _ = model(input_ids) sequence_output, presents = model(input_ids)
result = { result = {
"sequence_output": sequence_output, "sequence_output": sequence_output,
"presents": presents,
} }
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["sequence_output"].size()), list(result["sequence_output"].size()),
[self.batch_size, self.seq_length, self.hidden_size]) [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertEqual(len(result["presents"]), config.n_layer)
def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_lm_head_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = CTRLLMHeadModel(config) model = CTRLLMHeadModel(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment