Commit 893d0d64 authored by w4nderlust's avatar w4nderlust Committed by Julien Chaumond
Browse files

Changed order of some parameters to be more consistent. Identical results.

parent f42816e7
...@@ -121,17 +121,17 @@ def perturb_past( ...@@ -121,17 +121,17 @@ def perturb_past(
accumulated_hidden=None, accumulated_hidden=None,
grad_norms=None, grad_norms=None,
stepsize=0.01, stepsize=0.01,
one_hot_bows_vectors=None,
classifier=None, classifier=None,
class_label=None, class_label=None,
one_hot_bows_vectors=None,
loss_type=0, loss_type=0,
num_iterations=3, num_iterations=3,
kl_scale=0.01,
window_length=0,
horizon_length=1, horizon_length=1,
window_length=0,
decay=False, decay=False,
gamma=1.5, gamma=1.5,
device='cuda' kl_scale=0.01,
device='cuda',
): ):
# Generate inital perturbed past # Generate inital perturbed past
grad_accumulator = [ grad_accumulator = [
...@@ -351,8 +351,7 @@ def get_classifier( ...@@ -351,8 +351,7 @@ def get_classifier(
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
List[ List[List[List[int]]]:
List[List[int]]]:
bow_indices = [] bow_indices = []
for id_or_path in bag_of_words_ids_or_paths: for id_or_path in bag_of_words_ids_or_paths:
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP: if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
...@@ -388,22 +387,22 @@ def full_text_generation( ...@@ -388,22 +387,22 @@ def full_text_generation(
context=None, context=None,
num_samples=1, num_samples=1,
device="cuda", device="cuda",
sample=False, bag_of_words=None,
discrim=None, discrim=None,
class_label=None, class_label=None,
bag_of_words=None,
length=100, length=100,
grad_length=10000,
stepsize=0.02, stepsize=0.02,
num_iterations=3,
temperature=1.0, temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10, top_k=10,
window_length=0, sample=False,
num_iterations=3,
grad_length=10000,
horizon_length=1, horizon_length=1,
window_length=0,
decay=False, decay=False,
gamma=1.5, gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
**kwargs **kwargs
): ):
classifier, class_id = get_classifier( classifier, class_id = get_classifier(
...@@ -454,24 +453,24 @@ def full_text_generation( ...@@ -454,24 +453,24 @@ def full_text_generation(
tokenizer=tokenizer, tokenizer=tokenizer,
context=context, context=context,
device=device, device=device,
sample=sample,
perturb=True, perturb=True,
bow_indices=bow_indices, bow_indices=bow_indices,
classifier=classifier, classifier=classifier,
class_label=class_id, class_label=class_id,
loss_type=loss_type, loss_type=loss_type,
length=length, length=length,
grad_length=grad_length,
stepsize=stepsize, stepsize=stepsize,
num_iterations=num_iterations,
temperature=temperature, temperature=temperature,
gm_scale=gm_scale,
kl_scale=kl_scale,
top_k=top_k, top_k=top_k,
window_length=window_length, sample=sample,
num_iterations=num_iterations,
grad_length=grad_length,
horizon_length=horizon_length, horizon_length=horizon_length,
window_length=window_length,
decay=decay, decay=decay,
gamma=gamma, gamma=gamma,
gm_scale=gm_scale,
kl_scale=kl_scale,
) )
pert_gen_tok_texts.append(pert_gen_tok_text) pert_gen_tok_texts.append(pert_gen_tok_text)
if classifier is not None: if classifier is not None:
...@@ -490,24 +489,24 @@ def generate_text_pplm( ...@@ -490,24 +489,24 @@ def generate_text_pplm(
context=None, context=None,
past=None, past=None,
device="cuda", device="cuda",
sample=False,
perturb=True, perturb=True,
bow_indices=None,
classifier=None, classifier=None,
class_label=None, class_label=None,
bow_indices=None,
loss_type=0, loss_type=0,
length=100, length=100,
grad_length=10000,
stepsize=0.02, stepsize=0.02,
num_iterations=3,
temperature=1.0, temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10, top_k=10,
window_length=0, sample=False,
num_iterations=3,
grad_length=10000,
horizon_length=1, horizon_length=1,
window_length=0,
decay=False, decay=False,
gamma=1.5, gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
): ):
output_so_far = ( output_so_far = (
torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0) torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
...@@ -561,17 +560,17 @@ def generate_text_pplm( ...@@ -561,17 +560,17 @@ def generate_text_pplm(
accumulated_hidden=accumulated_hidden, accumulated_hidden=accumulated_hidden,
grad_norms=grad_norms, grad_norms=grad_norms,
stepsize=current_stepsize, stepsize=current_stepsize,
one_hot_bows_vectors=one_hot_bows_vectors,
classifier=classifier, classifier=classifier,
class_label=class_label, class_label=class_label,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type, loss_type=loss_type,
num_iterations=num_iterations, num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length, horizon_length=horizon_length,
window_length=window_length,
decay=decay, decay=decay,
gamma=gamma, gamma=gamma,
device=device kl_scale=kl_scale,
device=device,
) )
loss_in_time.append(loss_this_iter) loss_in_time.append(loss_this_iter)
else: else:
...@@ -685,7 +684,7 @@ def run_pplm_example( ...@@ -685,7 +684,7 @@ def run_pplm_example(
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][ pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][
"pretrained_model" "pretrained_model"
] ]
print("discrim = {}, setting pretrained_model " print("discrim = {}, pretrained_model set "
"to discriminator's = {}".format(discrim, pretrained_model)) "to discriminator's = {}".format(discrim, pretrained_model))
# load pretrained model # load pretrained model
...@@ -810,6 +809,20 @@ if __name__ == '__main__': ...@@ -810,6 +809,20 @@ if __name__ == '__main__':
default="gpt2-medium", default="gpt2-medium",
help="pretrained model name or path to local checkpoint", help="pretrained model name or path to local checkpoint",
) )
parser.add_argument(
"--cond_text", type=str, default="The lake",
help="Prefix texts to condition on"
)
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples to generate from the modified latents",
)
parser.add_argument( parser.add_argument(
"--bag_of_words", "--bag_of_words",
"-B", "-B",
...@@ -837,33 +850,22 @@ if __name__ == '__main__': ...@@ -837,33 +850,22 @@ if __name__ == '__main__':
default=-1, default=-1,
help="Class label used for the discriminator", help="Class label used for the discriminator",
) )
parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--length", type=int, default=100) parser.add_argument("--length", type=int, default=100)
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=10) parser.add_argument("--top_k", type=int, default=10)
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument( parser.add_argument(
"--sample", action="store_true", "--sample", action="store_true",
help="Generate from end-of-text as prefix" help="Generate from end-of-text as prefix"
) )
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--cond_text", type=str, default="The lake",
help="Prefix texts to condition on"
)
parser.add_argument("--num_iterations", type=int, default=3) parser.add_argument("--num_iterations", type=int, default=3)
parser.add_argument("--grad_length", type=int, default=10000) parser.add_argument("--grad_length", type=int, default=10000)
parser.add_argument( parser.add_argument(
"--num_samples", "--window_length",
type=int, type=int,
default=1, default=0,
help="Number of samples to generate from the modified latents", help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
) )
parser.add_argument( parser.add_argument(
"--horizon_length", "--horizon_length",
...@@ -871,16 +873,13 @@ if __name__ == '__main__': ...@@ -871,16 +873,13 @@ if __name__ == '__main__':
default=1, default=1,
help="Length of future to optimize over", help="Length of future to optimize over",
) )
parser.add_argument(
"--window_length",
type=int,
default=0,
help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
)
parser.add_argument("--decay", action="store_true", parser.add_argument("--decay", action="store_true",
help="whether to decay or not") help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5) parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument("--colorama", action="store_true", parser.add_argument("--colorama", action="store_true",
help="colors keywords") help="colors keywords")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment