"vscode:/vscode.git/clone" did not exist on "7060766490240f7f1a63dce4c1ca6d0abfd8555d"
Commit 61399e5a authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Cleaned perturb_past. Identical output as before.

parent ffc29354
......@@ -112,7 +112,7 @@ def top_k_filter(logits, k, probs=False):
def perturb_past(
past,
model,
prev,
last,
unpert_past=None,
unpert_logits=None,
accumulated_hidden=None,
......@@ -128,156 +128,174 @@ def perturb_past(
horizon_length=1,
decay=False,
gamma=1.5,
device='cuda'
):
# Generate inital perturbed past
past_perturb_orig = [
(np.random.uniform(0.0, 0.0, p.shape).astype('float32'))
for p in past]
grad_accumulator = [
(np.zeros(p.shape).astype("float32"))
for p in past
]
if accumulated_hidden is None:
accumulated_hidden = 0
if decay:
decay_mask = torch.arange(0., 1.0 + SMALL_CONST, 1.0 / (window_length))[
1:]
decay_mask = torch.arange(
0.,
1.0 + SMALL_CONST,
1.0 / (window_length)
)[1:]
else:
decay_mask = 1.0
# TODO fix this comment (SUMANTH)
# Generate a mask is gradient perturbated is based on a past window
_, _, _, current_length, _ = past[0].shape
_, _, _, curr_length, _ = past[0].shape
if current_length > window_length and window_length > 0:
ones_key_val_shape = tuple(past[0].shape[:-2]) + tuple(
[window_length]) + tuple(
past[0].shape[-1:])
if curr_length > window_length and window_length > 0:
ones_key_val_shape = (
tuple(past[0].shape[:-2])
+ tuple([window_length])
+ tuple(past[0].shape[-1:])
)
zeros_key_val_shape = tuple(past[0].shape[:-2]) + tuple(
[current_length - window_length]) + tuple(
past[0].shape[-1:])
zeros_key_val_shape = (
tuple(past[0].shape[:-2])
+ tuple([curr_length - window_length])
+ tuple(past[0].shape[-1:])
)
ones_mask = torch.ones(ones_key_val_shape)
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
window_mask = torch.cat((ones_mask, torch.zeros(zeros_key_val_shape)),
dim=-2).cuda()
window_mask = torch.cat(
(ones_mask, torch.zeros(zeros_key_val_shape)),
dim=-2
).to(device)
else:
window_mask = torch.ones_like(past[0]).cuda()
window_mask = torch.ones_like(past[0]).to(device)
# accumulate perturbations for num_iterations
loss_per_iter = []
new_accumulated_hidden = None
for i in range(num_iterations):
print("Iteration ", i + 1)
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
perturbed_past = list(map(add, past, past_perturb))
_, _, _, current_length, _ = past_perturb[0].shape
# _, future_past = model(prev, past=perturbed_past)
# hidden = model.hidden_states
# Piero modified model call
logits, _, all_hidden = model(prev, past=perturbed_past)
curr_perturbation = [
to_var(torch.from_numpy(p_), requires_grad=True)
for p_ in grad_accumulator
]
# Compute hidden using perturbed past
perturbed_past = list(map(add, past, curr_perturbation))
_, _, _, curr_length, _ = curr_perturbation[0].shape
all_logits, _, all_hidden = model(last, past=perturbed_past)
hidden = all_hidden[-1]
new_accumulated_hidden = accumulated_hidden + torch.sum(hidden,
dim=1).detach()
new_accumulated_hidden = accumulated_hidden + torch.sum(
hidden,
dim=1
).detach()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
logits = all_logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
# TODO: Check the layer-norm consistency of this with trained discriminator
logits = logits[:, -1, :]
probabs = F.softmax(logits, dim=-1)
loss = 0.0
loss_list = []
if loss_type == 1 or loss_type == 3:
for one_hot_good in one_hot_bows_vectors:
good_logits = torch.mm(probabs, torch.t(one_hot_good))
loss_word = good_logits
loss_word = torch.sum(loss_word)
loss_word = -torch.log(loss_word)
# loss_word = torch.sum(loss_word) /torch.sum(one_hot_good)
loss += loss_word
loss_list.append(loss_word)
if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
for one_hot_bow in one_hot_bows_vectors:
bow_logits = torch.mm(probs, torch.t(one_hot_bow))
bow_loss = -torch.log(torch.sum(bow_logits))
loss += bow_loss
loss_list.append(bow_loss)
print(" pplm_bow_loss:", loss.data.cpu().numpy())
if loss_type == 2 or loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss()
new_true_past = unpert_past
for i in range(horizon_length):
future_probabs = F.softmax(logits, dim=-1) # Get softmax
future_probabs = torch.unsqueeze(future_probabs, dim=1)
# _, new_true_past = model(future_probabs, past=new_true_past)
# future_hidden = model.hidden_states # Get expected hidden states
# Piero modified model call
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
curr_unpert_past = unpert_past
curr_probs = torch.unsqueeze(probs, dim=1)
wte = model.resize_token_embeddings()
inputs_embeds = torch.matmul(future_probabs, wte.weight.data)
_, new_true_past, future_hidden = model(
past=new_true_past,
for _ in range(horizon_length):
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
_, curr_unpert_past, curr_all_hidden = model(
past=curr_unpert_past,
inputs_embeds=inputs_embeds
)
future_hidden = future_hidden[-1]
curr_hidden = curr_all_hidden[-1]
new_accumulated_hidden = new_accumulated_hidden + torch.sum(
future_hidden, dim=1)
curr_hidden, dim=1)
predicted_sentiment = classifier(new_accumulated_hidden / (
current_length + 1 + horizon_length))
prediction = classifier(new_accumulated_hidden /
(curr_length + 1 + horizon_length))
label = torch.tensor([label_class], device='cuda',
label = torch.tensor([label_class], device=device,
dtype=torch.long)
discrim_loss = ce_loss(predicted_sentiment, label)
discrim_loss = ce_loss(prediction, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
loss += discrim_loss
loss_list.append(discrim_loss)
kl_loss = 0.0
if kl_scale > 0.0:
p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
p = p + SMALL_CONST * (p <= SMALL_CONST).type(
torch.FloatTensor).cuda().detach()
correction = SMALL_CONST * (probabs <= SMALL_CONST).type(
torch.FloatTensor).cuda().detach()
corrected_probabs = probabs + correction.detach()
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = (
unpert_probs + SMALL_CONST *
(unpert_probs <= SMALL_CONST).float().to(device).detach()
)
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
corrected_probs = probs + correction.detach()
kl_loss = kl_scale * (
(corrected_probabs * (corrected_probabs / p).log()).sum())
print(' kl_loss', (kl_loss).data.cpu().numpy())
loss += kl_loss # + discrim_loss
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
)
print(' kl_loss', kl_loss.data.cpu().numpy())
loss += kl_loss
loss_per_iter.append(loss.data.cpu().numpy())
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
# compute gradients
loss.backward()
if grad_norms is not None and loss_type == 1:
# calculate gradient norms
if grad_norms is not None and loss_type == PPLM_BOW:
grad_norms = [
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
for index, p_ in
enumerate(past_perturb)]
for index, p_ in enumerate(curr_perturbation)
]
else:
grad_norms = [(torch.norm(p_.grad * window_mask) + SMALL_CONST) for
index, p_ in enumerate(past_perturb)]
grad_norms = [
(torch.norm(p_.grad * window_mask) + SMALL_CONST)
for index, p_ in enumerate(curr_perturbation)
]
# normalize gradients
grad = [
-stepsize * (p_.grad * window_mask / grad_norms[
index] ** gamma).data.cpu().numpy()
for index, p_ in enumerate(past_perturb)]
past_perturb_orig = list(map(add, grad, past_perturb_orig))
-stepsize *
(p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
for index, p_ in enumerate(curr_perturbation)
]
for p_ in past_perturb:
# accumulate gradient
grad_accumulator = list(map(add, grad, grad_accumulator))
# reset gradients, just to make sure
for p_ in curr_perturbation:
p_.grad.data.zero_()
# removing past from the graph
new_past = []
for p in past:
new_past.append(p.detach())
for p_ in past:
new_past.append(p_.detach())
past = new_past
past_perturb = [torch.from_numpy(p_) for p_ in past_perturb_orig]
past_perturb = [to_var(p_, requires_grad=True) for p_ in past_perturb]
perturbed_past = list(map(add, past, past_perturb))
# apply the accumulated perturbations to the past
grad_accumulator = [
to_var(torch.from_numpy(p_), requires_grad=True)
for p_ in grad_accumulator
]
pert_past = list(map(add, past, grad_accumulator))
return perturbed_past, new_accumulated_hidden, grad_norms, loss_per_iter
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
def get_classifier(
......@@ -532,6 +550,7 @@ def generate_text_pplm(
horizon_length=horizon_length,
decay=decay,
gamma=gamma,
device=device
)
loss_in_time.append(loss_this_iter)
else:
......@@ -662,7 +681,8 @@ def run_model():
parser.add_argument("--decay", action="store_true",
help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--colorama", action="store_true", help="colors keywords")
parser.add_argument("--colorama", action="store_true",
help="colors keywords")
args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment