Commit 61a12f79 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Renamed SmallConst to SMALL_CONST and introduced BIG_CONST. Identical output as before.

parent ef47b2c0
......@@ -43,7 +43,7 @@ PPLM_BOW = 1
PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15
SmallConst = 1e-15
BIG_CONST = 1e10
TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium")
BAG_OF_WORDS_ARCHIVE_MAP = {
......@@ -104,7 +104,8 @@ def top_k_filter(logits, k, probs=False):
if probs:
return torch.where(logits < batch_mins,
torch.ones_like(logits) * 0.0, logits)
return torch.where(logits < batch_mins, torch.ones_like(logits) * -1e10,
return torch.where(logits < batch_mins,
torch.ones_like(logits) * -BIG_CONST,
logits)
......@@ -137,7 +138,7 @@ def perturb_past(
accumulated_hidden = 0
if decay:
decay_mask = torch.arange(0., 1.0 + SmallConst, 1.0 / (window_length))[
decay_mask = torch.arange(0., 1.0 + SMALL_CONST, 1.0 / (window_length))[
1:]
else:
decay_mask = 1.0
......@@ -233,9 +234,9 @@ def perturb_past(
kl_loss = 0.0
if kl_scale > 0.0:
p = (F.softmax(unpert_logits[:, -1, :], dim=-1))
p = p + SmallConst * (p <= SmallConst).type(
p = p + SMALL_CONST * (p <= SMALL_CONST).type(
torch.FloatTensor).cuda().detach()
correction = SmallConst * (probabs <= SmallConst).type(
correction = SMALL_CONST * (probabs <= SMALL_CONST).type(
torch.FloatTensor).cuda().detach()
corrected_probabs = probabs + correction.detach()
kl_loss = kl_scale * (
......@@ -254,7 +255,7 @@ def perturb_past(
for index, p_ in
enumerate(past_perturb)]
else:
grad_norms = [(torch.norm(p_.grad * window_mask) + SmallConst) for
grad_norms = [(torch.norm(p_.grad * window_mask) + SMALL_CONST) for
index, p_ in enumerate(past_perturb)]
grad = [
......@@ -560,31 +561,31 @@ def generate_text_pplm(
# Piero modified model call
# hidden = model.hidden_states # update hidden
# logits = model.forward_hidden(hidden)
logits = logits[:, -1, :] / temperature # + SmallConst
logits = logits[:, -1, :] / temperature # + SMALL_CONST
# logits = top_k_filter(logits, k=args.top_k) # + SmallConst
# logits = top_k_filter(logits, k=args.top_k) # + SMALL_CONST
log_probs = F.softmax(logits, dim=-1)
# Fuse the modified model and original model
if perturb:
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SmallConst
# original_probs = top_k_filter(original_probs[:, -1, :]) #+ SMALL_CONST
unpert_logits = F.softmax(unpert_logits[:, -1, :], dim=-1)
# likelywords = torch.topk(original_probs, k=10, dim=-1)
# print(TOKENIZER.decode(likelywords[1].tolist()[0]))
log_probs = ((log_probs ** gm_scale) * (
unpert_logits ** (1 - gm_scale))) # + SmallConst
unpert_logits ** (1 - gm_scale))) # + SMALL_CONST
log_probs = top_k_filter(log_probs, k=top_k,
probs=True) # + SmallConst
probs=True) # + SMALL_CONST
if torch.sum(log_probs) <= 1:
log_probs = log_probs / torch.sum(log_probs)
else:
logits = top_k_filter(logits, k=top_k) # + SmallConst
logits = top_k_filter(logits, k=top_k) # + SMALL_CONST
log_probs = F.softmax(logits, dim=-1)
if sample:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment