Commit 61a12f79 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Renamed SmallConst to SMALL_CONST and introduced BIG_CONST. Identical output as before.

parent ef47b2c0
...@@ -43,7 +43,7 @@ PPLM_BOW = 1 ...@@ -43,7 +43,7 @@ PPLM_BOW = 1
PPLM_DISCRIM = 2 PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3 PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15 SMALL_CONST = 1e-15
SmallConst = 1e-15 BIG_CONST = 1e10
TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium") TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium")
BAG_OF_WORDS_ARCHIVE_MAP = { BAG_OF_WORDS_ARCHIVE_MAP = {
...@@ -104,7 +104,8 @@ def top_k_filter(logits, k, probs=False): ...@@ -104,7 +104,8 @@ def top_k_filter(logits, k, probs=False):
if probs: if probs:
return torch.where(logits < batch_mins, return torch.where(logits < batch_mins,
torch.ones_like(logits) * 0.0, logits) torch.ones_like(logits) * 0.0, logits)
return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10, return torch.where(logits < batch_mins,
torch.ones_like(logits) * -BIG_CONST,
logits) logits)
...@@ -137,7 +138,7 @@ def perturb_past( ...@@ -137,7 +138,7 @@ def perturb_past(
accumulated_hidden = 0 accumulated_hidden = 0
if decay: if decay:
decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[ decay_mask = torch.arange(0., 1.0 + SMALL_CONST, 1.0 / (window_length))[
1:] 1:]
else: else:
decay_mask = 1.0 decay_mask = 1.0
...@@ -233,9 +234,9 @@ def perturb_past( ...@@ -233,9 +234,9 @@ def perturb_past(
kl_loss = 0.0 kl_loss = 0.0
if kl_scale > 0.0: if kl_scale > 0.0:
p = (F.softmax(unpert_logits[:, -1, :], dim=-1)) p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
p = p + SmallConst * (p <= SmallConst).type( p = p + SMALL_CONST * (p <= SMALL_CONST).type(
torch.FloatTensor).cuda().detach() torch.FloatTensor).cuda().detach()
correction = SmallConst * (probabs <= SmallConst).type( correction = SMALL_CONST * (probabs <= SMALL_CONST).type(
torch.FloatTensor).cuda().detach() torch.FloatTensor).cuda().detach()
corrected_probabs = probabs + correction.detach() corrected_probabs = probabs + correction.detach()
kl_loss = kl_scale * ( kl_loss = kl_scale * (
...@@ -254,7 +255,7 @@ def perturb_past( ...@@ -254,7 +255,7 @@ def perturb_past(
for index, p_ in for index, p_ in
enumerate(past_perturb)] enumerate(past_perturb)]
else: else:
grad_norms = [(torch.norm(p_.grad * window_mask) + SmallConst) for grad_norms = [(torch.norm(p_.grad * window_mask) + SMALL_CONST) for
index, p_ in enumerate(past_perturb)] index, p_ in enumerate(past_perturb)]
grad = [ grad = [
...@@ -560,31 +561,31 @@ def generate_text_pplm( ...@@ -560,31 +561,31 @@ def generate_text_pplm(
# Piero modified model call # Piero modified model call
# hidden = model.hidden_states # update hidden # hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden) # logits = model.forward_hidden(hidden)
logits = logits[:, -1, :] / temperature # + SmallConst logits = logits[:, -1, :] / temperature # + SMALL_CONST
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst # logits = top_k_filter(logits, k=args.top_k) # + SMALL_CONST
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
# Fuse the modified model and original model # Fuse the modified model and original model
if perturb: if perturb:
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst # original_probs = top_k_filter(original_probs[:, -1, :]) #+ SMALL_CONST
unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1) unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
# likelywords = torch.topk(original_probs, k=10, dim=-1) # likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0])) # print(TOKENIZER.decode(likelywords[1].tolist()[0]))
log_probs = ((log_probs ** gm_scale) * ( log_probs = ((log_probs ** gm_scale) * (
unpert_logits ** (1 - gm_scale))) # + SmallConst unpert_logits ** (1 - gm_scale))) # + SMALL_CONST
log_probs = top_k_filter(log_probs, k=top_k, log_probs = top_k_filter(log_probs, k=top_k,
probs=True) # + SmallConst probs=True) # + SMALL_CONST
if torch.sum(log_probs) <= 1: if torch.sum(log_probs) <= 1:
log_probs = log_probs / torch.sum(log_probs) log_probs = log_probs / torch.sum(log_probs)
else: else:
logits = top_k_filter(logits, k=top_k) # + SmallConst logits = top_k_filter(logits, k=top_k) # + SMALL_CONST
log_probs = F.softmax(logits, dim=-1) log_probs = F.softmax(logits, dim=-1)
if sample: if sample:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment