Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c164064e
Unverified
Commit
c164064e
authored
Jul 29, 2021
by
chutaklee
Committed by
GitHub
Jul 29, 2021
Browse files
Fix distiller.py (#12910)
* fix distiller * fix style
parent
1da782cb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
8 deletions
+6
-8
examples/research_projects/distillation/distiller.py
examples/research_projects/distillation/distiller.py
+6
-8
No files found.
examples/research_projects/distillation/distiller.py
View file @
c164064e
...
@@ -380,21 +380,19 @@ class Distiller:
...
@@ -380,21 +380,19 @@ class Distiller:
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
"""
"""
if
self
.
mlm
:
if
self
.
mlm
:
s
_logits
,
s_hidden_state
s
=
self
.
student
(
s
tudent_output
s
=
self
.
student
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# (bs, seq_length, voc_size)
)
# (bs, seq_length, voc_size)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
t
_logits
,
t_hidden_state
s
=
self
.
teacher
(
t
eacher_output
s
=
self
.
teacher
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
# (bs, seq_length, voc_size)
)
# (bs, seq_length, voc_size)
else
:
else
:
s_logits
,
_
,
s_hidden_states
=
self
.
student
(
student_outputs
=
self
.
student
(
input_ids
=
input_ids
,
attention_mask
=
None
)
# (bs, seq_length, voc_size)
input_ids
=
input_ids
,
attention_mask
=
None
)
# (bs, seq_length, voc_size)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
t
_logits
,
_
,
t_hidden_states
=
self
.
teacher
(
t
eacher_outputs
=
self
.
teacher
(
input_ids
=
input_ids
,
attention_mask
=
None
)
# (bs, seq_length, voc_size)
input_ids
=
input_ids
,
attention_mask
=
None
s_logits
,
s_hidden_states
=
student_outputs
[
"logits"
],
student_outputs
[
"hidden_states"
]
)
# (bs, seq_length, voc_size)
t_logits
,
t_hidden_states
=
teacher_outputs
[
"logits"
],
teacher_outputs
[
"hidden_states"
]
assert
s_logits
.
size
()
==
t_logits
.
size
()
assert
s_logits
.
size
()
==
t_logits
.
size
()
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
# https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment