Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4bd7be9a
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "eb4b015f1219e9b27c9ab5766ff24056a2227a68"
Unverified
Commit
4bd7be9a
authored
Aug 26, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 26, 2020
Browse files
s2s distillation uses AutoModelForSeqToSeqLM (#6761)
parent
05e7150a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+5
-6
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+1
-0
No files found.
examples/seq2seq/distillation.py
View file @
4bd7be9a
...
@@ -10,7 +10,7 @@ from torch import nn
...
@@ -10,7 +10,7 @@ from torch import nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
lightning_base
import
generic_train
from
lightning_base
import
generic_train
from
transformers
import
BartConfig
,
BartForConditionalGeneration
,
MBartTokenizer
,
T5Config
,
T5ForConditionalGeneration
from
transformers
import
AutoModelForSeq2SeqLM
,
MBartTokenizer
,
T5Config
,
T5ForConditionalGeneration
try
:
try
:
...
@@ -74,22 +74,22 @@ class BartSummarizationDistiller(SummarizationModule):
...
@@ -74,22 +74,22 @@ class BartSummarizationDistiller(SummarizationModule):
def
pre_init
(
self
,
hparams
):
def
pre_init
(
self
,
hparams
):
self
.
output_dir
=
Path
(
hparams
.
output_dir
)
self
.
output_dir
=
Path
(
hparams
.
output_dir
)
self
.
output_dir
.
mkdir
(
exist_ok
=
True
)
self
.
output_dir
.
mkdir
(
exist_ok
=
True
)
teacher
=
BartForConditionalGeneration
.
from_pretrained
(
hparams
.
teacher
).
eval
()
teacher
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
hparams
.
teacher
).
eval
()
student_updates
=
{
student_updates
=
{
"decoder_layers"
:
hparams
.
student_decoder_layers
,
"decoder_layers"
:
hparams
.
student_decoder_layers
,
"encoder_layers"
:
hparams
.
student_encoder_layers
,
"encoder_layers"
:
hparams
.
student_encoder_layers
,
}
}
if
hparams
.
length_penalty
!=
-
1
:
if
hparams
.
length_penalty
!=
-
1
:
student_updates
[
"length_penalty"
]
=
hparams
.
length_penalty
student_updates
[
"length_penalty"
]
=
hparams
.
length_penalty
d_layers_to_copy
=
get_layers_to_copy
(
student_updates
[
"decoder_layers"
],
teacher
.
config
.
decoder_layers
)
d_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"decoder_layers"
],
teacher
.
config
.
decoder_layers
)
e_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"encoder_layers"
],
teacher
.
config
.
encoder_layers
)
e_layers_to_copy
:
List
=
get_layers_to_copy
(
student_updates
[
"encoder_layers"
],
teacher
.
config
.
encoder_layers
)
hparams
.
d_layer_to_copy
=
d_layers_to_copy
hparams
.
d_layer_to_copy
=
d_layers_to_copy
hparams
.
e_layer_to_copy
=
e_layers_to_copy
hparams
.
e_layer_to_copy
=
e_layers_to_copy
kw
=
teacher
.
config
.
to_diff_dict
()
kw
=
teacher
.
config
.
to_diff_dict
()
kw
.
update
(
student_updates
)
kw
.
update
(
student_updates
)
# Copy weights
# Copy weights
student_cfg
=
BartConfig
(
**
kw
)
student_cfg
=
teacher
.
config_class
(
**
kw
)
student
=
BartForConditionalGeneration
(
student_cfg
)
student
=
type
(
teacher
)
(
student_cfg
)
student
,
_
=
init_student
(
student
,
teacher
)
student
,
_
=
init_student
(
student
,
teacher
)
save_dir
=
self
.
output_dir
.
joinpath
(
"student"
)
save_dir
=
self
.
output_dir
.
joinpath
(
"student"
)
self
.
copy_to_student
(
d_layers_to_copy
,
e_layers_to_copy
,
hparams
,
student
,
teacher
)
self
.
copy_to_student
(
d_layers_to_copy
,
e_layers_to_copy
,
hparams
,
student
,
teacher
)
...
@@ -252,7 +252,6 @@ class BartTranslationDistiller(BartSummarizationDistiller):
...
@@ -252,7 +252,6 @@ class BartTranslationDistiller(BartSummarizationDistiller):
def
__init__
(
self
,
hparams
,
**
kwargs
):
def
__init__
(
self
,
hparams
,
**
kwargs
):
super
().
__init__
(
hparams
,
**
kwargs
)
super
().
__init__
(
hparams
,
**
kwargs
)
assert
isinstance
(
self
.
tokenizer
,
MBartTokenizer
)
assert
hparams
.
src_lang
is
not
None
assert
hparams
.
src_lang
is
not
None
assert
hparams
.
tgt_lang
is
not
None
assert
hparams
.
tgt_lang
is
not
None
self
.
dataset_kwargs
[
"src_lang"
]
=
hparams
.
src_lang
self
.
dataset_kwargs
[
"src_lang"
]
=
hparams
.
src_lang
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
4bd7be9a
...
@@ -186,6 +186,7 @@ class TestSummarizationDistiller(unittest.TestCase):
...
@@ -186,6 +186,7 @@ class TestSummarizationDistiller(unittest.TestCase):
tgt_lang
=
"ro_RO"
,
tgt_lang
=
"ro_RO"
,
)
)
model
=
self
.
_test_distiller_cli
(
updates
,
check_contents
=
False
)
model
=
self
.
_test_distiller_cli
(
updates
,
check_contents
=
False
)
assert
model
.
model
.
config
.
model_type
==
"mbart"
ckpts
=
list
(
Path
(
model
.
output_dir
).
glob
(
"*.ckpt"
))
ckpts
=
list
(
Path
(
model
.
output_dir
).
glob
(
"*.ckpt"
))
self
.
assertEqual
(
1
,
len
(
ckpts
))
self
.
assertEqual
(
1
,
len
(
ckpts
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment