Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7edb51f3
Commit
7edb51f3
authored
Dec 03, 2019
by
Julien Chaumond
Browse files
[pplm] split classif head into its own file
parent
8101924a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
17 deletions
+20
-17
examples/pplm/pplm_classification_head.py
examples/pplm/pplm_classification_head.py
+18
-0
examples/pplm/run_pplm.py
examples/pplm/run_pplm.py
+1
-1
examples/pplm/run_pplm_discrim_train.py
examples/pplm/run_pplm_discrim_train.py
+1
-16
No files found.
examples/pplm/pplm_classification_head.py
0 → 100644
View file @
7edb51f3
import
torch
class
ClassificationHead
(
torch
.
nn
.
Module
):
"""Classification Head for transformer encoders"""
def
__init__
(
self
,
class_size
,
embed_size
):
super
(
ClassificationHead
,
self
).
__init__
()
self
.
class_size
=
class_size
self
.
embed_size
=
embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self
.
mlp
=
torch
.
nn
.
Linear
(
embed_size
,
class_size
)
def
forward
(
self
,
hidden_state
):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits
=
self
.
mlp
(
hidden_state
)
return
logits
examples/pplm/run_pplm.py
View file @
7edb51f3
...
@@ -33,10 +33,10 @@ import torch.nn.functional as F
...
@@ -33,10 +33,10 @@ import torch.nn.functional as F
from
torch.autograd
import
Variable
from
torch.autograd
import
Variable
from
tqdm
import
trange
from
tqdm
import
trange
from
examples.run_pplm_discrim_train
import
ClassificationHead
from
transformers
import
GPT2Tokenizer
from
transformers
import
GPT2Tokenizer
from
transformers.file_utils
import
cached_path
from
transformers.file_utils
import
cached_path
from
transformers.modeling_gpt2
import
GPT2LMHeadModel
from
transformers.modeling_gpt2
import
GPT2LMHeadModel
from
pplm_classification_head
import
ClassificationHead
PPLM_BOW
=
1
PPLM_BOW
=
1
PPLM_DISCRIM
=
2
PPLM_DISCRIM
=
2
...
...
examples/pplm/run_pplm_discrim_train.py
View file @
7edb51f3
...
@@ -21,6 +21,7 @@ from torchtext import datasets
...
@@ -21,6 +21,7 @@ from torchtext import datasets
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
from
transformers
import
GPT2Tokenizer
,
GPT2LMHeadModel
from
transformers
import
GPT2Tokenizer
,
GPT2LMHeadModel
from
pplm_classification_head
import
ClassificationHead
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
np
.
random
.
seed
(
0
)
np
.
random
.
seed
(
0
)
...
@@ -29,22 +30,6 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
...
@@ -29,22 +30,6 @@ example_sentence = "This is incredible! I love it, this is the best chicken I ha
max_length_seq
=
100
max_length_seq
=
100
class
ClassificationHead
(
torch
.
nn
.
Module
):
"""Classification Head for transformer encoders"""
def
__init__
(
self
,
class_size
,
embed_size
):
super
(
ClassificationHead
,
self
).
__init__
()
self
.
class_size
=
class_size
self
.
embed_size
=
embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self
.
mlp
=
torch
.
nn
.
Linear
(
embed_size
,
class_size
)
def
forward
(
self
,
hidden_state
):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits
=
self
.
mlp
(
hidden_state
)
return
logits
class
Discriminator
(
torch
.
nn
.
Module
):
class
Discriminator
(
torch
.
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment