Unverified Commit 7cbf0f72 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

examples/seq2seq/__init__.py mutates sys.path (#7194)

parent a4faecea
import os
import sys
sys.path.insert(1, os.path.dirname(os.path.realpath(__file__)))
......@@ -10,16 +10,13 @@ import torch
from torch import nn
from torch.nn import functional as F
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
from initialization_utils import copy_layers, init_student
from lightning_base import generic_train
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
try:
from .finetune import SummarizationModule, TranslationModule
from .finetune import main as ft_main
from .initialization_utils import copy_layers, init_student
from .utils import (
from utils import (
any_requires_grad,
assert_all_frozen,
calculate_bleu,
......@@ -27,20 +24,7 @@ try:
label_smoothed_nll_loss,
pickle_load,
use_task_specific_params,
)
except ImportError:
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
from initialization_utils import copy_layers, init_student
from utils import (
any_requires_grad,
assert_all_frozen,
calculate_bleu,
freeze_params,
label_smoothed_nll_loss,
pickle_load,
use_task_specific_params,
)
)
class BartSummarizationDistiller(SummarizationModule):
......
......@@ -12,33 +12,11 @@ import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
try:
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from .utils import (
ROUGE_KEYS,
LegacySeq2SeqDataset,
Seq2SeqDataset,
assert_all_frozen,
calculate_bleu,
calculate_rouge,
flatten_list,
freeze_params,
get_git_info,
label_smoothed_nll_loss,
lmap,
pickle_save,
save_git_info,
save_json,
use_task_specific_params,
)
except ImportError:
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from utils import (
from utils import (
ROUGE_KEYS,
LegacySeq2SeqDataset,
Seq2SeqDataset,
......@@ -54,7 +32,8 @@ except ImportError:
save_git_info,
save_json,
use_task_specific_params,
)
)
logger = logging.getLogger(__name__)
......
......@@ -11,24 +11,7 @@ from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger = getLogger(__name__)
try:
from .utils import (
Seq2SeqDataset,
calculate_bleu,
calculate_rouge,
lmap,
load_json,
parse_numeric_n_bool_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
)
except ImportError:
from utils import (
from utils import (
Seq2SeqDataset,
calculate_bleu,
calculate_rouge,
......@@ -38,7 +21,10 @@ except ImportError:
save_json,
use_task_specific_params,
write_txt_file,
)
)
logger = getLogger(__name__)
def eval_data_dir(
......
......@@ -11,14 +11,11 @@ import torch
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
logger = getLogger(__name__)
try:
from .utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
except ImportError:
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
......
......@@ -4,11 +4,7 @@ import operator
import sys
from collections import OrderedDict
try:
from .run_eval import datetime_now, run_generate
except ImportError:
from run_eval import datetime_now, run_generate
from run_eval import datetime_now, run_generate
# A table of supported tasks and the list of scores in the order of importance to be sorted by.
......
......@@ -10,13 +10,12 @@ import pytorch_lightning as pl
import timeout_decorator
import torch
from distillation import BartSummarizationDistiller, distill_main
from finetune import SummarizationModule, main
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from transformers import BartForConditionalGeneration, MarianMTModel
from transformers.testing_utils import slow
from .distillation import BartSummarizationDistiller, distill_main
from .finetune import SummarizationModule, main
from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from .utils import load_json
from utils import load_json
MODEL_NAME = MBART_TINY
......
......@@ -12,16 +12,15 @@ import pytorch_lightning as pl
import torch
import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
from distillation import distill_main, evaluate_checkpoint
from finetune import SummarizationModule, main
from run_eval import generate_summaries_or_translations, run_generate
from run_eval_search import run_search
from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
from .convert_pl_checkpoint_to_hf import convert_pl_to_hf
from .distillation import distill_main, evaluate_checkpoint
from .finetune import SummarizationModule, main
from .run_eval import generate_summaries_or_translations, run_generate
from .run_eval_search import run_search
from .utils import label_smoothed_nll_loss, lmap, load_json
from utils import label_smoothed_nll_loss, lmap, load_json
logging.basicConfig(level=logging.DEBUG)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment