"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f773faa25871357cc25fb04d65cc0d4bcb1364da"
Commit 893d0d64 authored by w4nderlust's avatar w4nderlust Committed by Julien Chaumond
Browse files

Changed order of some parameters to be more consistent. Identical results.

parent f42816e7
...@@ -121,17 +121,17 @@ def perturb_past( ...@@ -121,17 +121,17 @@ def perturb_past(
accumulated_hidden=None, accumulated_hidden=None,
grad_norms=None, grad_norms=None,
stepsize=0.01, stepsize=0.01,
one_hot_bows_vectors=None,
classifier=None, classifier=None,
class_label=None, class_label=None,
one_hot_bows_vectors=None,
loss_type=0, loss_type=0,
num_iterations=3, num_iterations=3,
kl_scale=0.01,
window_length=0,
horizon_length=1, horizon_length=1,
window_length=0,
decay=False, decay=False,
gamma=1.5, gamma=1.5,
device='cuda' kl_scale=0.01,
device='cuda',
): ):
# Generate inital perturbed past # Generate inital perturbed past
grad_accumulator = [ grad_accumulator = [
...@@ -351,8 +351,7 @@ def get_classifier( ...@@ -351,8 +351,7 @@ def get_classifier(
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
List[ List[List[List[int]]]:
List[List[int]]]:
bow_indices = [] bow_indices = []
for id_or_path in bag_of_words_ids_or_paths: for id_or_path in bag_of_words_ids_or_paths:
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP: if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
...@@ -388,22 +387,22 @@ def full_text_generation( ...@@ -388,22 +387,22 @@ def full_text_generation(
context=None, context=None,
num_samples=1, num_samples=1,
device="cuda", device="cuda",
sample=False, bag_of_words=None,
discrim=None, discrim=None,
class_label=None, class_label=None,
bag_of_words=None,
length=100, length=100,
grad_length=10000,
stepsize=0.02, stepsize=0.02,
num_iterations=3,
temperature=1.0, temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10, top_k=10,
window_length=0, sample=False,
num_iterations=3,
grad_length=10000,
horizon_length=1, horizon_length=1,
window_length=0,
decay=False, decay=False,
gamma=1.5, gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
**kwargs **kwargs
): ):
classifier, class_id = get_classifier( classifier, class_id = get_classifier(
...@@ -454,24 +453,24 @@ def full_text_generation( ...@@ -454,24 +453,24 @@ def full_text_generation(
tokenizer=tokenizer, tokenizer=tokenizer,
context=context, context=context,
device=device, device=device,
sample=sample,
perturb=True, perturb=True,
bow_indices=bow_indices, bow_indices=bow_indices,
classifier=classifier, classifier=classifier,
class_label=class_id, class_label=class_id,
loss_type=loss_type, loss_type=loss_type,
length=length, length=length,
grad_length=grad_length,
stepsize=stepsize, stepsize=stepsize,
num_iterations=num_iterations,
temperature=temperature, temperature=temperature,
gm_scale=gm_scale,
kl_scale=kl_scale,
top_k=top_k, top_k=top_k,
window_length=window_length, sample=sample,
num_iterations=num_iterations,
grad_length=grad_length,
horizon_length=horizon_length, horizon_length=horizon_length,
window_length=window_length,
decay=decay, decay=decay,
gamma=gamma, gamma=gamma,
gm_scale=gm_scale,
kl_scale=kl_scale,
) )
pert_gen_tok_texts.append(pert_gen_tok_text) pert_gen_tok_texts.append(pert_gen_tok_text)
if classifier is not None: if classifier is not None:
...@@ -490,24 +489,24 @@ def generate_text_pplm( ...@@ -490,24 +489,24 @@ def generate_text_pplm(
context=None, context=None,
past=None, past=None,
device="cuda", device="cuda",
sample=False,
perturb=True, perturb=True,
bow_indices=None,
classifier=None, classifier=None,
class_label=None, class_label=None,
bow_indices=None,
loss_type=0, loss_type=0,
length=100, length=100,
grad_length=10000,
stepsize=0.02, stepsize=0.02,
num_iterations=3,
temperature=1.0, temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10, top_k=10,
window_length=0, sample=False,
num_iterations=3,
grad_length=10000,
horizon_length=1, horizon_length=1,
window_length=0,
decay=False, decay=False,
gamma=1.5, gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
): ):
output_so_far = ( output_so_far = (
torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0) torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
...@@ -561,17 +560,17 @@ def generate_text_pplm( ...@@ -561,17 +560,17 @@ def generate_text_pplm(
accumulated_hidden=accumulated_hidden, accumulated_hidden=accumulated_hidden,
grad_norms=grad_norms, grad_norms=grad_norms,
stepsize=current_stepsize, stepsize=current_stepsize,
one_hot_bows_vectors=one_hot_bows_vectors,
classifier=classifier, classifier=classifier,
class_label=class_label, class_label=class_label,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type, loss_type=loss_type,
num_iterations=num_iterations, num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length, horizon_length=horizon_length,
window_length=window_length,
decay=decay, decay=decay,
gamma=gamma, gamma=gamma,
device=device kl_scale=kl_scale,
device=device,
) )
loss_in_time.append(loss_this_iter) loss_in_time.append(loss_this_iter)
else: else:
...@@ -685,7 +684,7 @@ def run_pplm_example( ...@@ -685,7 +684,7 @@ def run_pplm_example(
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][ pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][
"pretrained_model" "pretrained_model"
] ]
print("discrim = {}, setting pretrained_model " print("discrim = {}, pretrained_model set "
"to discriminator's = {}".format(discrim, pretrained_model)) "to discriminator's = {}".format(discrim, pretrained_model))
# load pretrained model # load pretrained model
...@@ -810,6 +809,20 @@ if __name__ == '__main__': ...@@ -810,6 +809,20 @@ if __name__ == '__main__':
default="gpt2-medium", default="gpt2-medium",
help="pretrained model name or path to local checkpoint", help="pretrained model name or path to local checkpoint",
) )
parser.add_argument(
"--cond_text", type=str, default="The lake",
help="Prefix texts to condition on"
)
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples to generate from the modified latents",
)
parser.add_argument( parser.add_argument(
"--bag_of_words", "--bag_of_words",
"-B", "-B",
...@@ -837,33 +850,22 @@ if __name__ == '__main__': ...@@ -837,33 +850,22 @@ if __name__ == '__main__':
default=-1, default=-1,
help="Class label used for the discriminator", help="Class label used for the discriminator",
) )
parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--length", type=int, default=100) parser.add_argument("--length", type=int, default=100)
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=10) parser.add_argument("--top_k", type=int, default=10)
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument( parser.add_argument(
"--sample", action="store_true", "--sample", action="store_true",
help="Generate from end-of-text as prefix" help="Generate from end-of-text as prefix"
) )
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--cond_text", type=str, default="The lake",
help="Prefix texts to condition on"
)
parser.add_argument("--num_iterations", type=int, default=3) parser.add_argument("--num_iterations", type=int, default=3)
parser.add_argument("--grad_length", type=int, default=10000) parser.add_argument("--grad_length", type=int, default=10000)
parser.add_argument( parser.add_argument(
"--num_samples", "--window_length",
type=int, type=int,
default=1, default=0,
help="Number of samples to generate from the modified latents", help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
) )
parser.add_argument( parser.add_argument(
"--horizon_length", "--horizon_length",
...@@ -871,16 +873,13 @@ if __name__ == '__main__': ...@@ -871,16 +873,13 @@ if __name__ == '__main__':
default=1, default=1,
help="Length of future to optimize over", help="Length of future to optimize over",
) )
parser.add_argument(
"--window_length",
type=int,
default=0,
help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
)
parser.add_argument("--decay", action="store_true", parser.add_argument("--decay", action="store_true",
help="whether to decay or not") help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5) parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument("--colorama", action="store_true", parser.add_argument("--colorama", action="store_true",
help="colors keywords") help="colors keywords")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment