Commit 821de121 authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Minor changes

parent 7469d03b
......@@ -72,7 +72,6 @@ class Discriminator(torch.nn.Module):
def train_custom(self):
for param in self.encoder.parameters():
param.requires_grad = False
pass
self.classifier_head.train()
def avg_representation(self, x):
......@@ -122,7 +121,7 @@ def collate_fn(data):
padded_sequences = torch.zeros(
len(sequences),
max(lengths)
).long() # padding index 0
).long() # padding value = 0
for i, seq in enumerate(sequences):
end = lengths[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment