Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7cbf0f72
Unverified
Commit
7cbf0f72
authored
Sep 20, 2020
by
Stas Bekman
Committed by
GitHub
Sep 20, 2020
Browse files
examples/seq2seq/__init__.py mutates sys.path (#7194)
parent
a4faecea
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
58 additions
and
113 deletions
+58
-113
examples/seq2seq/__init__.py
examples/seq2seq/__init__.py
+5
-0
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+12
-28
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+18
-39
examples/seq2seq/run_distributed_eval.py
examples/seq2seq/run_distributed_eval.py
+11
-25
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+1
-4
examples/seq2seq/run_eval_search.py
examples/seq2seq/run_eval_search.py
+1
-5
examples/seq2seq/test_bash_script.py
examples/seq2seq/test_bash_script.py
+4
-5
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+6
-7
No files found.
examples/seq2seq/__init__.py
View file @
7cbf0f72
import
os
import
sys
sys
.
path
.
insert
(
1
,
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)))
examples/seq2seq/distillation.py
View file @
7cbf0f72
...
...
@@ -10,16 +10,13 @@ import torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
finetune
import
SummarizationModule
,
TranslationModule
from
finetune
import
main
as
ft_main
from
initialization_utils
import
copy_layers
,
init_student
from
lightning_base
import
generic_train
from
transformers
import
AutoModelForSeq2SeqLM
,
MBartTokenizer
,
T5Config
,
T5ForConditionalGeneration
from
transformers.modeling_bart
import
shift_tokens_right
try
:
from
.finetune
import
SummarizationModule
,
TranslationModule
from
.finetune
import
main
as
ft_main
from
.initialization_utils
import
copy_layers
,
init_student
from
.utils
import
(
from
utils
import
(
any_requires_grad
,
assert_all_frozen
,
calculate_bleu
,
...
...
@@ -27,20 +24,7 @@ try:
label_smoothed_nll_loss
,
pickle_load
,
use_task_specific_params
,
)
except
ImportError
:
from
finetune
import
SummarizationModule
,
TranslationModule
from
finetune
import
main
as
ft_main
from
initialization_utils
import
copy_layers
,
init_student
from
utils
import
(
any_requires_grad
,
assert_all_frozen
,
calculate_bleu
,
freeze_params
,
label_smoothed_nll_loss
,
pickle_load
,
use_task_specific_params
,
)
)
class
BartSummarizationDistiller
(
SummarizationModule
):
...
...
examples/seq2seq/finetune.py
View file @
7cbf0f72
...
...
@@ -12,33 +12,11 @@ import pytorch_lightning as pl
import
torch
from
torch.utils.data
import
DataLoader
from
callbacks
import
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
from
lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
from
transformers
import
MBartTokenizer
,
T5ForConditionalGeneration
from
transformers.modeling_bart
import
shift_tokens_right
try
:
from
.callbacks
import
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
from
.utils
import
(
ROUGE_KEYS
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
,
assert_all_frozen
,
calculate_bleu
,
calculate_rouge
,
flatten_list
,
freeze_params
,
get_git_info
,
label_smoothed_nll_loss
,
lmap
,
pickle_save
,
save_git_info
,
save_json
,
use_task_specific_params
,
)
except
ImportError
:
from
callbacks
import
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
from
utils
import
(
from
utils
import
(
ROUGE_KEYS
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
,
...
...
@@ -54,7 +32,8 @@ except ImportError:
save_git_info
,
save_json
,
use_task_specific_params
,
)
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
examples/seq2seq/run_distributed_eval.py
View file @
7cbf0f72
...
...
@@ -11,24 +11,7 @@ from torch.utils.data import DataLoader
from
tqdm
import
tqdm
from
transformers
import
AutoModelForSeq2SeqLM
,
AutoTokenizer
logger
=
getLogger
(
__name__
)
try
:
from
.utils
import
(
Seq2SeqDataset
,
calculate_bleu
,
calculate_rouge
,
lmap
,
load_json
,
parse_numeric_n_bool_cl_kwargs
,
save_json
,
use_task_specific_params
,
write_txt_file
,
)
except
ImportError
:
from
utils
import
(
from
utils
import
(
Seq2SeqDataset
,
calculate_bleu
,
calculate_rouge
,
...
...
@@ -38,7 +21,10 @@ except ImportError:
save_json
,
use_task_specific_params
,
write_txt_file
,
)
)
logger
=
getLogger
(
__name__
)
def
eval_data_dir
(
...
...
examples/seq2seq/run_eval.py
View file @
7cbf0f72
...
...
@@ -11,14 +11,11 @@ import torch
from
tqdm
import
tqdm
from
transformers
import
AutoModelForSeq2SeqLM
,
AutoTokenizer
from
utils
import
calculate_bleu
,
calculate_rouge
,
parse_numeric_n_bool_cl_kwargs
,
use_task_specific_params
logger
=
getLogger
(
__name__
)
try
:
from
.utils
import
calculate_bleu
,
calculate_rouge
,
parse_numeric_n_bool_cl_kwargs
,
use_task_specific_params
except
ImportError
:
from
utils
import
calculate_bleu
,
calculate_rouge
,
parse_numeric_n_bool_cl_kwargs
,
use_task_specific_params
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
...
examples/seq2seq/run_eval_search.py
View file @
7cbf0f72
...
...
@@ -4,11 +4,7 @@ import operator
import
sys
from
collections
import
OrderedDict
try
:
from
.run_eval
import
datetime_now
,
run_generate
except
ImportError
:
from
run_eval
import
datetime_now
,
run_generate
from
run_eval
import
datetime_now
,
run_generate
# A table of supported tasks and the list of scores in the order of importance to be sorted by.
...
...
examples/seq2seq/test_bash_script.py
View file @
7cbf0f72
...
...
@@ -10,13 +10,12 @@ import pytorch_lightning as pl
import
timeout_decorator
import
torch
from
distillation
import
BartSummarizationDistiller
,
distill_main
from
finetune
import
SummarizationModule
,
main
from
test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers.testing_utils
import
slow
from
.distillation
import
BartSummarizationDistiller
,
distill_main
from
.finetune
import
SummarizationModule
,
main
from
.test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
.utils
import
load_json
from
utils
import
load_json
MODEL_NAME
=
MBART_TINY
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
7cbf0f72
...
...
@@ -12,16 +12,15 @@ import pytorch_lightning as pl
import
torch
import
lightning_base
from
convert_pl_checkpoint_to_hf
import
convert_pl_to_hf
from
distillation
import
distill_main
,
evaluate_checkpoint
from
finetune
import
SummarizationModule
,
main
from
run_eval
import
generate_summaries_or_translations
,
run_generate
from
run_eval_search
import
run_search
from
transformers
import
AutoConfig
,
AutoModelForSeq2SeqLM
from
transformers.hf_api
import
HfApi
from
transformers.testing_utils
import
CaptureStderr
,
CaptureStdout
,
require_multigpu
,
require_torch_and_cuda
,
slow
from
.convert_pl_checkpoint_to_hf
import
convert_pl_to_hf
from
.distillation
import
distill_main
,
evaluate_checkpoint
from
.finetune
import
SummarizationModule
,
main
from
.run_eval
import
generate_summaries_or_translations
,
run_generate
from
.run_eval_search
import
run_search
from
.utils
import
label_smoothed_nll_loss
,
lmap
,
load_json
from
utils
import
label_smoothed_nll_loss
,
lmap
,
load_json
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment