Unverified Commit 285a4801 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fix gradient checkpointing + fp16 autocast for most models (#24247)



* fix gc bug

* continue PoC on OPT

* fixes

* :exploding_head:

* fix tests

* remove pytest.mark

* fixup

* forward contrib credits from discussions

* forward contrib credits from discussions

* reverting changes on untouched files.

---------
Co-authored-by: default avatarzhaoqf123 <zhaoqf123@users.noreply.github.com>
Co-authored-by: default avatar7eu7d7 <7eu7d7@users.noreply.github.com>
parent 1815d186
...@@ -43,6 +43,7 @@ from ...modeling_outputs import ( ...@@ -43,6 +43,7 @@ from ...modeling_outputs import (
TokenClassifierOutput, TokenClassifierOutput,
) )
from ...modeling_utils import PreTrainedModel, SequenceSummary from ...modeling_utils import PreTrainedModel, SequenceSummary
from ...pytorch_utils import torch_custom_checkpointing
from ...pytorch_utils import ( from ...pytorch_utils import (
apply_chunking_to_forward, apply_chunking_to_forward,
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
...@@ -550,7 +551,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): ...@@ -550,7 +551,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(layer_module), create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -1585,6 +1586,7 @@ from ...modeling_outputs import ( ...@@ -1585,6 +1586,7 @@ from ...modeling_outputs import (
CausalLMOutputWithCrossAttentions CausalLMOutputWithCrossAttentions
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import torch_custom_checkpointing
from ...utils import logging from ...utils import logging
from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config from .configuration_{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}Config
...@@ -2318,7 +2320,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model ...@@ -2318,7 +2320,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(encoder_layer), create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
...@@ -2557,7 +2559,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2557,7 +2559,7 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
return custom_forward return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch_custom_checkpointing(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
......
...@@ -352,6 +352,12 @@ class AlignTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -352,6 +352,12 @@ class AlignTextModelTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@unittest.skip(reason="ALIGN does not use inputs_embeds") @unittest.skip(reason="ALIGN does not use inputs_embeds")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
...@@ -186,6 +186,12 @@ class AltCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -186,6 +186,12 @@ class AltCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@unittest.skip(reason="AltCLIPVisionModel has no base class and is not available in MODEL_MAPPING") @unittest.skip(reason="AltCLIPVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
...@@ -238,6 +238,12 @@ class AutoformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -238,6 +238,12 @@ class AutoformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
# # Input is 'static_categorical_features' not 'input_ids' # # Input is 'static_categorical_features' not 'input_ids'
def test_model_main_input_name(self): def test_model_main_input_name(self):
model_signature = inspect.signature(getattr(AutoformerModel, "forward")) model_signature = inspect.signature(getattr(AutoformerModel, "forward"))
......
...@@ -227,6 +227,12 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -227,6 +227,12 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_multi_gpu_data_parallel_forward(self): def test_multi_gpu_data_parallel_forward(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -609,6 +609,12 @@ class BigBirdModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) ...@@ -609,6 +609,12 @@ class BigBirdModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs) self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs)
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
# overwrite from common in order to skip the check on `attentions` # overwrite from common in order to skip the check on `attentions`
def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None): def check_pt_flax_outputs(self, fx_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
# `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version, # `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version,
......
...@@ -789,6 +789,12 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -789,6 +789,12 @@ class BlipTextRetrievalModelTest(ModelTesterMixin, unittest.TestCase):
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
def test_forward_signature(self): def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -499,6 +499,12 @@ class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -499,6 +499,12 @@ class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# ViT does not use inputs_embeds # ViT does not use inputs_embeds
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@unittest.skip("CANINE does not have a get_input_embeddings() method.") @unittest.skip("CANINE does not have a get_input_embeddings() method.")
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
......
...@@ -395,6 +395,12 @@ class ChineseCLIPTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -395,6 +395,12 @@ class ChineseCLIPTextModelTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@unittest.skip(reason="ChineseCLIPTextModel has no base class and is not available in MODEL_MAPPING") @unittest.skip(reason="ChineseCLIPTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
...@@ -469,6 +475,12 @@ class ChineseCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -469,6 +475,12 @@ class ChineseCLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_save_load_fast_init_to_base(self): def test_save_load_fast_init_to_base(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in CHINESE_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in CHINESE_CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -227,6 +227,12 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -227,6 +227,12 @@ class CLIPVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@unittest.skip(reason="CLIPVisionModel has no base class and is not available in MODEL_MAPPING") @unittest.skip(reason="CLIPVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
...@@ -202,6 +202,12 @@ class CLIPSegVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -202,6 +202,12 @@ class CLIPSegVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@unittest.skip(reason="CLIPSegVisionModel has no base class and is not available in MODEL_MAPPING") @unittest.skip(reason="CLIPSegVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
...@@ -448,6 +454,12 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) ...@@ -448,6 +454,12 @@ class CLIPSegModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def test_hidden_states_output(self): def test_hidden_states_output(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@unittest.skip(reason="Inputs_embeds is tested in individual model tests") @unittest.skip(reason="Inputs_embeds is tested in individual model tests")
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
......
...@@ -310,6 +310,12 @@ class Data2VecVisionModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te ...@@ -310,6 +310,12 @@ class Data2VecVisionModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs) self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in DATA2VEC_VISION_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -182,6 +182,12 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -182,6 +182,12 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -196,6 +196,12 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -196,6 +196,12 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -185,6 +185,12 @@ class FlavaImageModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -185,6 +185,12 @@ class FlavaImageModelTest(ModelTesterMixin, unittest.TestCase):
# FLAVA does not use inputs_embeds # FLAVA does not use inputs_embeds
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -462,6 +468,12 @@ class FlavaTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -462,6 +468,12 @@ class FlavaTextModelTest(ModelTesterMixin, unittest.TestCase):
# FLAVA does not use inputs_embeds # FLAVA does not use inputs_embeds
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
# skip this test as FlavaTextModel has no base class and is # skip this test as FlavaTextModel has no base class and is
# not available in MODEL_MAPPING # not available in MODEL_MAPPING
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
...@@ -624,6 +636,12 @@ class FlavaMultimodalModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -624,6 +636,12 @@ class FlavaMultimodalModelTest(ModelTesterMixin, unittest.TestCase):
def test_save_load_fast_init_to_base(self): def test_save_load_fast_init_to_base(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in FLAVA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
...@@ -731,6 +749,12 @@ class FlavaImageCodebookTest(ModelTesterMixin, unittest.TestCase): ...@@ -731,6 +749,12 @@ class FlavaImageCodebookTest(ModelTesterMixin, unittest.TestCase):
def test_save_load_fast_init_to_base(self): def test_save_load_fast_init_to_base(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in FLAVA_CODEBOOK_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
...@@ -1156,6 +1180,12 @@ class FlavaForPreTrainingTest(FlavaModelTest): ...@@ -1156,6 +1180,12 @@ class FlavaForPreTrainingTest(FlavaModelTest):
class_for_tester = FlavaForPreTrainingTester class_for_tester = FlavaForPreTrainingTester
test_torchscript = False test_torchscript = False
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
# We will verify our results on an image of cute cats # We will verify our results on an image of cute cats
def prepare_img(): def prepare_img():
......
...@@ -444,6 +444,12 @@ class FNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -444,6 +444,12 @@ class FNetModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs) self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in FNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: for model_name in FNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
...@@ -562,6 +562,12 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -562,6 +562,12 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt2_weight_initialization(*config_and_inputs) self.model_tester.create_and_check_gpt2_weight_initialization(*config_and_inputs)
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@slow @slow
def test_batch_generation(self): def test_batch_generation(self):
model = GPT2LMHeadModel.from_pretrained("gpt2") model = GPT2LMHeadModel.from_pretrained("gpt2")
......
...@@ -356,6 +356,12 @@ class GraphormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa ...@@ -356,6 +356,12 @@ class GraphormerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
@unittest.skip(reason="Graphormer does not share input and output embeddings") @unittest.skip(reason="Graphormer does not share input and output embeddings")
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
......
...@@ -304,6 +304,12 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM ...@@ -304,6 +304,12 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
def test_imagegpt_model(self): def test_imagegpt_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_imagegpt_model(*config_and_inputs) self.model_tester.create_and_check_imagegpt_model(*config_and_inputs)
......
...@@ -216,6 +216,12 @@ class InformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase ...@@ -216,6 +216,12 @@ class InformerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common() config_and_inputs = self.model_tester.prepare_config_and_inputs_for_common()
self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs) self.model_tester.check_encoder_decoder_model_standalone(*config_and_inputs)
@unittest.skip(
reason="The model does not support GC + autocast + fp16: https://github.com/huggingface/transformers/pull/24247"
)
def test_training_gradient_checkpointing_autocast(self):
pass
def test_hidden_states_output(self): def test_hidden_states_output(self):
def check_hidden_states_output(inputs_dict, config, model_class): def check_hidden_states_output(inputs_dict, config, model_class):
model = model_class(config) model = model_class(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment