Commit 690a0dbf authored by thomwolf's avatar thomwolf
Browse files

fix example - masking

parent fbb248a2
...@@ -22,7 +22,7 @@ def top_k_logits(logits, k): ...@@ -22,7 +22,7 @@ def top_k_logits(logits, k):
min_values = values[:, -1] min_values = values[:, -1]
return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits) return torch.where(logits < min_values, torch.ones_like(logits, dtype=logits.dtype) * -1e10, logits)
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'): def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
if start_token is None: if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!' assert context is not None, 'Specify exactly one of start_token and context!'
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1) context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
...@@ -38,11 +38,14 @@ def sample_sequence(model, length, start_token=None, batch_size=None, context=No ...@@ -38,11 +38,14 @@ def sample_sequence(model, length, start_token=None, batch_size=None, context=No
logits = logits[:, -1, :] / temperature logits = logits[:, -1, :] / temperature
logits = top_k_logits(logits, k=top_k) logits = top_k_logits(logits, k=top_k)
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1) if sample:
prev = torch.multinomial(log_probs, num_samples=1)
else:
_, prev = torch.topk(log_probs, k=1, dim=-1)
output = torch.cat((output, prev), dim=1) output = torch.cat((output, prev), dim=1)
return output return output
def interact_model(): def run_model():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint') parser.add_argument('--model_name_or_path', type=str, default='gpt2', help='pretrained model name or path to local checkpoint')
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
...@@ -51,6 +54,7 @@ def interact_model(): ...@@ -51,6 +54,7 @@ def interact_model():
parser.add_argument("--length", type=int, default=-1) parser.add_argument("--length", type=int, default=-1)
parser.add_argument("--temperature", type=int, default=1) parser.add_argument("--temperature", type=int, default=1)
parser.add_argument("--top_k", type=int, default=0) parser.add_argument("--top_k", type=int, default=0)
parser.add_argument('--unconditional', action='store_true', help='If true, unconditional generation.')
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
...@@ -73,17 +77,19 @@ def interact_model(): ...@@ -73,17 +77,19 @@ def interact_model():
elif args.length > model.config.n_ctx: elif args.length > model.config.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx) raise ValueError("Can't get samples longer than window size: %s" % model.config.n_ctx)
while True: while not args.unconditional:
raw_text = input("Model prompt >>> ") if not args.unconditional:
while not raw_text:
print('Prompt should not be empty!')
raw_text = input("Model prompt >>> ") raw_text = input("Model prompt >>> ")
context_tokens = enc.encode(raw_text) while not raw_text:
print('Prompt should not be empty!')
raw_text = input("Model prompt >>> ")
context_tokens = enc.encode(raw_text)
generated = 0 generated = 0
for _ in range(args.nsamples // args.batch_size): for _ in range(args.nsamples // args.batch_size):
out = sample_sequence( out = sample_sequence(
model=model, length=args.length, model=model, length=args.length,
context=context_tokens, context=context_tokens if not args.unconditional else None,
start_token=enc.encoder['<|endoftext|>'] if args.unconditional else None,
batch_size=args.batch_size, batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device temperature=args.temperature, top_k=args.top_k, device=device
) )
...@@ -96,5 +102,4 @@ def interact_model(): ...@@ -96,5 +102,4 @@ def interact_model():
print("=" * 80) print("=" * 80)
if __name__ == '__main__': if __name__ == '__main__':
interact_model() run_model()
...@@ -87,10 +87,6 @@ def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path): ...@@ -87,10 +87,6 @@ def load_tf_weights_in_gpt2(model, gpt2_checkpoint_path):
if len(l) >= 2: if len(l) >= 2:
num = int(l[1]) num = int(l[1])
pointer = pointer[num] pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try: try:
assert pointer.shape == array.shape assert pointer.shape == array.shape
except AssertionError as e: except AssertionError as e:
...@@ -216,10 +212,9 @@ class Attention(nn.Module): ...@@ -216,10 +212,9 @@ class Attention(nn.Module):
w = torch.matmul(q, k) w = torch.matmul(q, k)
if self.scale: if self.scale:
w = w / math.sqrt(v.size(-1)) w = w / math.sqrt(v.size(-1))
# w = w * self.bias + -1e9 * (1 - self.bias) # TF implem method: mask_attn_weights nd, ns = w.size(-2), w.size(-1)
# XD: self.b may be larger than w, so we need to crop it b = self.bias[:, :, ns-nd:ns, :ns]
b = self.bias[:, :, : w.size(-2), : w.size(-1)] w = w * b - 1e10 * (1 - b)
w = w * b + -1e10 * (1 - b)
w = nn.Softmax(dim=-1)(w) w = nn.Softmax(dim=-1)(w)
return torch.matmul(w, v) return torch.matmul(w, v)
...@@ -233,9 +228,9 @@ class Attention(nn.Module): ...@@ -233,9 +228,9 @@ class Attention(nn.Module):
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
if k: if k:
return x.permute(0, 2, 3, 1) return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
else: else:
return x.permute(0, 2, 1, 3) return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward(self, x, layer_past=None): def forward(self, x, layer_past=None):
x = self.c_attn(x) x = self.c_attn(x)
...@@ -244,10 +239,10 @@ class Attention(nn.Module): ...@@ -244,10 +239,10 @@ class Attention(nn.Module):
key = self.split_heads(key, k=True) key = self.split_heads(key, k=True)
value = self.split_heads(value) value = self.split_heads(value)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose to have same shapes past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
key = torch.cat((past_key, key), dim=-1) key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
a = self._attn(query, key, value) a = self._attn(query, key, value)
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
...@@ -522,7 +517,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -522,7 +517,7 @@ class GPT2Model(GPT2PreTrainedModel):
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
block = Block(config.n_ctx, config, scale=True) block = Block(config.n_ctx, config, scale=True)
self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)]) self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(config.n_layer)])
self.ln_f = LayerNorm(config.n_embd) self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.apply(self.init_weights) self.apply(self.init_weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment