Commit fbb248a2 authored by thomwolf's avatar thomwolf
Browse files

examples testing

parent 5ff0c605
...@@ -4,7 +4,9 @@ import argparse ...@@ -4,7 +4,9 @@ import argparse
import logging import logging
import torch import torch
import torch.nn.functional as F
import numpy as np import numpy as np
from tqdm import trange
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
...@@ -23,18 +25,20 @@ def top_k_logits(logits, k): ...@@ -23,18 +25,20 @@ def top_k_logits(logits, k):
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'): def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
if start_token is None: if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!' assert context is not None, 'Specify exactly one of start_token and context!'
context = torch.tensor(context, device=device) context = torch.tensor(context, device=device, dtype=torch.long)
else: else:
assert context is None, 'Specify exactly one of start_token and context!' assert context is None, 'Specify exactly one of start_token and context!'
context = torch.full((batch_size, 1), start_token, device=device) context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
prev = context prev = context
output = context output = context
past = None
with torch.no_grad(): with torch.no_grad():
for i in range(length): for i in trange(length):
logits, past = model(prev, past=past) logits, past = model(prev, past=past)
logits = logits[:, -1, :] / temperature logits = logits[:, -1, :] / temperature
logits = top_k_logits(logits, k=top_k) logits = top_k_logits(logits, k=top_k)
prev = torch.multinomial(logits, 1) log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
output = torch.cat((output, prev), dim=1) output = torch.cat((output, prev), dim=1)
return output return output
...@@ -57,6 +61,8 @@ def sample_model(): ...@@ -57,6 +61,8 @@ def sample_model():
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path) model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
model.to(device)
model.eval()
if args.length == -1: if args.length == -1:
args.length = model.config.n_ctx args.length = model.config.n_ctx
...@@ -71,6 +77,7 @@ def sample_model(): ...@@ -71,6 +77,7 @@ def sample_model():
batch_size=args.batch_size, batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device temperature=args.temperature, top_k=args.top_k, device=device
) )
out = out.tolist()
for i in range(args.batch_size): for i in range(args.batch_size):
generated += args.batch_size generated += args.batch_size
text = enc.decode(out[i]) text = enc.decode(out[i])
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
import argparse import argparse
import logging import logging
from tqdm import trange
import torch import torch
import torch.nn.functional as F
import numpy as np import numpy as np
from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
...@@ -23,18 +25,20 @@ def top_k_logits(logits, k): ...@@ -23,18 +25,20 @@ def top_k_logits(logits, k):
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'): def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
if start_token is None: if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!' assert context is not None, 'Specify exactly one of start_token and context!'
context = torch.tensor(context, device=device) context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
else: else:
assert context is None, 'Specify exactly one of start_token and context!' assert context is None, 'Specify exactly one of start_token and context!'
context = torch.full((batch_size, 1), start_token, device=device) context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
prev = context prev = context
output = context output = context
past = None
with torch.no_grad(): with torch.no_grad():
for i in range(length): for i in trange(length):
logits, past = model(prev, past=past) logits, past = model(prev, past=past)
logits = logits[:, -1, :] / temperature logits = logits[:, -1, :] / temperature
logits = top_k_logits(logits, k=top_k) logits = top_k_logits(logits, k=top_k)
prev = torch.multinomial(logits, 1) log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
output = torch.cat((output, prev), dim=1) output = torch.cat((output, prev), dim=1)
return output return output
...@@ -50,7 +54,7 @@ def interact_model(): ...@@ -50,7 +54,7 @@ def interact_model():
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
if args.batch_size is None: if args.batch_size == -1:
args.batch_size = 1 args.batch_size = 1
assert args.nsamples % args.batch_size == 0 assert args.nsamples % args.batch_size == 0
...@@ -61,6 +65,8 @@ def interact_model(): ...@@ -61,6 +65,8 @@ def interact_model():
enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path) enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path) model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
model.to(device)
model.eval()
if args.length == -1: if args.length == -1:
args.length = model.config.n_ctx // 2 args.length = model.config.n_ctx // 2
...@@ -81,7 +87,7 @@ def interact_model(): ...@@ -81,7 +87,7 @@ def interact_model():
batch_size=args.batch_size, batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device temperature=args.temperature, top_k=args.top_k, device=device
) )
out = out[:, len(context_tokens):] out = out[:, len(context_tokens):].tolist()
for i in range(args.batch_size): for i in range(args.batch_size):
generated += 1 generated += 1
text = enc.decode(out[i]) text = enc.decode(out[i])
......
...@@ -244,10 +244,10 @@ class Attention(nn.Module): ...@@ -244,10 +244,10 @@ class Attention(nn.Module):
key = self.split_heads(key, k=True) key = self.split_heads(key, k=True)
value = self.split_heads(value) value = self.split_heads(value)
if layer_past is not None: if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1] past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose to have same shapes
key = torch.cat((past_key, key), dim=-2) key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key, value)) present = torch.stack((key.transpose(-2, -1), value))
a = self._attn(query, key, value) a = self._attn(query, key, value)
a = self.merge_heads(a) a = self.merge_heads(a)
a = self.c_proj(a) a = self.c_proj(a)
...@@ -278,7 +278,7 @@ class Block(nn.Module): ...@@ -278,7 +278,7 @@ class Block(nn.Module):
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None): def forward(self, x, layer_past=None):
a, present = self.attn(self.ln_1(x), layer_past=past) a, present = self.attn(self.ln_1(x), layer_past=layer_past)
x = x + a x = x + a
m = self.mlp(self.ln_2(x)) m = self.mlp(self.ln_2(x))
x = x + m x = x + m
...@@ -531,7 +531,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -531,7 +531,7 @@ class GPT2Model(GPT2PreTrainedModel):
past_length = 0 past_length = 0
past = [None] * len(self.h) past = [None] * len(self.h)
else: else:
past[0][0].size(-2) past_length = past[0][0].size(-2)
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device) position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment