Commit 75904dae authored by w4nderlust's avatar w4nderlust Committed by Julien Chaumond
Browse files

Removed global variable device

parent 7fd54b55
...@@ -25,7 +25,6 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel ...@@ -25,7 +25,6 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel
torch.manual_seed(0) torch.manual_seed(0)
np.random.seed(0) np.random.seed(0)
EPSILON = 1e-10 EPSILON = 1e-10
device = "cpu"
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had." example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
max_length_seq = 100 max_length_seq = 100
...@@ -55,7 +54,8 @@ class Discriminator(torch.nn.Module): ...@@ -55,7 +54,8 @@ class Discriminator(torch.nn.Module):
self, self,
class_size, class_size,
pretrained_model="gpt2-medium", pretrained_model="gpt2-medium",
cached_mode=False cached_mode=False,
device='cpu'
): ):
super(Discriminator, self).__init__() super(Discriminator, self).__init__()
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
...@@ -66,6 +66,7 @@ class Discriminator(torch.nn.Module): ...@@ -66,6 +66,7 @@ class Discriminator(torch.nn.Module):
embed_size=self.embed_size embed_size=self.embed_size
) )
self.cached_mode = cached_mode self.cached_mode = cached_mode
self.device = device
def get_classifier(self): def get_classifier(self):
return self.classifier_head return self.classifier_head
...@@ -78,7 +79,7 @@ class Discriminator(torch.nn.Module): ...@@ -78,7 +79,7 @@ class Discriminator(torch.nn.Module):
def avg_representation(self, x): def avg_representation(self, x):
mask = x.ne(0).unsqueeze(2).repeat( mask = x.ne(0).unsqueeze(2).repeat(
1, 1, self.embed_size 1, 1, self.embed_size
).float().to(device).detach() ).float().to(self.device).detach()
hidden, _ = self.encoder.transformer(x) hidden, _ = self.encoder.transformer(x)
masked_hidden = hidden * mask masked_hidden = hidden * mask
avg_hidden = torch.sum(masked_hidden, dim=1) / ( avg_hidden = torch.sum(masked_hidden, dim=1) / (
...@@ -88,9 +89,9 @@ class Discriminator(torch.nn.Module): ...@@ -88,9 +89,9 @@ class Discriminator(torch.nn.Module):
def forward(self, x): def forward(self, x):
if self.cached_mode: if self.cached_mode:
avg_hidden = x.to(device) avg_hidden = x.to(self.device)
else: else:
avg_hidden = self.avg_representation(x.to(device)) avg_hidden = self.avg_representation(x.to(self.device))
logits = self.classifier_head(avg_hidden) logits = self.classifier_head(avg_hidden)
probs = F.log_softmax(logits, dim=-1) probs = F.log_softmax(logits, dim=-1)
...@@ -152,7 +153,7 @@ def cached_collate_fn(data): ...@@ -152,7 +153,7 @@ def cached_collate_fn(data):
def train_epoch(data_loader, discriminator, optimizer, def train_epoch(data_loader, discriminator, optimizer,
epoch=0, log_interval=10): epoch=0, log_interval=10, device='cpu'):
samples_so_far = 0 samples_so_far = 0
discriminator.train_custom() discriminator.train_custom()
for batch_idx, (input_t, target_t) in enumerate(data_loader): for batch_idx, (input_t, target_t) in enumerate(data_loader):
...@@ -177,7 +178,7 @@ def train_epoch(data_loader, discriminator, optimizer, ...@@ -177,7 +178,7 @@ def train_epoch(data_loader, discriminator, optimizer,
) )
def evaluate_performance(data_loader, discriminator): def evaluate_performance(data_loader, discriminator, device='cpu'):
discriminator.eval() discriminator.eval()
test_loss = 0 test_loss = 0
correct = 0 correct = 0
...@@ -202,7 +203,7 @@ def evaluate_performance(data_loader, discriminator): ...@@ -202,7 +203,7 @@ def evaluate_performance(data_loader, discriminator):
) )
def predict(input_sentence, model, classes, cached=False): def predict(input_sentence, model, classes, cached=False, device='cpu'):
input_t = model.tokenizer.encode(input_sentence) input_t = model.tokenizer.encode(input_sentence)
input_t = torch.tensor([input_t], dtype=torch.long, device=device) input_t = torch.tensor([input_t], dtype=torch.long, device=device)
if cached: if cached:
...@@ -216,7 +217,8 @@ def predict(input_sentence, model, classes, cached=False): ...@@ -216,7 +217,8 @@ def predict(input_sentence, model, classes, cached=False):
)) ))
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False): def get_cached_data_loader(dataset, batch_size, discriminator,
shuffle=False, device='cpu'):
data_loader = torch.utils.data.DataLoader(dataset=dataset, data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size, batch_size=batch_size,
collate_fn=collate_fn) collate_fn=collate_fn)
...@@ -244,7 +246,6 @@ def train_discriminator( ...@@ -244,7 +246,6 @@ def train_discriminator(
dataset, dataset_fp=None, pretrained_model="gpt2-medium", dataset, dataset_fp=None, pretrained_model="gpt2-medium",
epochs=10, batch_size=64, log_interval=10, epochs=10, batch_size=64, log_interval=10,
save_model=False, cached=False, no_cuda=False): save_model=False, cached=False, no_cuda=False):
global device
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
print("Preprocessing {} dataset...".format(dataset)) print("Preprocessing {} dataset...".format(dataset))
...@@ -258,7 +259,8 @@ def train_discriminator( ...@@ -258,7 +259,8 @@ def train_discriminator(
discriminator = Discriminator( discriminator = Discriminator(
class_size=len(idx2class), class_size=len(idx2class),
pretrained_model=pretrained_model, pretrained_model=pretrained_model,
cached_mode=cached cached_mode=cached,
device=device
).to(device) ).to(device)
text = torchtext_data.Field() text = torchtext_data.Field()
...@@ -309,7 +311,8 @@ def train_discriminator( ...@@ -309,7 +311,8 @@ def train_discriminator(
discriminator = Discriminator( discriminator = Discriminator(
class_size=len(idx2class), class_size=len(idx2class),
pretrained_model=pretrained_model, pretrained_model=pretrained_model,
cached_mode=cached cached_mode=cached,
device=device
).to(device) ).to(device)
with open("datasets/clickbait/clickbait_train_prefix.txt") as f: with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
...@@ -368,7 +371,8 @@ def train_discriminator( ...@@ -368,7 +371,8 @@ def train_discriminator(
discriminator = Discriminator( discriminator = Discriminator(
class_size=len(idx2class), class_size=len(idx2class),
pretrained_model=pretrained_model, pretrained_model=pretrained_model,
cached_mode=cached cached_mode=cached,
device=device
).to(device) ).to(device)
x = [] x = []
...@@ -431,7 +435,8 @@ def train_discriminator( ...@@ -431,7 +435,8 @@ def train_discriminator(
discriminator = Discriminator( discriminator = Discriminator(
class_size=len(idx2class), class_size=len(idx2class),
pretrained_model=pretrained_model, pretrained_model=pretrained_model,
cached_mode=cached cached_mode=cached,
device=device
).to(device) ).to(device)
x = [] x = []
...@@ -494,11 +499,12 @@ def train_discriminator( ...@@ -494,11 +499,12 @@ def train_discriminator(
start = time.time() start = time.time()
train_loader = get_cached_data_loader( train_loader = get_cached_data_loader(
train_dataset, batch_size, discriminator, shuffle=True train_dataset, batch_size, discriminator,
shuffle=True, device=device
) )
test_loader = get_cached_data_loader( test_loader = get_cached_data_loader(
test_dataset, batch_size, discriminator test_dataset, batch_size, discriminator, device=device
) )
end = time.time() end = time.time()
...@@ -529,18 +535,21 @@ def train_discriminator( ...@@ -529,18 +535,21 @@ def train_discriminator(
data_loader=train_loader, data_loader=train_loader,
optimizer=optimizer, optimizer=optimizer,
epoch=epoch, epoch=epoch,
log_interval=log_interval log_interval=log_interval,
device=device
) )
evaluate_performance( evaluate_performance(
data_loader=test_loader, data_loader=test_loader,
discriminator=discriminator discriminator=discriminator,
device=device
) )
end = time.time() end = time.time()
print("Epoch took: {:.3f}s".format(end - start)) print("Epoch took: {:.3f}s".format(end - start))
print("\nExample prediction") print("\nExample prediction")
predict(example_sentence, discriminator, idx2class, cached) predict(example_sentence, discriminator, idx2class,
cached=cached, device=device)
if save_model: if save_model:
# torch.save(discriminator.state_dict(), # torch.save(discriminator.state_dict(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment