Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7cbf0f72
Unverified
Commit
7cbf0f72
authored
Sep 20, 2020
by
Stas Bekman
Committed by
GitHub
Sep 20, 2020
Browse files
examples/seq2seq/__init__.py mutates sys.path (#7194)
parent
a4faecea
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
58 additions
and
113 deletions
+58
-113
examples/seq2seq/__init__.py
examples/seq2seq/__init__.py
+5
-0
examples/seq2seq/distillation.py
examples/seq2seq/distillation.py
+12
-28
examples/seq2seq/finetune.py
examples/seq2seq/finetune.py
+18
-39
examples/seq2seq/run_distributed_eval.py
examples/seq2seq/run_distributed_eval.py
+11
-25
examples/seq2seq/run_eval.py
examples/seq2seq/run_eval.py
+1
-4
examples/seq2seq/run_eval_search.py
examples/seq2seq/run_eval_search.py
+1
-5
examples/seq2seq/test_bash_script.py
examples/seq2seq/test_bash_script.py
+4
-5
examples/seq2seq/test_seq2seq_examples.py
examples/seq2seq/test_seq2seq_examples.py
+6
-7
No files found.
examples/seq2seq/__init__.py
View file @
7cbf0f72
import
os
import
sys
sys
.
path
.
insert
(
1
,
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)))
examples/seq2seq/distillation.py
View file @
7cbf0f72
...
@@ -10,37 +10,21 @@ import torch
...
@@ -10,37 +10,21 @@ import torch
from
torch
import
nn
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
finetune
import
SummarizationModule
,
TranslationModule
from
finetune
import
main
as
ft_main
from
initialization_utils
import
copy_layers
,
init_student
from
lightning_base
import
generic_train
from
lightning_base
import
generic_train
from
transformers
import
AutoModelForSeq2SeqLM
,
MBartTokenizer
,
T5Config
,
T5ForConditionalGeneration
from
transformers
import
AutoModelForSeq2SeqLM
,
MBartTokenizer
,
T5Config
,
T5ForConditionalGeneration
from
transformers.modeling_bart
import
shift_tokens_right
from
transformers.modeling_bart
import
shift_tokens_right
from
utils
import
(
any_requires_grad
,
try
:
assert_all_frozen
,
from
.finetune
import
SummarizationModule
,
TranslationModule
calculate_bleu
,
from
.finetune
import
main
as
ft_main
freeze_params
,
from
.initialization_utils
import
copy_layers
,
init_student
label_smoothed_nll_loss
,
from
.utils
import
(
pickle_load
,
any_requires_grad
,
use_task_specific_params
,
assert_all_frozen
,
)
calculate_bleu
,
freeze_params
,
label_smoothed_nll_loss
,
pickle_load
,
use_task_specific_params
,
)
except
ImportError
:
from
finetune
import
SummarizationModule
,
TranslationModule
from
finetune
import
main
as
ft_main
from
initialization_utils
import
copy_layers
,
init_student
from
utils
import
(
any_requires_grad
,
assert_all_frozen
,
calculate_bleu
,
freeze_params
,
label_smoothed_nll_loss
,
pickle_load
,
use_task_specific_params
,
)
class
BartSummarizationDistiller
(
SummarizationModule
):
class
BartSummarizationDistiller
(
SummarizationModule
):
...
...
examples/seq2seq/finetune.py
View file @
7cbf0f72
...
@@ -12,50 +12,29 @@ import pytorch_lightning as pl
...
@@ -12,50 +12,29 @@ import pytorch_lightning as pl
import
torch
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
callbacks
import
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
from
lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
from
lightning_base
import
BaseTransformer
,
add_generic_args
,
generic_train
from
transformers
import
MBartTokenizer
,
T5ForConditionalGeneration
from
transformers
import
MBartTokenizer
,
T5ForConditionalGeneration
from
transformers.modeling_bart
import
shift_tokens_right
from
transformers.modeling_bart
import
shift_tokens_right
from
utils
import
(
ROUGE_KEYS
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
,
assert_all_frozen
,
calculate_bleu
,
calculate_rouge
,
flatten_list
,
freeze_params
,
get_git_info
,
label_smoothed_nll_loss
,
lmap
,
pickle_save
,
save_git_info
,
save_json
,
use_task_specific_params
,
)
try
:
from
.callbacks
import
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
from
.utils
import
(
ROUGE_KEYS
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
,
assert_all_frozen
,
calculate_bleu
,
calculate_rouge
,
flatten_list
,
freeze_params
,
get_git_info
,
label_smoothed_nll_loss
,
lmap
,
pickle_save
,
save_git_info
,
save_json
,
use_task_specific_params
,
)
except
ImportError
:
from
callbacks
import
Seq2SeqLoggingCallback
,
get_checkpoint_callback
,
get_early_stopping_callback
from
utils
import
(
ROUGE_KEYS
,
LegacySeq2SeqDataset
,
Seq2SeqDataset
,
assert_all_frozen
,
calculate_bleu
,
calculate_rouge
,
flatten_list
,
freeze_params
,
get_git_info
,
label_smoothed_nll_loss
,
lmap
,
pickle_save
,
save_git_info
,
save_json
,
use_task_specific_params
,
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
examples/seq2seq/run_distributed_eval.py
View file @
7cbf0f72
...
@@ -11,35 +11,21 @@ from torch.utils.data import DataLoader
...
@@ -11,35 +11,21 @@ from torch.utils.data import DataLoader
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
AutoModelForSeq2SeqLM
,
AutoTokenizer
from
transformers
import
AutoModelForSeq2SeqLM
,
AutoTokenizer
from
utils
import
(
Seq2SeqDataset
,
calculate_bleu
,
calculate_rouge
,
lmap
,
load_json
,
parse_numeric_n_bool_cl_kwargs
,
save_json
,
use_task_specific_params
,
write_txt_file
,
)
logger
=
getLogger
(
__name__
)
logger
=
getLogger
(
__name__
)
try
:
from
.utils
import
(
Seq2SeqDataset
,
calculate_bleu
,
calculate_rouge
,
lmap
,
load_json
,
parse_numeric_n_bool_cl_kwargs
,
save_json
,
use_task_specific_params
,
write_txt_file
,
)
except
ImportError
:
from
utils
import
(
Seq2SeqDataset
,
calculate_bleu
,
calculate_rouge
,
lmap
,
load_json
,
parse_numeric_n_bool_cl_kwargs
,
save_json
,
use_task_specific_params
,
write_txt_file
,
)
def
eval_data_dir
(
def
eval_data_dir
(
data_dir
,
data_dir
,
...
...
examples/seq2seq/run_eval.py
View file @
7cbf0f72
...
@@ -11,14 +11,11 @@ import torch
...
@@ -11,14 +11,11 @@ import torch
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
AutoModelForSeq2SeqLM
,
AutoTokenizer
from
transformers
import
AutoModelForSeq2SeqLM
,
AutoTokenizer
from
utils
import
calculate_bleu
,
calculate_rouge
,
parse_numeric_n_bool_cl_kwargs
,
use_task_specific_params
logger
=
getLogger
(
__name__
)
logger
=
getLogger
(
__name__
)
try
:
from
.utils
import
calculate_bleu
,
calculate_rouge
,
parse_numeric_n_bool_cl_kwargs
,
use_task_specific_params
except
ImportError
:
from
utils
import
calculate_bleu
,
calculate_rouge
,
parse_numeric_n_bool_cl_kwargs
,
use_task_specific_params
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
DEFAULT_DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
...
...
examples/seq2seq/run_eval_search.py
View file @
7cbf0f72
...
@@ -4,11 +4,7 @@ import operator
...
@@ -4,11 +4,7 @@ import operator
import
sys
import
sys
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
run_eval
import
datetime_now
,
run_generate
try
:
from
.run_eval
import
datetime_now
,
run_generate
except
ImportError
:
from
run_eval
import
datetime_now
,
run_generate
# A table of supported tasks and the list of scores in the order of importance to be sorted by.
# A table of supported tasks and the list of scores in the order of importance to be sorted by.
...
...
examples/seq2seq/test_bash_script.py
View file @
7cbf0f72
...
@@ -10,13 +10,12 @@ import pytorch_lightning as pl
...
@@ -10,13 +10,12 @@ import pytorch_lightning as pl
import
timeout_decorator
import
timeout_decorator
import
torch
import
torch
from
distillation
import
BartSummarizationDistiller
,
distill_main
from
finetune
import
SummarizationModule
,
main
from
test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers
import
BartForConditionalGeneration
,
MarianMTModel
from
transformers.testing_utils
import
slow
from
transformers.testing_utils
import
slow
from
utils
import
load_json
from
.distillation
import
BartSummarizationDistiller
,
distill_main
from
.finetune
import
SummarizationModule
,
main
from
.test_seq2seq_examples
import
CUDA_AVAILABLE
,
MBART_TINY
from
.utils
import
load_json
MODEL_NAME
=
MBART_TINY
MODEL_NAME
=
MBART_TINY
...
...
examples/seq2seq/test_seq2seq_examples.py
View file @
7cbf0f72
...
@@ -12,16 +12,15 @@ import pytorch_lightning as pl
...
@@ -12,16 +12,15 @@ import pytorch_lightning as pl
import
torch
import
torch
import
lightning_base
import
lightning_base
from
convert_pl_checkpoint_to_hf
import
convert_pl_to_hf
from
distillation
import
distill_main
,
evaluate_checkpoint
from
finetune
import
SummarizationModule
,
main
from
run_eval
import
generate_summaries_or_translations
,
run_generate
from
run_eval_search
import
run_search
from
transformers
import
AutoConfig
,
AutoModelForSeq2SeqLM
from
transformers
import
AutoConfig
,
AutoModelForSeq2SeqLM
from
transformers.hf_api
import
HfApi
from
transformers.hf_api
import
HfApi
from
transformers.testing_utils
import
CaptureStderr
,
CaptureStdout
,
require_multigpu
,
require_torch_and_cuda
,
slow
from
transformers.testing_utils
import
CaptureStderr
,
CaptureStdout
,
require_multigpu
,
require_torch_and_cuda
,
slow
from
utils
import
label_smoothed_nll_loss
,
lmap
,
load_json
from
.convert_pl_checkpoint_to_hf
import
convert_pl_to_hf
from
.distillation
import
distill_main
,
evaluate_checkpoint
from
.finetune
import
SummarizationModule
,
main
from
.run_eval
import
generate_summaries_or_translations
,
run_generate
from
.run_eval_search
import
run_search
from
.utils
import
label_smoothed_nll_loss
,
lmap
,
load_json
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment