Commit 96c49901 authored by Gunnlaugur Thor Briem's avatar Gunnlaugur Thor Briem
Browse files

fix unused imports and style

parent 470753bc
......@@ -24,7 +24,7 @@ import tensorflow as tf
from .configuration_distilbert import DistilBertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, keras_serializable, shape_list
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, get_initializer, shape_list
logger = logging.getLogger(__name__)
......
......@@ -29,7 +29,6 @@ from .modeling_tf_utils import (
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
keras_serializable,
shape_list,
)
......
......@@ -20,11 +20,10 @@ import logging
import tensorflow as tf
from . import PretrainedConfig
from .configuration_roberta import RobertaConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_bert import TFBertEmbeddings, TFBertMainLayer, gelu
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
logger = logging.getLogger(__name__)
......
......@@ -25,7 +25,7 @@ import tensorflow as tf
from .configuration_t5 import T5Config
from .file_utils import DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
logger = logging.getLogger(__name__)
......
......@@ -25,14 +25,7 @@ import tensorflow as tf
from .configuration_xlm import XLMConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_tf_utils import (
TFPreTrainedModel,
TFSequenceSummary,
TFSharedEmbeddings,
get_initializer,
keras_serializable,
shape_list,
)
from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, get_initializer, shape_list
logger = logging.getLogger(__name__)
......
......@@ -102,8 +102,9 @@ class TFModelTesterMixin:
for module_member_name in dir(module)
if module_member_name.endswith("MainLayer")
for module_member in (getattr(module, module_member_name),)
if isinstance(module_member, type) and tf.keras.layers.Layer in module_member.__bases__
and getattr(module_member, '_keras_serializable', False)
if isinstance(module_member, type)
and tf.keras.layers.Layer in module_member.__bases__
and getattr(module_member, "_keras_serializable", False)
)
for main_layer_class in tf_main_layer_classes:
main_layer = main_layer_class(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment