Commit 08c6e456 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Cleaned full_text_generation. Identical output as before.

parent 6c9c1317
...@@ -401,74 +401,6 @@ def full_text_generation( ...@@ -401,74 +401,6 @@ def full_text_generation(
device device
) )
# if args.discrim == 'clickbait':
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
# classifier.load_state_dict(torch.load("discrim_models/clickbait_classifierhead.pt"))
# classifier.eval()
# args.label_class = 1 # clickbaity
#
# elif args.discrim == 'sentiment':
# classifier = ClassificationHead(class_size=5, embed_size=1024).to(device)
# #classifier.load_state_dict(torch.load("discrim_models/sentiment_classifierhead.pt"))
# classifier.load_state_dict(torch.load("discrim_models/SST_classifier_head_epoch_16.pt"))
# classifier.eval()
# if args.label_class < 0:
# raise Exception('Wrong class for sentiment, use --label-class 2 for *very positive*, 3 for *very negative*')
# #args.label_class = 2 # very pos
# #args.label_class = 3 # very neg
#
# elif args.discrim == 'toxicity':
# classifier = ClassificationHead(class_size=2, embed_size=1024).to(device)
# classifier.load_state_dict(torch.load("discrim_models/toxicity_classifierhead.pt"))
# classifier.eval()
# args.label_class = 0 # not toxic
#
# elif args.discrim == 'generic':
# if args.discrim_weights is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_weights need to be specified')
# if args.discrim_meta is None:
# raise ValueError('When using a generic discriminator, '
# 'discrim_meta need to be specified')
#
# with open(args.discrim_meta, 'r') as discrim_meta_file:
# meta = json.load(discrim_meta_file)
#
# classifier = ClassificationHead(
# class_size=meta['class_size'],
# embed_size=meta['embed_size'],
# # todo add tokenizer from meta
# ).to(device)
# classifier.load_state_dict(torch.load(args.discrim_weights))
# classifier.eval()
# if args.label_class == -1:
# args.label_class = meta['default_class']
#
# else:
# classifier = None
# Get tokens for the list of positive words
def list_tokens(word_list):
token_list = [TOKENIZER.encode(word, add_prefix_space=True) for word in
word_list]
# token_list = []
# for word in word_list:
# token_list.append(TOKENIZER.encode(" " + word))
return token_list
# good_index = []
# if args.bag_of_words:
# bags_of_words = args.bag_of_words.split(";")
# for wordlist in bags_of_words:
# with open(wordlist, "r") as f:
# words = f.read().strip()
# words = words.split('\n')
# good_index.append(list_tokens(words))
#
# for good_list in good_index:
# good_list = list(filter(lambda x: len(x) <= 1, good_list))
# actual_words = [(TOKENIZER.decode(ww).strip(),ww) for ww in good_list]
bow_indices = [] bow_indices = []
if bag_of_words: if bag_of_words:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";")) bow_indices = get_bag_of_words_indices(bag_of_words.split(";"))
...@@ -486,9 +418,9 @@ def full_text_generation( ...@@ -486,9 +418,9 @@ def full_text_generation(
print("Using PPLM-Discrim") print("Using PPLM-Discrim")
else: else:
raise Exception("Specify either --bag_of_words (-B) or --discrim (-D)") raise Exception("Specify either a bag of words or a discriminator")
original, _, _ = generate_text_pplm( unpert_gen_tok_text, _, _ = generate_text_pplm(
model=model, model=model,
context=context, context=context,
device=device, device=device,
...@@ -497,12 +429,12 @@ def full_text_generation( ...@@ -497,12 +429,12 @@ def full_text_generation(
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
perturbed_list = [] pert_gen_tok_texts = []
discrim_loss_list = [] discrim_losses = []
loss_in_time_list = [] losses_in_time = []
for i in range(num_samples): for i in range(num_samples):
perturbed, discrim_loss, loss_in_time = generate_text_pplm( pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
model=model, model=model,
context=context, context=context,
device=device, device=device,
...@@ -525,14 +457,14 @@ def full_text_generation( ...@@ -525,14 +457,14 @@ def full_text_generation(
decay=decay, decay=decay,
gamma=gamma, gamma=gamma,
) )
perturbed_list.append(perturbed) pert_gen_tok_texts.append(pert_gen_tok_text)
if classifier is not None: if classifier is not None:
discrim_loss_list.append(discrim_loss.data.cpu().numpy()) discrim_losses.append(discrim_loss.data.cpu().numpy())
loss_in_time_list.append(loss_in_time) losses_in_time.append(loss_in_time)
torch.cuda.empty_cache() torch.cuda.empty_cache()
return original, perturbed_list, discrim_loss_list, loss_in_time_list return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
def generate_text_pplm( def generate_text_pplm(
...@@ -821,11 +753,14 @@ def run_model(): ...@@ -821,11 +753,14 @@ def run_model():
generated_texts = [] generated_texts = []
bow_words = set() bow_word_ids = set()
bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";")) if args.bag_of_words and args.colorama:
for bow_list in bow_indices: bow_indices = get_bag_of_words_indices(args.bag_of_words.split(";"))
filtered = list(filter(lambda x: len(x) <= 1, bow_list)) for single_bow_list in bow_indices:
bow_words.update(w[0] for w in filtered) # filtering all words in the list composed of more than 1 token
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
# w[0] because we are sure w has only 1 item because previous fitler
bow_word_ids.update(w[0] for w in filtered)
# iterate through the perturbed texts # iterate through the perturbed texts
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts): for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
...@@ -836,7 +771,7 @@ def run_model(): ...@@ -836,7 +771,7 @@ def run_model():
pert_gen_text = '' pert_gen_text = ''
for word_id in pert_gen_tok_text.tolist()[0]: for word_id in pert_gen_tok_text.tolist()[0]:
if word_id in bow_words: if word_id in bow_word_ids:
pert_gen_text += '{}{}{}'.format( pert_gen_text += '{}{}{}'.format(
colorama.Fore.RED, colorama.Fore.RED,
TOKENIZER.decode([word_id]), TOKENIZER.decode([word_id]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment