Unverified Commit de142611 authored by Prathik Rao's avatar Prathik Rao Committed by GitHub
Browse files

Make `UNet2DConditionOutput` pickle-able (#3857)



* add default to unet output to prevent it from being a required arg

* add unit test

* make style

* adjust unit test

* mark as fast test

* adjust assert statement in test

---------

Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: default avatarroot <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
parent 41ea88f3
...@@ -57,7 +57,7 @@ class UNet2DConditionOutput(BaseOutput): ...@@ -57,7 +57,7 @@ class UNet2DConditionOutput(BaseOutput):
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
""" """
sample: torch.FloatTensor sample: torch.FloatTensor = None
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import gc import gc
import os import os
import tempfile import tempfile
...@@ -782,6 +783,22 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -782,6 +783,22 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert (sample - on_sample).abs().max() < 1e-4 assert (sample - on_sample).abs().max() < 1e-4
assert (sample - off_sample).abs().max() < 1e-4 assert (sample - off_sample).abs().max() < 1e-4
def test_pickle(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
sample = model(**inputs_dict).sample
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
@slow @slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase): class UNet2DConditionModelIntegrationTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment