Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3fd71c44
"vscode:/vscode.git/clone" did not exist on "e848b54730de9f5421337e3074f33b2e16d52a3a"
Commit
3fd71c44
authored
Dec 12, 2019
by
LysandreJik
Browse files
Update example scripts
parent
b72f9d34
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
4 deletions
+4
-4
examples/distillation/distiller.py
examples/distillation/distiller.py
+3
-3
examples/utils_ner.py
examples/utils_ner.py
+1
-1
No files found.
examples/distillation/distiller.py
View file @
3fd71c44
...
...
@@ -112,7 +112,7 @@ class Distiller:
self
.
last_log
=
0
self
.
ce_loss_fct
=
nn
.
KLDivLoss
(
reduction
=
'batchmean'
)
self
.
lm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
)
self
.
lm_loss_fct
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
1
00
)
if
self
.
alpha_mse
>
0.
:
self
.
mse_loss_fct
=
nn
.
MSELoss
(
reduction
=
'sum'
)
if
self
.
alpha_cos
>
0.
:
...
...
@@ -224,7 +224,7 @@ class Distiller:
_token_ids
=
_token_ids_mask
*
(
probs
==
0
).
long
()
+
_token_ids_real
*
(
probs
==
1
).
long
()
+
_token_ids_rand
*
(
probs
==
2
).
long
()
token_ids
=
token_ids
.
masked_scatter
(
pred_mask
,
_token_ids
)
mlm_labels
[
~
pred_mask
]
=
-
1
# previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
mlm_labels
[
~
pred_mask
]
=
-
1
00
# previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert
0
<=
token_ids
.
min
()
<=
token_ids
.
max
()
<
self
.
vocab_size
...
...
@@ -254,7 +254,7 @@ class Distiller:
attn_mask
=
(
torch
.
arange
(
token_ids
.
size
(
1
),
dtype
=
torch
.
long
,
device
=
lengths
.
device
)
<
lengths
[:,
None
])
clm_labels
=
token_ids
.
new
(
token_ids
.
size
()).
copy_
(
token_ids
)
clm_labels
[
~
attn_mask
]
=
-
1
# previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
clm_labels
[
~
attn_mask
]
=
-
1
00
# previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility
# sanity checks
assert
0
<=
token_ids
.
min
()
<=
token_ids
.
max
()
<
self
.
vocab_size
...
...
examples/utils_ner.py
View file @
3fd71c44
...
...
@@ -94,7 +94,7 @@ def convert_examples_to_features(examples,
pad_on_left
=
False
,
pad_token
=
0
,
pad_token_segment_id
=
0
,
pad_token_label_id
=-
1
,
pad_token_label_id
=-
1
00
,
sequence_a_segment_id
=
0
,
mask_padding_with_zero
=
True
):
""" Loads a data file into a list of `InputBatch`s
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment