Unverified Commit 7cbf0f72 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

examples/seq2seq/__init__.py mutates sys.path (#7194)

parent a4faecea
import os
import sys
sys.path.insert(1, os.path.dirname(os.path.realpath(__file__)))
...@@ -10,37 +10,21 @@ import torch ...@@ -10,37 +10,21 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
from initialization_utils import copy_layers, init_student
from lightning_base import generic_train from lightning_base import generic_train
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right from transformers.modeling_bart import shift_tokens_right
from utils import (
any_requires_grad,
try: assert_all_frozen,
from .finetune import SummarizationModule, TranslationModule calculate_bleu,
from .finetune import main as ft_main freeze_params,
from .initialization_utils import copy_layers, init_student label_smoothed_nll_loss,
from .utils import ( pickle_load,
any_requires_grad, use_task_specific_params,
assert_all_frozen, )
calculate_bleu,
freeze_params,
label_smoothed_nll_loss,
pickle_load,
use_task_specific_params,
)
except ImportError:
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
from initialization_utils import copy_layers, init_student
from utils import (
any_requires_grad,
assert_all_frozen,
calculate_bleu,
freeze_params,
label_smoothed_nll_loss,
pickle_load,
use_task_specific_params,
)
class BartSummarizationDistiller(SummarizationModule): class BartSummarizationDistiller(SummarizationModule):
......
...@@ -12,50 +12,29 @@ import pytorch_lightning as pl ...@@ -12,50 +12,29 @@ import pytorch_lightning as pl
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from lightning_base import BaseTransformer, add_generic_args, generic_train from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, T5ForConditionalGeneration from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right from transformers.modeling_bart import shift_tokens_right
from utils import (
ROUGE_KEYS,
LegacySeq2SeqDataset,
Seq2SeqDataset,
assert_all_frozen,
calculate_bleu,
calculate_rouge,
flatten_list,
freeze_params,
get_git_info,
label_smoothed_nll_loss,
lmap,
pickle_save,
save_git_info,
save_json,
use_task_specific_params,
)
try:
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from .utils import (
ROUGE_KEYS,
LegacySeq2SeqDataset,
Seq2SeqDataset,
assert_all_frozen,
calculate_bleu,
calculate_rouge,
flatten_list,
freeze_params,
get_git_info,
label_smoothed_nll_loss,
lmap,
pickle_save,
save_git_info,
save_json,
use_task_specific_params,
)
except ImportError:
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from utils import (
ROUGE_KEYS,
LegacySeq2SeqDataset,
Seq2SeqDataset,
assert_all_frozen,
calculate_bleu,
calculate_rouge,
flatten_list,
freeze_params,
get_git_info,
label_smoothed_nll_loss,
lmap,
pickle_save,
save_git_info,
save_json,
use_task_specific_params,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -11,35 +11,21 @@ from torch.utils.data import DataLoader ...@@ -11,35 +11,21 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import (
Seq2SeqDataset,
calculate_bleu,
calculate_rouge,
lmap,
load_json,
parse_numeric_n_bool_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
)
logger = getLogger(__name__) logger = getLogger(__name__)
try:
from .utils import (
Seq2SeqDataset,
calculate_bleu,
calculate_rouge,
lmap,
load_json,
parse_numeric_n_bool_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
)
except ImportError:
from utils import (
Seq2SeqDataset,
calculate_bleu,
calculate_rouge,
lmap,
load_json,
parse_numeric_n_bool_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
)
def eval_data_dir( def eval_data_dir(
data_dir, data_dir,
......
...@@ -11,14 +11,11 @@ import torch ...@@ -11,14 +11,11 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
logger = getLogger(__name__) logger = getLogger(__name__)
try:
from .utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
except ImportError:
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
......
...@@ -4,11 +4,7 @@ import operator ...@@ -4,11 +4,7 @@ import operator
import sys import sys
from collections import OrderedDict from collections import OrderedDict
from run_eval import datetime_now, run_generate
try:
from .run_eval import datetime_now, run_generate
except ImportError:
from run_eval import datetime_now, run_generate
# A table of supported tasks and the list of scores in the order of importance to be sorted by. # A table of supported tasks and the list of scores in the order of importance to be sorted by.
......
...@@ -10,13 +10,12 @@ import pytorch_lightning as pl ...@@ -10,13 +10,12 @@ import pytorch_lightning as pl
import timeout_decorator import timeout_decorator
import torch import torch
from distillation import BartSummarizationDistiller, distill_main
from finetune import SummarizationModule, main
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from transformers import BartForConditionalGeneration, MarianMTModel from transformers import BartForConditionalGeneration, MarianMTModel
from transformers.testing_utils import slow from transformers.testing_utils import slow
from utils import load_json
from .distillation import BartSummarizationDistiller, distill_main
from .finetune import SummarizationModule, main
from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from .utils import load_json
MODEL_NAME = MBART_TINY MODEL_NAME = MBART_TINY
......
...@@ -12,16 +12,15 @@ import pytorch_lightning as pl ...@@ -12,16 +12,15 @@ import pytorch_lightning as pl
import torch import torch
import lightning_base import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
from distillation import distill_main, evaluate_checkpoint
from finetune import SummarizationModule, main
from run_eval import generate_summaries_or_translations, run_generate
from run_eval_search import run_search
from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi from transformers.hf_api import HfApi
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
from utils import label_smoothed_nll_loss, lmap, load_json
from .convert_pl_checkpoint_to_hf import convert_pl_to_hf
from .distillation import distill_main, evaluate_checkpoint
from .finetune import SummarizationModule, main
from .run_eval import generate_summaries_or_translations, run_generate
from .run_eval_search import run_search
from .utils import label_smoothed_nll_loss, lmap, load_json
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment