Commit f10b9250 authored by w4nderlust's avatar w4nderlust Committed by Julien Chaumond
Browse files

Imrpovements: model_path renamed pretrained_model, tokenizer loaded from...

Imrpovements: model_path renamed pretrained_model, tokenizer loaded from pretrained_model, pretrained_model set to discriminator's when discrim is specified, sample = False by default but cli parameter introduced. To obtain identical samples call the cli with --sample
parent 75904dae
...@@ -43,7 +43,6 @@ PPLM_DISCRIM = 2 ...@@ -43,7 +43,6 @@ PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3 PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15 SMALL_CONST = 1e-15
BIG_CONST = 1e10 BIG_CONST = 1e10
TOKENIZER = GPT2Tokenizer.from_pretrained("gpt2-medium")
BAG_OF_WORDS_ARCHIVE_MAP = { BAG_OF_WORDS_ARCHIVE_MAP = {
'kitchen': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/kitchen.txt", 'kitchen': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/kitchen.txt",
...@@ -65,6 +64,7 @@ DISCRIMINATOR_MODELS_PARAMS = { ...@@ -65,6 +64,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
"embed_size": 1024, "embed_size": 1024,
"class_vocab": {"non_clickbait": 0, "clickbait": 1}, "class_vocab": {"non_clickbait": 0, "clickbait": 1},
"default_class": 1, "default_class": 1,
"pretrained_model": "gpt2-medium",
}, },
"sentiment": { "sentiment": {
"url": "http://s.yosinski.com/SST_classifier_head.pt", "url": "http://s.yosinski.com/SST_classifier_head.pt",
...@@ -72,6 +72,7 @@ DISCRIMINATOR_MODELS_PARAMS = { ...@@ -72,6 +72,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
"embed_size": 1024, "embed_size": 1024,
"class_vocab": {"very_positive": 2, "very_negative": 3}, "class_vocab": {"very_positive": 2, "very_negative": 3},
"default_class": 3, "default_class": 3,
"pretrained_model": "gpt2-medium",
}, },
"toxicity": { "toxicity": {
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/toxicity_classifierhead.pt", "url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/toxicity_classifierhead.pt",
...@@ -79,6 +80,7 @@ DISCRIMINATOR_MODELS_PARAMS = { ...@@ -79,6 +80,7 @@ DISCRIMINATOR_MODELS_PARAMS = {
"embed_size": 1024, "embed_size": 1024,
"class_vocab": {"non_toxic": 0, "toxic": 1}, "class_vocab": {"non_toxic": 0, "toxic": 1},
"default_class": 0, "default_class": 0,
"pretrained_model": "gpt2-medium",
}, },
} }
...@@ -345,8 +347,9 @@ def get_classifier( ...@@ -345,8 +347,9 @@ def get_classifier(
return classifier, label_id return classifier, label_id
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
List[List[int]]]: List[
List[List[int]]]:
bow_indices = [] bow_indices = []
for id_or_path in bag_of_words_ids_or_paths: for id_or_path in bag_of_words_ids_or_paths:
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP: if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
...@@ -356,12 +359,12 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[ ...@@ -356,12 +359,12 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
with open(filepath, "r") as f: with open(filepath, "r") as f:
words = f.read().strip().split("\n") words = f.read().strip().split("\n")
bow_indices.append( bow_indices.append(
[TOKENIZER.encode(word.strip(), add_prefix_space=True) for word in [tokenizer.encode(word.strip(), add_prefix_space=True) for word in
words]) words])
return bow_indices return bow_indices
def build_bows_one_hot_vectors(bow_indices, device='cuda'): def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):
if bow_indices is None: if bow_indices is None:
return None return None
...@@ -370,7 +373,7 @@ def build_bows_one_hot_vectors(bow_indices, device='cuda'): ...@@ -370,7 +373,7 @@ def build_bows_one_hot_vectors(bow_indices, device='cuda'):
single_bow = list(filter(lambda x: len(x) <= 1, single_bow)) single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
single_bow = torch.tensor(single_bow).to(device) single_bow = torch.tensor(single_bow).to(device)
num_words = single_bow.shape[0] num_words = single_bow.shape[0]
one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).to(device) one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
one_hot_bow.scatter_(1, single_bow, 1) one_hot_bow.scatter_(1, single_bow, 1)
one_hot_bows_vectors.append(one_hot_bow) one_hot_bows_vectors.append(one_hot_bow)
return one_hot_bows_vectors return one_hot_bows_vectors
...@@ -378,10 +381,11 @@ def build_bows_one_hot_vectors(bow_indices, device='cuda'): ...@@ -378,10 +381,11 @@ def build_bows_one_hot_vectors(bow_indices, device='cuda'):
def full_text_generation( def full_text_generation(
model, model,
tokenizer,
context=None, context=None,
num_samples=1, num_samples=1,
device="cuda", device="cuda",
sample=True, sample=False,
discrim=None, discrim=None,
class_label=None, class_label=None,
bag_of_words=None, bag_of_words=None,
...@@ -407,7 +411,8 @@ def full_text_generation( ...@@ -407,7 +411,8 @@ def full_text_generation(
bow_indices = [] bow_indices = []
if bag_of_words: if bag_of_words:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";")) bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
tokenizer)
if bag_of_words and classifier: if bag_of_words and classifier:
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.") print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
...@@ -426,9 +431,11 @@ def full_text_generation( ...@@ -426,9 +431,11 @@ def full_text_generation(
unpert_gen_tok_text, _, _ = generate_text_pplm( unpert_gen_tok_text, _, _ = generate_text_pplm(
model=model, model=model,
tokenizer=tokenizer,
context=context, context=context,
device=device, device=device,
length=length, length=length,
sample=sample,
perturb=False perturb=False
) )
if device == 'cuda': if device == 'cuda':
...@@ -441,6 +448,7 @@ def full_text_generation( ...@@ -441,6 +448,7 @@ def full_text_generation(
for i in range(num_samples): for i in range(num_samples):
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm( pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
model=model, model=model,
tokenizer=tokenizer,
context=context, context=context,
device=device, device=device,
sample=sample, sample=sample,
...@@ -475,10 +483,11 @@ def full_text_generation( ...@@ -475,10 +483,11 @@ def full_text_generation(
def generate_text_pplm( def generate_text_pplm(
model, model,
tokenizer,
context=None, context=None,
past=None, past=None,
device="cuda", device="cuda",
sample=True, sample=False,
perturb=True, perturb=True,
classifier=None, classifier=None,
class_label=None, class_label=None,
...@@ -504,7 +513,8 @@ def generate_text_pplm( ...@@ -504,7 +513,8 @@ def generate_text_pplm(
) )
# collect one hot vectors for bags of words # collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, device) one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
device)
grad_norms = None grad_norms = None
last = None last = None
...@@ -612,7 +622,7 @@ def generate_text_pplm( ...@@ -612,7 +622,7 @@ def generate_text_pplm(
else torch.cat((output_so_far, last), dim=1) else torch.cat((output_so_far, last), dim=1)
) )
print(TOKENIZER.decode(output_so_far.tolist()[0])) print(tokenizer.decode(output_so_far.tolist()[0]))
return output_so_far, unpert_discrim_loss, loss_in_time return output_so_far, unpert_discrim_loss, loss_in_time
...@@ -631,123 +641,79 @@ def set_generic_model_params(discrim_weights, discrim_meta): ...@@ -631,123 +641,79 @@ def set_generic_model_params(discrim_weights, discrim_meta):
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
def run_model(): def run_pplm_example(
parser = argparse.ArgumentParser() pretrained_model="gpt2-medium",
parser.add_argument( cond_text="",
"--model_path", uncond=False,
"-M", num_samples=1,
type=str, bag_of_words=None,
default="gpt2-medium", discrim=None,
help="pretrained model name or path to local checkpoint", discrim_weights=None,
) discrim_meta=None,
parser.add_argument( class_label=-1,
"--bag_of_words", length=100,
"-B", stepsize=0.02,
type=str, temperature=1.0,
default=None, top_k=10,
help="Bags of words used for PPLM-BoW. " sample=False,
"Either a BOW id (see list in code) or a filepath. " num_iterations=3,
"Multiple BoWs separated by ;", grad_length=10000,
) horizon_length=1,
parser.add_argument( window_length=0,
"--discrim", decay=False,
"-D", gamma=1.5,
type=str, gm_scale=0.9,
default=None, kl_scale=0.01,
choices=("clickbait", "sentiment", "toxicity", "generic"), seed=0,
help="Discriminator to use", no_cuda=False,
) colorama=False
parser.add_argument('--discrim_weights', type=str, default=None, ):
help='Weights for the generic discriminator')
parser.add_argument('--discrim_meta', type=str, default=None,
help='Meta information for the generic discriminator')
parser.add_argument(
"--class_label",
type=int,
default=-1,
help="Class label used for the discriminator",
)
parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--length", type=int, default=100)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=10)
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--cond_text", type=str, default="The lake",
help="Prefix texts to condition on"
)
parser.add_argument("--num_iterations", type=int, default=3)
parser.add_argument("--grad_length", type=int, default=10000)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples to generate from the modified latents",
)
parser.add_argument(
"--horizon_length",
type=int,
default=1,
help="Length of future to optimize over",
)
parser.add_argument(
"--window_length",
type=int,
default=0,
help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
)
parser.add_argument("--decay", action="store_true",
help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--colorama", action="store_true",
help="colors keywords")
args = parser.parse_args()
# set Random seed # set Random seed
torch.manual_seed(args.seed) torch.manual_seed(seed)
np.random.seed(args.seed) np.random.seed(seed)
# set the device # set the device
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
if discrim == 'generic':
set_generic_model_params(discrim_weights, discrim_meta)
if args.discrim == 'generic': if discrim is not None:
set_generic_model_params(args.discrim_weights, args.discrim_meta) pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][
"pretrained_model"
]
print("discrim = {}, setting pretrained_model "
"to discriminator's = {}".format(discrim, pretrained_model))
# load pretrained model # load pretrained model
model = GPT2LMHeadModel.from_pretrained( model = GPT2LMHeadModel.from_pretrained(
args.model_path, pretrained_model,
output_hidden_states=True output_hidden_states=True
) )
model.to(device) model.to(device)
model.eval() model.eval()
# load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
# Freeze GPT-2 weights # Freeze GPT-2 weights
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
# figure out conditioning text # figure out conditioning text
if args.uncond: if uncond:
tokenized_cond_text = TOKENIZER.encode( tokenized_cond_text = tokenizer.encode(
[TOKENIZER.bos_token] [tokenizer.bos_token]
) )
else: else:
raw_text = args.cond_text raw_text = cond_text
while not raw_text: while not raw_text:
print("Did you forget to add `--cond_text`? ") print("Did you forget to add `--cond_text`? ")
raw_text = input("Model prompt >>> ") raw_text = input("Model prompt >>> ")
tokenized_cond_text = TOKENIZER.encode(TOKENIZER.bos_token + raw_text) tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text)
print("= Prefix of sentence =") print("= Prefix of sentence =")
print(TOKENIZER.decode(tokenized_cond_text)) print(tokenizer.decode(tokenized_cond_text))
print() print()
# generate unperturbed and perturbed texts # generate unperturbed and perturbed texts
...@@ -755,11 +721,31 @@ def run_model(): ...@@ -755,11 +721,31 @@ def run_model():
# full_text_generation returns: # full_text_generation returns:
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time # unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation( unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
model=model, context=tokenized_cond_text, device=device, **vars(args) model=model,
tokenizer=tokenizer,
context=tokenized_cond_text,
device=device,
num_samples=num_samples,
bag_of_words=bag_of_words,
discrim=discrim,
class_label=class_label,
length=length,
stepsize=stepsize,
temperature=temperature,
top_k=top_k,
sample=sample,
num_iterations=num_iterations,
grad_length=grad_length,
horizon_length=horizon_length,
window_length=window_length,
decay=decay,
gamma=gamma,
gm_scale=gm_scale,
kl_scale=kl_scale,
) )
# untokenize unperturbed text # untokenize unperturbed text
unpert_gen_text = TOKENIZER.decode(unpert_gen_tok_text.tolist()[0]) unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0])
print("=" * 80) print("=" * 80)
print("= Unperturbed generated text =") print("= Unperturbed generated text =")
...@@ -769,8 +755,9 @@ def run_model(): ...@@ -769,8 +755,9 @@ def run_model():
generated_texts = [] generated_texts = []
bow_word_ids = set() bow_word_ids = set()
if args.bag_of_words and args.colorama: if bag_of_words and colorama:
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";")) bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
tokenizer)
for single_bow_list in bow_indices: for single_bow_list in bow_indices:
# filtering all words in the list composed of more than 1 token # filtering all words in the list composed of more than 1 token
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list)) filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
...@@ -781,7 +768,7 @@ def run_model(): ...@@ -781,7 +768,7 @@ def run_model():
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts): for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
try: try:
# untokenize unperturbed text # untokenize unperturbed text
if args.colorama: if colorama:
import colorama import colorama
pert_gen_text = '' pert_gen_text = ''
...@@ -789,13 +776,13 @@ def run_model(): ...@@ -789,13 +776,13 @@ def run_model():
if word_id in bow_word_ids: if word_id in bow_word_ids:
pert_gen_text += '{}{}{}'.format( pert_gen_text += '{}{}{}'.format(
colorama.Fore.RED, colorama.Fore.RED,
TOKENIZER.decode([word_id]), tokenizer.decode([word_id]),
colorama.Style.RESET_ALL colorama.Style.RESET_ALL
) )
else: else:
pert_gen_text += TOKENIZER.decode([word_id]) pert_gen_text += tokenizer.decode([word_id])
else: else:
pert_gen_text = TOKENIZER.decode(pert_gen_tok_text.tolist()[0]) pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])
print("= Perturbed generated text {} =".format(i + 1)) print("= Perturbed generated text {} =".format(i + 1))
print(pert_gen_text) print(pert_gen_text)
...@@ -812,4 +799,87 @@ def run_model(): ...@@ -812,4 +799,87 @@ def run_model():
if __name__ == '__main__': if __name__ == '__main__':
run_model() parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model",
"-M",
type=str,
default="gpt2-medium",
help="pretrained model name or path to local checkpoint",
)
parser.add_argument(
"--bag_of_words",
"-B",
type=str,
default=None,
help="Bags of words used for PPLM-BoW. "
"Either a BOW id (see list in code) or a filepath. "
"Multiple BoWs separated by ;",
)
parser.add_argument(
"--discrim",
"-D",
type=str,
default=None,
choices=("clickbait", "sentiment", "toxicity", "generic"),
help="Discriminator to use",
)
parser.add_argument('--discrim_weights', type=str, default=None,
help='Weights for the generic discriminator')
parser.add_argument('--discrim_meta', type=str, default=None,
help='Meta information for the generic discriminator')
parser.add_argument(
"--class_label",
type=int,
default=-1,
help="Class label used for the discriminator",
)
parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--length", type=int, default=100)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=10)
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument(
"--sample", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--cond_text", type=str, default="The lake",
help="Prefix texts to condition on"
)
parser.add_argument("--num_iterations", type=int, default=3)
parser.add_argument("--grad_length", type=int, default=10000)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples to generate from the modified latents",
)
parser.add_argument(
"--horizon_length",
type=int,
default=1,
help="Length of future to optimize over",
)
parser.add_argument(
"--window_length",
type=int,
default=0,
help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
)
parser.add_argument("--decay", action="store_true",
help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--colorama", action="store_true",
help="colors keywords")
args = parser.parse_args()
run_pplm_example(**vars(args))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment