Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
01b1a053
Commit
01b1a053
authored
Nov 03, 2018
by
thomwolf
Browse files
Merge branch 'master' of
https://github.com/huggingface/pytorch-pretrained-BERT
parents
8aa22af0
72ab1039
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
3 deletions
+61
-3
modeling_pytorch.py
modeling_pytorch.py
+16
-3
optimization_test_pytorch.py
optimization_test_pytorch.py
+45
-0
No files found.
modeling_pytorch.py
View file @
01b1a053
...
@@ -494,9 +494,22 @@ class BertForQuestionAnswering(nn.Module):
...
@@ -494,9 +494,22 @@ class BertForQuestionAnswering(nn.Module):
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
start_logits
,
end_logits
=
logits
.
split
(
1
,
dim
=-
1
)
if
start_positions
is
not
None
and
end_positions
is
not
None
:
if
start_positions
is
not
None
and
end_positions
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
#loss_fct = CrossEntropyLoss()
start_loss
=
loss_fct
(
start_logits
,
start_positions
)
#start_loss = loss_fct(start_logits, start_positions)
end_loss
=
loss_fct
(
end_logits
,
end_positions
)
#end_loss = loss_fct(end_logits, end_positions)
batch_size
,
seq_length
=
input_ids
.
size
()
def
compute_loss
(
logits
,
positions
):
max_position
=
positions
.
max
().
item
()
one_hot
=
torch
.
FloatTensor
(
batch_size
,
max
(
max_position
,
seq_length
)
+
1
).
zero_
()
one_hot
=
one_hot
.
scatter
(
1
,
positions
.
cpu
(),
1
)
# Second argument need to be LongTensor and not cuda.LongTensor
one_hot
=
one_hot
[:,
:
seq_length
].
to
(
input_ids
.
device
)
log_probs
=
nn
.
functional
.
log_softmax
(
logits
,
dim
=
-
1
).
view
(
batch_size
,
seq_length
)
loss
=
-
torch
.
mean
(
torch
.
sum
(
one_hot
*
log_probs
),
dim
=
-
1
)
return
loss
start_loss
=
compute_loss
(
start_logits
,
start_positions
)
end_loss
=
compute_loss
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2
total_loss
=
(
start_loss
+
end_loss
)
/
2
return
total_loss
,
(
start_logits
,
end_logits
)
return
total_loss
,
(
start_logits
,
end_logits
)
else
:
else
:
...
...
optimization_test_pytorch.py
0 → 100644
View file @
01b1a053
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
optimization_pytorch
as
optimization
import
torch
import
unittest
class
OptimizationTest
(
unittest
.
TestCase
):
def
assertListAlmostEqual
(
self
,
list1
,
list2
,
tol
):
self
.
assertEqual
(
len
(
list1
),
len
(
list2
))
for
a
,
b
in
zip
(
list1
,
list2
):
self
.
assertAlmostEqual
(
a
,
b
,
delta
=
tol
)
def
test_adam
(
self
):
w
=
torch
.
tensor
([
0.1
,
-
0.2
,
-
0.1
],
requires_grad
=
True
)
x
=
torch
.
tensor
([
0.4
,
0.2
,
-
0.5
])
criterion
=
torch
.
nn
.
MSELoss
(
reduction
=
'elementwise_mean'
)
optimizer
=
optimization
.
BERTAdam
(
params
=
{
w
},
lr
=
0.2
,
schedule
=
'warmup_linear'
,
warmup
=
0.1
,
t_total
=
100
)
for
_
in
range
(
100
):
# TODO Solve: reduction='elementwise_mean'=True not taken into account so division by x.size(0) is necessary
loss
=
criterion
(
x
,
w
)
/
x
.
size
(
0
)
loss
.
backward
()
optimizer
.
step
()
self
.
assertListAlmostEqual
(
w
.
tolist
(),
[
0.4
,
0.2
,
-
0.5
],
tol
=
1e-2
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment