"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "a5d58187a2eef6616b6527d2473570c43b6547c9"
Commit afc7dcd9 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Now run_pplm works on cpu. Identical output as before (when using gpu).

parent 61399e5a
......@@ -84,9 +84,11 @@ DISCRIMINATOR_MODELS_PARAMS = {
}
def to_var(x, requires_grad=False, volatile=False):
if torch.cuda.is_available():
def to_var(x, requires_grad=False, volatile=False, device='cuda'):
if torch.cuda.is_available() and device == 'cuda':
x = x.cuda()
elif device != 'cuda':
x = x.to(device)
return Variable(x, requires_grad=requires_grad, volatile=volatile)
......@@ -182,7 +184,7 @@ def perturb_past(
for i in range(num_iterations):
print("Iteration ", i + 1)
curr_perturbation = [
to_var(torch.from_numpy(p_), requires_grad=True)
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
for p_ in grad_accumulator
]
......@@ -290,7 +292,7 @@ def perturb_past(
# apply the accumulated perturbations to the past
grad_accumulator = [
to_var(torch.from_numpy(p_), requires_grad=True)
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
for p_ in grad_accumulator
]
pert_past = list(map(add, past, grad_accumulator))
......@@ -300,7 +302,7 @@ def perturb_past(
def get_classifier(
name: Optional[str], label_class: Union[str, int],
device: Union[str, torch.device]
device: str
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
if name is None:
return None, None
......@@ -355,16 +357,16 @@ def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str]) -> List[
return bow_indices
def build_bows_one_hot_vectors(bow_indices):
def build_bows_one_hot_vectors(bow_indices, device='cuda'):
if bow_indices is None:
return None
one_hot_bows_vectors = []
for single_bow in bow_indices:
single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
single_bow = torch.tensor(single_bow).cuda()
single_bow = torch.tensor(single_bow).to(device)
num_words = single_bow.shape[0]
one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).cuda()
one_hot_bow = torch.zeros(num_words, TOKENIZER.vocab_size).to(device)
one_hot_bow.scatter_(1, single_bow, 1)
one_hot_bows_vectors.append(one_hot_bow)
return one_hot_bows_vectors
......@@ -425,6 +427,7 @@ def full_text_generation(
length=length,
perturb=False
)
if device == 'cuda':
torch.cuda.empty_cache()
pert_gen_tok_texts = []
......@@ -460,6 +463,7 @@ def full_text_generation(
discrim_losses.append(discrim_loss.data.cpu().numpy())
losses_in_time.append(loss_in_time)
if device == 'cuda':
torch.cuda.empty_cache()
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
......@@ -496,7 +500,7 @@ def generate_text_pplm(
)
# collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices)
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, device)
grad_norms = None
last = None
......@@ -563,7 +567,7 @@ def generate_text_pplm(
if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss()
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
label = torch.tensor([label_class], device='cuda',
label = torch.tensor([label_class], device=device,
dtype=torch.long)
unpert_discrim_loss = ce_loss(prediction, label)
print(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment