Commit 75904dae authored by w4nderlust's avatar w4nderlust Committed by Julien Chaumond
Browse files

Removed global variable device

parent 7fd54b55
......@@ -25,7 +25,6 @@ from transformers import GPT2Tokenizer, GPT2LMHeadModel
torch.manual_seed(0)
np.random.seed(0)
EPSILON = 1e-10
device = "cpu"
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
max_length_seq = 100
......@@ -55,7 +54,8 @@ class Discriminator(torch.nn.Module):
self,
class_size,
pretrained_model="gpt2-medium",
cached_mode=False
cached_mode=False,
device='cpu'
):
super(Discriminator, self).__init__()
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
......@@ -66,6 +66,7 @@ class Discriminator(torch.nn.Module):
embed_size=self.embed_size
)
self.cached_mode = cached_mode
self.device = device
def get_classifier(self):
return self.classifier_head
......@@ -78,7 +79,7 @@ class Discriminator(torch.nn.Module):
def avg_representation(self, x):
mask = x.ne(0).unsqueeze(2).repeat(
1, 1, self.embed_size
).float().to(device).detach()
).float().to(self.device).detach()
hidden, _ = self.encoder.transformer(x)
masked_hidden = hidden * mask
avg_hidden = torch.sum(masked_hidden, dim=1) / (
......@@ -88,9 +89,9 @@ class Discriminator(torch.nn.Module):
def forward(self, x):
if self.cached_mode:
avg_hidden = x.to(device)
avg_hidden = x.to(self.device)
else:
avg_hidden = self.avg_representation(x.to(device))
avg_hidden = self.avg_representation(x.to(self.device))
logits = self.classifier_head(avg_hidden)
probs = F.log_softmax(logits, dim=-1)
......@@ -152,7 +153,7 @@ def cached_collate_fn(data):
def train_epoch(data_loader, discriminator, optimizer,
epoch=0, log_interval=10):
epoch=0, log_interval=10, device='cpu'):
samples_so_far = 0
discriminator.train_custom()
for batch_idx, (input_t, target_t) in enumerate(data_loader):
......@@ -177,7 +178,7 @@ def train_epoch(data_loader, discriminator, optimizer,
)
def evaluate_performance(data_loader, discriminator):
def evaluate_performance(data_loader, discriminator, device='cpu'):
discriminator.eval()
test_loss = 0
correct = 0
......@@ -202,7 +203,7 @@ def evaluate_performance(data_loader, discriminator):
)
def predict(input_sentence, model, classes, cached=False):
def predict(input_sentence, model, classes, cached=False, device='cpu'):
input_t = model.tokenizer.encode(input_sentence)
input_t = torch.tensor([input_t], dtype=torch.long, device=device)
if cached:
......@@ -216,7 +217,8 @@ def predict(input_sentence, model, classes, cached=False):
))
def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
def get_cached_data_loader(dataset, batch_size, discriminator,
shuffle=False, device='cpu'):
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
collate_fn=collate_fn)
......@@ -244,7 +246,6 @@ def train_discriminator(
dataset, dataset_fp=None, pretrained_model="gpt2-medium",
epochs=10, batch_size=64, log_interval=10,
save_model=False, cached=False, no_cuda=False):
global device
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
print("Preprocessing {} dataset...".format(dataset))
......@@ -258,7 +259,8 @@ def train_discriminator(
discriminator = Discriminator(
class_size=len(idx2class),
pretrained_model=pretrained_model,
cached_mode=cached
cached_mode=cached,
device=device
).to(device)
text = torchtext_data.Field()
......@@ -309,7 +311,8 @@ def train_discriminator(
discriminator = Discriminator(
class_size=len(idx2class),
pretrained_model=pretrained_model,
cached_mode=cached
cached_mode=cached,
device=device
).to(device)
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
......@@ -368,7 +371,8 @@ def train_discriminator(
discriminator = Discriminator(
class_size=len(idx2class),
pretrained_model=pretrained_model,
cached_mode=cached
cached_mode=cached,
device=device
).to(device)
x = []
......@@ -431,7 +435,8 @@ def train_discriminator(
discriminator = Discriminator(
class_size=len(idx2class),
pretrained_model=pretrained_model,
cached_mode=cached
cached_mode=cached,
device=device
).to(device)
x = []
......@@ -494,11 +499,12 @@ def train_discriminator(
start = time.time()
train_loader = get_cached_data_loader(
train_dataset, batch_size, discriminator, shuffle=True
train_dataset, batch_size, discriminator,
shuffle=True, device=device
)
test_loader = get_cached_data_loader(
test_dataset, batch_size, discriminator
test_dataset, batch_size, discriminator, device=device
)
end = time.time()
......@@ -529,18 +535,21 @@ def train_discriminator(
data_loader=train_loader,
optimizer=optimizer,
epoch=epoch,
log_interval=log_interval
log_interval=log_interval,
device=device
)
evaluate_performance(
data_loader=test_loader,
discriminator=discriminator
discriminator=discriminator,
device=device
)
end = time.time()
print("Epoch took: {:.3f}s".format(end - start))
print("\nExample prediction")
predict(example_sentence, discriminator, idx2class, cached)
predict(example_sentence, discriminator, idx2class,
cached=cached, device=device)
if save_model:
# torch.save(discriminator.state_dict(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment