Commit afc7dcd9 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Now run_pplm works on cpu. Identical output as before (when using gpu).

parent 61399e5a
...@@ -84,9 +84,11 @@ DISCRIMINATOR_MODELS_PARAMS = { ...@@ -84,9 +84,11 @@ DISCRIMINATOR_MODELS_PARAMS = {
} }
def to_var(x, requires_grad=False, volatile=False): def to_var(x, requires_grad=False, volatile=False, device='cuda'):
if torch.cuda.is_available(): if torch.cuda.is_available() and device == 'cuda':
x = x.cuda() x = x.cuda()
elif device != 'cuda':
x = x.to(device)
return Variable(x, requires_grad=requires_grad, volatile=volatile) return Variable(x, requires_grad=requires_grad, volatile=volatile)
...@@ -182,7 +184,7 @@ def perturb_past( ...@@ -182,7 +184,7 @@ def perturb_past(
for i in range(num_iterations): for i in range(num_iterations):
print("Iteration ", i + 1) print("Iteration ", i + 1)
curr_perturbation = [ curr_perturbation = [
to_var(torch.from_numpy(p_), requires_grad=True) to_var(torch.from_numpy(p_), requires_grad=True, device=device)
for p_ in grad_accumulator for p_ in grad_accumulator
] ]
...@@ -290,7 +292,7 @@ def perturb_past( ...@@ -290,7 +292,7 @@ def perturb_past(
# apply the accumulated perturbations to the past # apply the accumulated perturbations to the past
grad_accumulator = [ grad_accumulator = [
to_var(torch.from_numpy(p_), requires_grad=True) to_var(torch.from_numpy(p_), requires_grad=True, device=device)
for p_ in grad_accumulator for p_ in grad_accumulator
] ]
pert_past = list(map(add, past, grad_accumulator)) pert_past = list(map(add, past, grad_accumulator))
...@@ -300,7 +302,7 @@ def perturb_past( ...@@ -300,7 +302,7 @@ def perturb_past(
def get_classifier( def get_classifier(
name: Optional[str], label_class: Union[str, int], name: Optional[str], label_class: Union[str, int],
device: Union[str, torch.device] device: str
) -> Tuple[Optional[ClassificationHead], Optional[int]]: ) -> Tuple[Optional[ClassificationHead], Optional[int]]:
if name is None: if name is None:
return None, None return None, None
...@@ -355,16 +357,16 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[ ...@@ -355,16 +357,16 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
return bow_indices return bow_indices
def build_bows_one_hot_vectors(bow_indices): def build_bows_one_hot_vectors(bow_indices, device='cuda'):
if bow_indices is None: if bow_indices is None:
return None return None
one_hot_bows_vectors = [] one_hot_bows_vectors = []
for single_bow in bow_indices: for single_bow in bow_indices:
single_bow = list(filter(lambda x: len(x) <= 1, single_bow)) single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
single_bow = torch.tensor(single_bow).cuda() single_bow = torch.tensor(single_bow).to(device)
num_words = single_bow.shape[0] num_words = single_bow.shape[0]
one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).cuda() one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).to(device)
one_hot_bow.scatter_(1, single_bow, 1) one_hot_bow.scatter_(1, single_bow, 1)
one_hot_bows_vectors.append(one_hot_bow) one_hot_bows_vectors.append(one_hot_bow)
return one_hot_bows_vectors return one_hot_bows_vectors
...@@ -425,7 +427,8 @@ def full_text_generation( ...@@ -425,7 +427,8 @@ def full_text_generation(
length=length, length=length,
perturb=False perturb=False
) )
torch.cuda.empty_cache() if device == 'cuda':
torch.cuda.empty_cache()
pert_gen_tok_texts = [] pert_gen_tok_texts = []
discrim_losses = [] discrim_losses = []
...@@ -460,7 +463,8 @@ def full_text_generation( ...@@ -460,7 +463,8 @@ def full_text_generation(
discrim_losses.append(discrim_loss.data.cpu().numpy()) discrim_losses.append(discrim_loss.data.cpu().numpy())
losses_in_time.append(loss_in_time) losses_in_time.append(loss_in_time)
torch.cuda.empty_cache() if device == 'cuda':
torch.cuda.empty_cache()
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
...@@ -496,7 +500,7 @@ def generate_text_pplm( ...@@ -496,7 +500,7 @@ def generate_text_pplm(
) )
# collect one hot vectors for bags of words # collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices) one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, device)
grad_norms = None grad_norms = None
last = None last = None
...@@ -563,7 +567,7 @@ def generate_text_pplm( ...@@ -563,7 +567,7 @@ def generate_text_pplm(
if classifier is not None: if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss() ce_loss = torch.nn.CrossEntropyLoss()
prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
label = torch.tensor([label_class], device='cuda', label = torch.tensor([label_class], device=device,
dtype=torch.long) dtype=torch.long)
unpert_discrim_loss = ce_loss(prediction, label) unpert_discrim_loss = ce_loss(prediction, label)
print( print(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment