Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
821de121
Commit
821de121
authored
Nov 27, 2019
by
piero
Committed by
Julien Chaumond
Dec 03, 2019
Browse files
Minor changes
parent
7469d03b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
2 deletions
+1
-2
examples/run_pplm_discrim_train.py
examples/run_pplm_discrim_train.py
+1
-2
No files found.
examples/run_pplm_discrim_train.py
View file @
821de121
...
@@ -72,7 +72,6 @@ class Discriminator(torch.nn.Module):
...
@@ -72,7 +72,6 @@ class Discriminator(torch.nn.Module):
def
train_custom
(
self
):
def
train_custom
(
self
):
for
param
in
self
.
encoder
.
parameters
():
for
param
in
self
.
encoder
.
parameters
():
param
.
requires_grad
=
False
param
.
requires_grad
=
False
pass
self
.
classifier_head
.
train
()
self
.
classifier_head
.
train
()
def
avg_representation
(
self
,
x
):
def
avg_representation
(
self
,
x
):
...
@@ -122,7 +121,7 @@ def collate_fn(data):
...
@@ -122,7 +121,7 @@ def collate_fn(data):
padded_sequences
=
torch
.
zeros
(
padded_sequences
=
torch
.
zeros
(
len
(
sequences
),
len
(
sequences
),
max
(
lengths
)
max
(
lengths
)
).
long
()
# padding
index
0
).
long
()
# padding
value =
0
for
i
,
seq
in
enumerate
(
sequences
):
for
i
,
seq
in
enumerate
(
sequences
):
end
=
lengths
[
i
]
end
=
lengths
[
i
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment