Commit 5c8e5b37 authored by Victor SANH's avatar Victor SANH
Browse files

commplying with isort

parent db2a3b2e
...@@ -17,13 +17,13 @@ For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceC ...@@ -17,13 +17,13 @@ For instance, once the a model from the :class:`~emmental.MaskedBertForSequenceC
as a standard :class:`~transformers.BertForSequenceClassification`. as a standard :class:`~transformers.BertForSequenceClassification`.
""" """
import argparse
import os import os
import shutil import shutil
import argparse
import torch import torch
from emmental.modules import MagnitudeBinarizer, TopKBinarizer, ThresholdBinarizer from emmental.modules import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
def main(args): def main(args):
...@@ -40,13 +40,13 @@ def main(args): ...@@ -40,13 +40,13 @@ def main(args):
for name, tensor in model.items(): for name, tensor in model.items():
if "embeddings" in name or "LayerNorm" in name or "pooler" in name: if "embeddings" in name or "LayerNorm" in name or "pooler" in name:
pruned_model[name] = tensor pruned_model[name] = tensor
print(f"Pruned layer {name}") print(f"Copied layer {name}")
elif "classifier" in name or "qa_output" in name: elif "classifier" in name or "qa_output" in name:
pruned_model[name] = tensor pruned_model[name] = tensor
print(f"Pruned layer {name}") print(f"Copied layer {name}")
elif "bias" in name: elif "bias" in name:
pruned_model[name] = tensor pruned_model[name] = tensor
print(f"Pruned layer {name}") print(f"Copied layer {name}")
else: else:
if pruning_method == "magnitude": if pruning_method == "magnitude":
mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold) mask = MagnitudeBinarizer.apply(inputs=tensor, threshold=threshold)
......
...@@ -15,12 +15,12 @@ ...@@ -15,12 +15,12 @@
Count remaining (non-zero) weights in the encoder (i.e. the transformer layers). Count remaining (non-zero) weights in the encoder (i.e. the transformer layers).
Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %. Sparsity and remaining weights levels are equivalent: sparsity % = 100 - remaining weights %.
""" """
import os
import argparse import argparse
import os
import torch import torch
from emmental.modules import TopKBinarizer, ThresholdBinarizer from emmental.modules import ThresholdBinarizer, TopKBinarizer
def main(args): def main(args):
......
from .modules import *
from .configuration_bert_masked import MaskedBertConfig from .configuration_bert_masked import MaskedBertConfig
from .modeling_bert_masked import ( from .modeling_bert_masked import (
MaskedBertModel, MaskedBertForMultipleChoice,
MaskedBertForQuestionAnswering, MaskedBertForQuestionAnswering,
MaskedBertForSequenceClassification, MaskedBertForSequenceClassification,
MaskedBertForTokenClassification, MaskedBertForTokenClassification,
MaskedBertForMultipleChoice, MaskedBertModel,
) )
from .modules import *
...@@ -19,8 +19,9 @@ and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init` ...@@ -19,8 +19,9 @@ and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init`
import logging import logging
from transformers.configuration_utils import PretrainedConfig
from transformers.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP from transformers.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers.configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -26,13 +26,16 @@ import torch ...@@ -26,13 +26,16 @@ import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from emmental import MaskedBertConfig, MaskedLinear
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
from transformers.modeling_bert import (
ACT2FN,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
BertLayerNorm,
load_tf_weights_in_bert,
)
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.modeling_bert import load_tf_weights_in_bert, ACT2FN, BertLayerNorm
from emmental import MaskedLinear
from emmental import MaskedBertConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
from .binarizer import ThresholdBinarizer, TopKBinarizer, MagnitudeBinarizer from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
from .masked_nn import MaskedLinear from .masked_nn import MaskedLinear
...@@ -19,14 +19,14 @@ the weight matrix to prune a portion of the weights. ...@@ -19,14 +19,14 @@ the weight matrix to prune a portion of the weights.
The pruned weight matrix is then multiplied against the inputs (and if necessary, the bias is added). The pruned weight matrix is then multiplied against the inputs (and if necessary, the bias is added).
""" """
import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from torch.nn import init from torch.nn import init
import math from .binarizer import MagnitudeBinarizer, ThresholdBinarizer, TopKBinarizer
from .binarizer import ThresholdBinarizer, TopKBinarizer, MagnitudeBinarizer
class MaskedLinear(nn.Linear): class MaskedLinear(nn.Linear):
......
...@@ -24,12 +24,13 @@ import random ...@@ -24,12 +24,13 @@ import random
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from emmental import MaskedBertConfig, MaskedBertForSequenceClassification
from transformers import ( from transformers import (
WEIGHTS_NAME, WEIGHTS_NAME,
AdamW, AdamW,
...@@ -43,7 +44,6 @@ from transformers import glue_convert_examples_to_features as convert_examples_t ...@@ -43,7 +44,6 @@ from transformers import glue_convert_examples_to_features as convert_examples_t
from transformers import glue_output_modes as output_modes from transformers import glue_output_modes as output_modes
from transformers import glue_processors as processors from transformers import glue_processors as processors
from emmental import MaskedBertConfig, MaskedBertForSequenceClassification
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
......
...@@ -25,12 +25,13 @@ import timeit ...@@ -25,12 +25,13 @@ import timeit
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from emmental import MaskedBertConfig, MaskedBertForQuestionAnswering
from transformers import ( from transformers import (
WEIGHTS_NAME, WEIGHTS_NAME,
AdamW, AdamW,
...@@ -48,8 +49,6 @@ from transformers.data.metrics.squad_metrics import ( ...@@ -48,8 +49,6 @@ from transformers.data.metrics.squad_metrics import (
from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor from transformers.data.processors.squad import SquadResult, SquadV1Processor, SquadV2Processor
from emmental import MaskedBertConfig, MaskedBertForQuestionAnswering
try: try:
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
except ImportError: except ImportError:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment