Commit 61399e5a authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Cleaned perturb_past. Identical output as before.

parent ffc29354
...@@ -112,7 +112,7 @@ def top_k_filter(logits, k, probs=False): ...@@ -112,7 +112,7 @@ def top_k_filter(logits, k, probs=False):
def perturb_past( def perturb_past(
past, past,
model, model,
prev, last,
unpert_past=None, unpert_past=None,
unpert_logits=None, unpert_logits=None,
accumulated_hidden=None, accumulated_hidden=None,
...@@ -128,156 +128,174 @@ def perturb_past( ...@@ -128,156 +128,174 @@ def perturb_past(
horizon_length=1, horizon_length=1,
decay=False, decay=False,
gamma=1.5, gamma=1.5,
device='cuda'
): ):
# Generate inital perturbed past # Generate inital perturbed past
past_perturb_orig = [ grad_accumulator = [
(np.random.uniform(0.0, 0.0, p.shape).astype('float32')) (np.zeros(p.shape).astype("float32"))
for p in past] for p in past
]
if accumulated_hidden is None: if accumulated_hidden is None:
accumulated_hidden = 0 accumulated_hidden = 0
if decay: if decay:
decay_mask = torch.arange(0., 1.0 + SMALL_CONST, 1.0 / (window_length))[ decay_mask = torch.arange(
1:] 0.,
1.0 + SMALL_CONST,
1.0 / (window_length)
)[1:]
else: else:
decay_mask = 1.0 decay_mask = 1.0
# TODO fix this comment (SUMANTH)
# Generate a mask is gradient perturbated is based on a past window # Generate a mask is gradient perturbated is based on a past window
_, _, _, current_length, _ = past[0].shape _, _, _, curr_length, _ = past[0].shape
if current_length > window_length and window_length > 0: if curr_length > window_length and window_length > 0:
ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple( ones_key_val_shape = (
[window_length]) + tuple( tuple(past[0].shape[:-2])
past[0].shape[-1:]) + tuple([window_length])
+ tuple(past[0].shape[-1:])
)
zeros_key_val_shape = tuple(past[0].shape[:-2]) + tuple( zeros_key_val_shape = (
[current_length - window_length]) + tuple( tuple(past[0].shape[:-2])
past[0].shape[-1:]) + tuple([curr_length - window_length])
+ tuple(past[0].shape[-1:])
)
ones_mask = torch.ones(ones_key_val_shape) ones_mask = torch.ones(ones_key_val_shape)
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3) ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
ones_mask = ones_mask.permute(0, 1, 2, 4, 3) ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)), window_mask = torch.cat(
dim=-2).cuda() (ones_mask, torch.zeros(zeros_key_val_shape)),
dim=-2
).to(device)
else: else:
window_mask = torch.ones_like(past[0]).cuda() window_mask = torch.ones_like(past[0]).to(device)
# accumulate perturbations for num_iterations
loss_per_iter = [] loss_per_iter = []
new_accumulated_hidden = None
for i in range(num_iterations): for i in range(num_iterations):
print("Iteration ", i + 1) print("Iteration ", i + 1)
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig] curr_perturbation = [
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb] to_var(torch.from_numpy(p_), requires_grad=True)
for p_ in grad_accumulator
perturbed_past = list(map(add, past, past_perturb)) ]
_, _, _, current_length, _ = past_perturb[0].shape # Compute hidden using perturbed past
perturbed_past = list(map(add, past, curr_perturbation))
# _, future_past = model(prev, past=perturbed_past) _, _, _, curr_length, _ = curr_perturbation[0].shape
# hidden = model.hidden_states all_logits, _, all_hidden = model(last, past=perturbed_past)
# Piero modified model call
logits, _, all_hidden = model(prev, past=perturbed_past)
hidden = all_hidden[-1] hidden = all_hidden[-1]
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden, new_accumulated_hidden = accumulated_hidden + torch.sum(
dim=1).detach() hidden,
dim=1
).detach()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
logits = all_logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
# TODO: Check the layer-norm consistency of this with trained discriminator
logits = logits[:, -1, :]
probabs = F.softmax(logits, dim=-1)
loss = 0.0 loss = 0.0
loss_list = [] loss_list = []
if loss_type == 1 or loss_type == 3: if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
for one_hot_good in one_hot_bows_vectors: for one_hot_bow in one_hot_bows_vectors:
good_logits = torch.mm(probabs, torch.t(one_hot_good)) bow_logits = torch.mm(probs, torch.t(one_hot_bow))
loss_word = good_logits bow_loss = -torch.log(torch.sum(bow_logits))
loss_word = torch.sum(loss_word) loss += bow_loss
loss_word = -torch.log(loss_word) loss_list.append(bow_loss)
# loss_word = torch.sum(loss_word) /torch.sum(one_hot_good)
loss += loss_word
loss_list.append(loss_word)
print(" pplm_bow_loss:", loss.data.cpu().numpy()) print(" pplm_bow_loss:", loss.data.cpu().numpy())
if loss_type == 2 or loss_type == 3: if loss_type == 2 or loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = torch.nn.CrossEntropyLoss()
new_true_past = unpert_past # TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
for i in range(horizon_length): curr_unpert_past = unpert_past
future_probabs = F.softmax(logits, dim=-1) # Get softmax curr_probs = torch.unsqueeze(probs, dim=1)
future_probabs = torch.unsqueeze(future_probabs, dim=1)
# _, new_true_past = model(future_probabs, past=new_true_past)
# future_hidden = model.hidden_states # Get expected hidden states
# Piero modified model call
wte = model.resize_token_embeddings() wte = model.resize_token_embeddings()
inputs_embeds = torch.matmul(future_probabs, wte.weight.data) for _ in range(horizon_length):
_, new_true_past, future_hidden = model( inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
past=new_true_past, _, curr_unpert_past, curr_all_hidden = model(
past=curr_unpert_past,
inputs_embeds=inputs_embeds inputs_embeds=inputs_embeds
) )
future_hidden = future_hidden[-1] curr_hidden = curr_all_hidden[-1]
new_accumulated_hidden = new_accumulated_hidden + torch.sum( new_accumulated_hidden = new_accumulated_hidden + torch.sum(
future_hidden, dim=1) curr_hidden, dim=1)
predicted_sentiment = classifier(new_accumulated_hidden / ( prediction = classifier(new_accumulated_hidden /
current_length + 1 + horizon_length)) (curr_length + 1 + horizon_length))
label = torch.tensor([label_class], device='cuda', label = torch.tensor([label_class], device=device,
dtype=torch.long) dtype=torch.long)
discrim_loss = ce_loss(predicted_sentiment, label) discrim_loss = ce_loss(prediction, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
loss += discrim_loss loss += discrim_loss
loss_list.append(discrim_loss) loss_list.append(discrim_loss)
kl_loss = 0.0 kl_loss = 0.0
if kl_scale > 0.0: if kl_scale > 0.0:
p = (F.softmax(unpert_logits[:, -1, :], dim=-1)) unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
p = p + SMALL_CONST * (p <= SMALL_CONST).type( unpert_probs = (
torch.FloatTensor).cuda().detach() unpert_probs + SMALL_CONST *
correction = SMALL_CONST * (probabs <= SMALL_CONST).type( (unpert_probs <= SMALL_CONST).float().to(device).detach()
torch.FloatTensor).cuda().detach() )
corrected_probabs = probabs + correction.detach() correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
corrected_probs = probs + correction.detach()
kl_loss = kl_scale * ( kl_loss = kl_scale * (
(corrected_probabs * (corrected_probabs / p).log()).sum()) (corrected_probs * (corrected_probs / unpert_probs).log()).sum()
print(' kl_loss', (kl_loss).data.cpu().numpy()) )
loss += kl_loss # + discrim_loss print(' kl_loss', kl_loss.data.cpu().numpy())
loss += kl_loss
loss_per_iter.append(loss.data.cpu().numpy()) loss_per_iter.append(loss.data.cpu().numpy())
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy()) print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
# compute gradients
loss.backward() loss.backward()
if grad_norms is not None and loss_type == 1:
# calculate gradient norms
if grad_norms is not None and loss_type == PPLM_BOW:
grad_norms = [ grad_norms = [
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask)) torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
for index, p_ in for index, p_ in enumerate(curr_perturbation)
enumerate(past_perturb)] ]
else: else:
grad_norms = [(torch.norm(p_.grad * window_mask) + SMALL_CONST) for grad_norms = [
index, p_ in enumerate(past_perturb)] (torch.norm(p_.grad * window_mask) + SMALL_CONST)
for index, p_ in enumerate(curr_perturbation)
]
# normalize gradients
grad = [ grad = [
-stepsize * (p_.grad * window_mask / grad_norms[ -stepsize *
index] ** gamma).data.cpu().numpy() (p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
for index, p_ in enumerate(past_perturb)] for index, p_ in enumerate(curr_perturbation)
past_perturb_orig = list(map(add, grad, past_perturb_orig)) ]
for p_ in past_perturb: # accumulate gradient
grad_accumulator = list(map(add, grad, grad_accumulator))
# reset gradients, just to make sure
for p_ in curr_perturbation:
p_.grad.data.zero_() p_.grad.data.zero_()
# removing past from the graph
new_past = [] new_past = []
for p in past: for p_ in past:
new_past.append(p.detach()) new_past.append(p_.detach())
past = new_past past = new_past
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig] # apply the accumulated perturbations to the past
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb] grad_accumulator = [
perturbed_past = list(map(add, past, past_perturb)) to_var(torch.from_numpy(p_), requires_grad=True)
for p_ in grad_accumulator
]
pert_past = list(map(add, past, grad_accumulator))
return perturbed_past, new_accumulated_hidden, grad_norms, loss_per_iter return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
def get_classifier( def get_classifier(
...@@ -532,6 +550,7 @@ def generate_text_pplm( ...@@ -532,6 +550,7 @@ def generate_text_pplm(
horizon_length=horizon_length, horizon_length=horizon_length,
decay=decay, decay=decay,
gamma=gamma, gamma=gamma,
device=device
) )
loss_in_time.append(loss_this_iter) loss_in_time.append(loss_this_iter)
else: else:
...@@ -662,7 +681,8 @@ def run_model(): ...@@ -662,7 +681,8 @@ def run_model():
parser.add_argument("--decay", action="store_true", parser.add_argument("--decay", action="store_true",
help="whether to decay or not") help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5) parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--colorama", action="store_true", help="colors keywords") parser.add_argument("--colorama", action="store_true",
help="colors keywords")
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment