Commit 7edb51f3 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[pplm] split classif head into its own file

parent 8101924a
import torch
class ClassificationHead(torch.nn.Module):
"""Classification Head for transformer encoders"""
def __init__(self, class_size, embed_size):
super(ClassificationHead, self).__init__()
self.class_size = class_size
self.embed_size = embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self.mlp = torch.nn.Linear(embed_size, class_size)
def forward(self, hidden_state):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits = self.mlp(hidden_state)
return logits
...@@ -33,10 +33,10 @@ import torch.nn.functional as F ...@@ -33,10 +33,10 @@ import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable
from tqdm import trange from tqdm import trange
from examples.run_pplm_discrim_train import ClassificationHead
from transformers import GPT2Tokenizer from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModel from transformers.modeling_gpt2 import GPT2LMHeadModel
from pplm_classification_head import ClassificationHead
PPLM_BOW = 1 PPLM_BOW = 1
PPLM_DISCRIM = 2 PPLM_DISCRIM = 2
......
...@@ -21,6 +21,7 @@ from torchtext import datasets ...@@ -21,6 +21,7 @@ from torchtext import datasets
from tqdm import tqdm, trange from tqdm import tqdm, trange
from transformers import GPT2Tokenizer, GPT2LMHeadModel from transformers import GPT2Tokenizer, GPT2LMHeadModel
from pplm_classification_head import ClassificationHead
torch.manual_seed(0) torch.manual_seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -29,22 +30,6 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha ...@@ -29,22 +30,6 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
max_length_seq = 100 max_length_seq = 100
class ClassificationHead(torch.nn.Module):
"""Classification Head for transformer encoders"""
def __init__(self, class_size, embed_size):
super(ClassificationHead, self).__init__()
self.class_size = class_size
self.embed_size = embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self.mlp = torch.nn.Linear(embed_size, class_size)
def forward(self, hidden_state):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits = self.mlp(hidden_state)
return logits
class Discriminator(torch.nn.Module): class Discriminator(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment