Commit f10b9250 authored by w4nderlust's avatar w4nderlust Committed by Julien Chaumond
Browse files

Imrpovements: model_path renamed pretrained_model, tokenizer loaded from...

Imrpovements: model_path renamed pretrained_model, tokenizer loaded from pretrained_model, pretrained_model set to discriminator's when discrim is specified, sample = False by default but cli parameter introduced. To obtain identical samples call the cli with --sample
parent 75904dae
......@@ -43,7 +43,6 @@ PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15
BIG_CONST = 1e10
TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium")
BAG_OF_WORDS_ARCHIVE_MAP = {
'kitchen': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/kitchen.txt",
......@@ -65,6 +64,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
"embed_size": 1024,
"class_vocab": {"non_clickbait": 0, "clickbait": 1},
"default_class": 1,
"pretrained_model": "gpt2-medium",
},
"sentiment": {
"url": "http://s.yosinski.com/SST_classifier_head.pt",
......@@ -72,6 +72,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
"embed_size": 1024,
"class_vocab": {"very_positive": 2, "very_negative": 3},
"default_class": 3,
"pretrained_model": "gpt2-medium",
},
"toxicity": {
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/toxicity_classifierhead.pt",
......@@ -79,6 +80,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
"embed_size": 1024,
"class_vocab": {"non_toxic": 0, "toxic": 1},
"default_class": 0,
"pretrained_model": "gpt2-medium",
},
}
......@@ -345,7 +347,8 @@ def get_classifier(
return classifier, label_id
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
List[
List[List[int]]]:
bow_indices = []
for id_or_path in bag_of_words_ids_or_paths:
......@@ -356,12 +359,12 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
with open(filepath, "r") as f:
words = f.read().strip().split("\n")
bow_indices.append(
[TOKENIZER.encode(word.strip(), add_prefix_space=True) for word in
[tokenizer.encode(word.strip(), add_prefix_space=True) for word in
words])
return bow_indices
def build_bows_one_hot_vectors(bow_indices, device='cuda'):
def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):
if bow_indices is None:
return None
......@@ -370,7 +373,7 @@ def build_bows_one_hot_vectors(bow_indices, device='cuda'):
single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
single_bow = torch.tensor(single_bow).to(device)
num_words = single_bow.shape[0]
one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).to(device)
one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
one_hot_bow.scatter_(1, single_bow, 1)
one_hot_bows_vectors.append(one_hot_bow)
return one_hot_bows_vectors
......@@ -378,10 +381,11 @@ def build_bows_one_hot_vectors(bow_indices, device='cuda'):
def full_text_generation(
model,
tokenizer,
context=None,
num_samples=1,
device="cuda",
sample=True,
sample=False,
discrim=None,
class_label=None,
bag_of_words=None,
......@@ -407,7 +411,8 @@ def full_text_generation(
bow_indices = []
if bag_of_words:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
tokenizer)
if bag_of_words and classifier:
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
......@@ -426,9 +431,11 @@ def full_text_generation(
unpert_gen_tok_text, _, _ = generate_text_pplm(
model=model,
tokenizer=tokenizer,
context=context,
device=device,
length=length,
sample=sample,
perturb=False
)
if device == 'cuda':
......@@ -441,6 +448,7 @@ def full_text_generation(
for i in range(num_samples):
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
model=model,
tokenizer=tokenizer,
context=context,
device=device,
sample=sample,
......@@ -475,10 +483,11 @@ def full_text_generation(
def generate_text_pplm(
model,
tokenizer,
context=None,
past=None,
device="cuda",
sample=True,
sample=False,
perturb=True,
classifier=None,
class_label=None,
......@@ -504,7 +513,8 @@ def generate_text_pplm(
)
# collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, device)
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
device)
grad_norms = None
last = None
......@@ -612,7 +622,7 @@ def generate_text_pplm(
else torch.cat((output_so_far, last), dim=1)
)
print(TOKENIZER.decode(output_so_far.tolist()[0]))
print(tokenizer.decode(output_so_far.tolist()[0]))
return output_so_far, unpert_discrim_loss, loss_in_time
......@@ -631,123 +641,79 @@ def set_generic_model_params(discrim_weights, discrim_meta):
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
def run_model():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_path",
"-M",
type=str,
default="gpt2-medium",
help="pretrained model name or path to local checkpoint",
)
parser.add_argument(
"--bag_of_words",
"-B",
type=str,
default=None,
help="Bags of words used for PPLM-BoW. "
"Either a BOW id (see list in code) or a filepath. "
"Multiple BoWs separated by ;",
)
parser.add_argument(
"--discrim",
"-D",
type=str,
default=None,
choices=("clickbait", "sentiment", "toxicity", "generic"),
help="Discriminator to use",
)
parser.add_argument('--discrim_weights', type=str, default=None,
help='Weights for the generic discriminator')
parser.add_argument('--discrim_meta', type=str, default=None,
help='Meta information for the generic discriminator')
parser.add_argument(
"--class_label",
type=int,
default=-1,
help="Class label used for the discriminator",
)
parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--length", type=int, default=100)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=10)
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--cond_text", type=str, default="The lake",
help="Prefix texts to condition on"
)
parser.add_argument("--num_iterations", type=int, default=3)
parser.add_argument("--grad_length", type=int, default=10000)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples to generate from the modified latents",
)
parser.add_argument(
"--horizon_length",
type=int,
default=1,
help="Length of future to optimize over",
)
parser.add_argument(
"--window_length",
type=int,
default=0,
help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
)
parser.add_argument("--decay", action="store_true",
help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--colorama", action="store_true",
help="colors keywords")
args = parser.parse_args()
def run_pplm_example(
pretrained_model="gpt2-medium",
cond_text="",
uncond=False,
num_samples=1,
bag_of_words=None,
discrim=None,
discrim_weights=None,
discrim_meta=None,
class_label=-1,
length=100,
stepsize=0.02,
temperature=1.0,
top_k=10,
sample=False,
num_iterations=3,
grad_length=10000,
horizon_length=1,
window_length=0,
decay=False,
gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
seed=0,
no_cuda=False,
colorama=False
):
# set Random seed
torch.manual_seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(seed)
np.random.seed(seed)
# set the device
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
if discrim == 'generic':
set_generic_model_params(discrim_weights, discrim_meta)
if args.discrim == 'generic':
set_generic_model_params(args.discrim_weights, args.discrim_meta)
if discrim is not None:
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][
"pretrained_model"
]
print("discrim = {}, setting pretrained_model "
"to discriminator's = {}".format(discrim, pretrained_model))
# load pretrained model
model = GPT2LMHeadModel.from_pretrained(
args.model_path,
pretrained_model,
output_hidden_states=True
)
model.to(device)
model.eval()
# load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
# Freeze GPT-2 weights
for param in model.parameters():
param.requires_grad = False
# figure out conditioning text
if args.uncond:
tokenized_cond_text = TOKENIZER.encode(
[TOKENIZER.bos_token]
if uncond:
tokenized_cond_text = tokenizer.encode(
[tokenizer.bos_token]
)
else:
raw_text = args.cond_text
raw_text = cond_text
while not raw_text:
print("Did you forget to add `--cond_text`? ")
raw_text = input("Model prompt >>> ")
tokenized_cond_text = TOKENIZER.encode(TOKENIZER.bos_token + raw_text)
tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text)
print("= Prefix of sentence =")
print(TOKENIZER.decode(tokenized_cond_text))
print(tokenizer.decode(tokenized_cond_text))
print()
# generate unperturbed and perturbed texts
......@@ -755,11 +721,31 @@ def run_model():
# full_text_generation returns:
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
model=model, context=tokenized_cond_text, device=device, **vars(args)
model=model,
tokenizer=tokenizer,
context=tokenized_cond_text,
device=device,
num_samples=num_samples,
bag_of_words=bag_of_words,
discrim=discrim,
class_label=class_label,
length=length,
stepsize=stepsize,
temperature=temperature,
top_k=top_k,
sample=sample,
num_iterations=num_iterations,
grad_length=grad_length,
horizon_length=horizon_length,
window_length=window_length,
decay=decay,
gamma=gamma,
gm_scale=gm_scale,
kl_scale=kl_scale,
)
# untokenize unperturbed text
unpert_gen_text = TOKENIZER.decode(unpert_gen_tok_text.tolist()[0])
unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0])
print("=" * 80)
print("= Unperturbed generated text =")
......@@ -769,8 +755,9 @@ def run_model():
generated_texts = []
bow_word_ids = set()
if args.bag_of_words and args.colorama:
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
if bag_of_words and colorama:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
tokenizer)
for single_bow_list in bow_indices:
# filtering all words in the list composed of more than 1 token
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
......@@ -781,7 +768,7 @@ def run_model():
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
try:
# untokenize unperturbed text
if args.colorama:
if colorama:
import colorama
pert_gen_text = ''
......@@ -789,13 +776,13 @@ def run_model():
if word_id in bow_word_ids:
pert_gen_text += '{}{}{}'.format(
colorama.Fore.RED,
TOKENIZER.decode([word_id]),
tokenizer.decode([word_id]),
colorama.Style.RESET_ALL
)
else:
pert_gen_text += TOKENIZER.decode([word_id])
pert_gen_text += tokenizer.decode([word_id])
else:
pert_gen_text = TOKENIZER.decode(pert_gen_tok_text.tolist()[0])
pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])
print("= Perturbed generated text {} =".format(i + 1))
print(pert_gen_text)
......@@ -812,4 +799,87 @@ def run_model():
if __name__ == '__main__':
run_model()
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model",
"-M",
type=str,
default="gpt2-medium",
help="pretrained model name or path to local checkpoint",
)
parser.add_argument(
"--bag_of_words",
"-B",
type=str,
default=None,
help="Bags of words used for PPLM-BoW. "
"Either a BOW id (see list in code) or a filepath. "
"Multiple BoWs separated by ;",
)
parser.add_argument(
"--discrim",
"-D",
type=str,
default=None,
choices=("clickbait", "sentiment", "toxicity", "generic"),
help="Discriminator to use",
)
parser.add_argument('--discrim_weights', type=str, default=None,
help='Weights for the generic discriminator')
parser.add_argument('--discrim_meta', type=str, default=None,
help='Meta information for the generic discriminator')
parser.add_argument(
"--class_label",
type=int,
default=-1,
help="Class label used for the discriminator",
)
parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--length", type=int, default=100)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=10)
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument(
"--sample", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--cond_text", type=str, default="The lake",
help="Prefix texts to condition on"
)
parser.add_argument("--num_iterations", type=int, default=3)
parser.add_argument("--grad_length", type=int, default=10000)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples to generate from the modified latents",
)
parser.add_argument(
"--horizon_length",
type=int,
default=1,
help="Length of future to optimize over",
)
parser.add_argument(
"--window_length",
type=int,
default=0,
help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
)
parser.add_argument("--decay", action="store_true",
help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--colorama", action="store_true",
help="colors keywords")
args = parser.parse_args()
run_pplm_example(**vars(args))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment