Commit 7fd54b55 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Added support for generic discriminators

parent b0eaff36
...@@ -14,17 +14,16 @@ ...@@ -14,17 +14,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# TODO: add code for training a custom discriminator
""" """
Example command with bag of words: Example command with bag of words:
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95 python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
Example command with discriminator: Example command with discriminator:
python examples/run_pplm.py -D sentiment --label_class 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95 python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
""" """
import argparse import argparse
import json
from operator import add from operator import add
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -121,7 +120,7 @@ def perturb_past( ...@@ -121,7 +120,7 @@ def perturb_past(
grad_norms=None, grad_norms=None,
stepsize=0.01, stepsize=0.01,
classifier=None, classifier=None,
label_class=None, class_label=None,
one_hot_bows_vectors=None, one_hot_bows_vectors=None,
loss_type=0, loss_type=0,
num_iterations=3, num_iterations=3,
...@@ -230,7 +229,7 @@ def perturb_past( ...@@ -230,7 +229,7 @@ def perturb_past(
prediction = classifier(new_accumulated_hidden / prediction = classifier(new_accumulated_hidden /
(curr_length + 1 + horizon_length)) (curr_length + 1 + horizon_length))
label = torch.tensor([label_class], device=device, label = torch.tensor([class_label], device=device,
dtype=torch.long) dtype=torch.long)
discrim_loss = ce_loss(prediction, label) discrim_loss = ce_loss(prediction, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
...@@ -244,7 +243,8 @@ def perturb_past( ...@@ -244,7 +243,8 @@ def perturb_past(
unpert_probs + SMALL_CONST * unpert_probs + SMALL_CONST *
(unpert_probs <= SMALL_CONST).float().to(device).detach() (unpert_probs <= SMALL_CONST).float().to(device).detach()
) )
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach() correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(
device).detach()
corrected_probs = probs + correction.detach() corrected_probs = probs + correction.detach()
kl_loss = kl_scale * ( kl_loss = kl_scale * (
(corrected_probs * (corrected_probs / unpert_probs).log()).sum() (corrected_probs * (corrected_probs / unpert_probs).log()).sum()
...@@ -273,7 +273,8 @@ def perturb_past( ...@@ -273,7 +273,8 @@ def perturb_past(
# normalize gradients # normalize gradients
grad = [ grad = [
-stepsize * -stepsize *
(p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy() (p_.grad * window_mask / grad_norms[
index] ** gamma).data.cpu().numpy()
for index, p_ in enumerate(curr_perturbation) for index, p_ in enumerate(curr_perturbation)
] ]
...@@ -301,7 +302,7 @@ def perturb_past( ...@@ -301,7 +302,7 @@ def perturb_past(
def get_classifier( def get_classifier(
name: Optional[str], label_class: Union[str, int], name: Optional[str], class_label: Union[str, int],
device: str device: str
) -> Tuple[Optional[ClassificationHead], Optional[int]]: ) -> Tuple[Optional[ClassificationHead], Optional[int]]:
if name is None: if name is None:
...@@ -312,26 +313,29 @@ def get_classifier( ...@@ -312,26 +313,29 @@ def get_classifier(
class_size=params['class_size'], class_size=params['class_size'],
embed_size=params['embed_size'] embed_size=params['embed_size']
).to(device) ).to(device)
if "url" in params:
resolved_archive_file = cached_path(params["url"]) resolved_archive_file = cached_path(params["url"])
else:
resolved_archive_file = params["path"]
classifier.load_state_dict( classifier.load_state_dict(
torch.load(resolved_archive_file, map_location=device)) torch.load(resolved_archive_file, map_location=device))
classifier.eval() classifier.eval()
if isinstance(label_class, str): if isinstance(class_label, str):
if label_class in params["class_vocab"]: if class_label in params["class_vocab"]:
label_id = params["class_vocab"][label_class] label_id = params["class_vocab"][class_label]
else: else:
label_id = params["default_class"] label_id = params["default_class"]
print("label_class {} not in class_vocab".format(label_class)) print("class_label {} not in class_vocab".format(class_label))
print("available values are: {}".format(params["class_vocab"])) print("available values are: {}".format(params["class_vocab"]))
print("using default class {}".format(label_id)) print("using default class {}".format(label_id))
elif isinstance(label_class, int): elif isinstance(class_label, int):
if label_class in set(params["class_vocab"].values()): if class_label in set(params["class_vocab"].values()):
label_id = label_class label_id = class_label
else: else:
label_id = params["default_class"] label_id = params["default_class"]
print("label_class {} not in class_vocab".format(label_class)) print("class_label {} not in class_vocab".format(class_label))
print("available values are: {}".format(params["class_vocab"])) print("available values are: {}".format(params["class_vocab"]))
print("using default class {}".format(label_id)) print("using default class {}".format(label_id))
...@@ -379,7 +383,7 @@ def full_text_generation( ...@@ -379,7 +383,7 @@ def full_text_generation(
device="cuda", device="cuda",
sample=True, sample=True,
discrim=None, discrim=None,
label_class=None, class_label=None,
bag_of_words=None, bag_of_words=None,
length=100, length=100,
grad_length=10000, grad_length=10000,
...@@ -397,7 +401,7 @@ def full_text_generation( ...@@ -397,7 +401,7 @@ def full_text_generation(
): ):
classifier, class_id = get_classifier( classifier, class_id = get_classifier(
discrim, discrim,
label_class, class_label,
device device
) )
...@@ -443,7 +447,7 @@ def full_text_generation( ...@@ -443,7 +447,7 @@ def full_text_generation(
perturb=True, perturb=True,
bow_indices=bow_indices, bow_indices=bow_indices,
classifier=classifier, classifier=classifier,
label_class=class_id, class_label=class_id,
loss_type=loss_type, loss_type=loss_type,
length=length, length=length,
grad_length=grad_length, grad_length=grad_length,
...@@ -477,7 +481,7 @@ def generate_text_pplm( ...@@ -477,7 +481,7 @@ def generate_text_pplm(
sample=True, sample=True,
perturb=True, perturb=True,
classifier=None, classifier=None,
label_class=None, class_label=None,
bow_indices=None, bow_indices=None,
loss_type=0, loss_type=0,
length=100, length=100,
...@@ -545,7 +549,7 @@ def generate_text_pplm( ...@@ -545,7 +549,7 @@ def generate_text_pplm(
grad_norms=grad_norms, grad_norms=grad_norms,
stepsize=current_stepsize, stepsize=current_stepsize,
classifier=classifier, classifier=classifier,
label_class=label_class, class_label=class_label,
one_hot_bows_vectors=one_hot_bows_vectors, one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type, loss_type=loss_type,
num_iterations=num_iterations, num_iterations=num_iterations,
...@@ -567,7 +571,7 @@ def generate_text_pplm( ...@@ -567,7 +571,7 @@ def generate_text_pplm(
if classifier is not None: if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = torch.nn.CrossEntropyLoss()
prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
label = torch.tensor([label_class], device=device, label = torch.tensor([class_label], device=device,
dtype=torch.long) dtype=torch.long)
unpert_discrim_loss = ce_loss(prediction, label) unpert_discrim_loss = ce_loss(prediction, label)
print( print(
...@@ -613,6 +617,20 @@ def generate_text_pplm( ...@@ -613,6 +617,20 @@ def generate_text_pplm(
return output_so_far, unpert_discrim_loss, loss_in_time return output_so_far, unpert_discrim_loss, loss_in_time
def set_generic_model_params(discrim_weights, discrim_meta):
if discrim_weights is None:
raise ValueError('When using a generic discriminator, '
'discrim_weights need to be specified')
if discrim_meta is None:
raise ValueError('When using a generic discriminator, '
'discrim_meta need to be specified')
with open(discrim_meta, 'r') as discrim_meta_file:
meta = json.load(discrim_meta_file)
meta['path'] = discrim_weights
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
def run_model(): def run_model():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
...@@ -636,11 +654,15 @@ def run_model(): ...@@ -636,11 +654,15 @@ def run_model():
"-D", "-D",
type=str, type=str,
default=None, default=None,
choices=("clickbait", "sentiment", "toxicity"), choices=("clickbait", "sentiment", "toxicity", "generic"),
help="Discriminator to use for loss-type 2", help="Discriminator to use",
) )
parser.add_argument('--discrim_weights', type=str, default=None,
help='Weights for the generic discriminator')
parser.add_argument('--discrim_meta', type=str, default=None,
help='Meta information for the generic discriminator')
parser.add_argument( parser.add_argument(
"--label_class", "--class_label",
type=int, type=int,
default=-1, default=-1,
help="Class label used for the discriminator", help="Class label used for the discriminator",
...@@ -697,6 +719,9 @@ def run_model(): ...@@ -697,6 +719,9 @@ def run_model():
# set the device # set the device
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
if args.discrim == 'generic':
set_generic_model_params(args.discrim_weights, args.discrim_meta)
# load pretrained model # load pretrained model
model = GPT2LMHeadModel.from_pretrained( model = GPT2LMHeadModel.from_pretrained(
args.model_path, args.model_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment