Unverified Commit 470f51cd authored by Saurav Maheshkar's avatar Saurav Maheshkar Committed by GitHub
Browse files

feat: add `act_fn` param to `OutValueFunctionBlock` (#3994)



* feat: add act_fn param to OutValueFunctionBlock

* feat: update unet1d tests to not use mish

* feat: add `mish` as the default activation function
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* feat: drop mish tests from unet1d

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent b7e35dc7
...@@ -235,12 +235,12 @@ class OutConv1DBlock(nn.Module): ...@@ -235,12 +235,12 @@ class OutConv1DBlock(nn.Module):
class OutValueFunctionBlock(nn.Module): class OutValueFunctionBlock(nn.Module):
def __init__(self, fc_dim, embed_dim): def __init__(self, fc_dim, embed_dim, act_fn="mish"):
super().__init__() super().__init__()
self.final_block = nn.ModuleList( self.final_block = nn.ModuleList(
[ [
nn.Linear(fc_dim + embed_dim, fc_dim // 2), nn.Linear(fc_dim + embed_dim, fc_dim // 2),
nn.Mish(), get_activation(act_fn),
nn.Linear(fc_dim // 2, 1), nn.Linear(fc_dim // 2, 1),
] ]
) )
...@@ -652,5 +652,5 @@ def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, ac ...@@ -652,5 +652,5 @@ def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, ac
if out_block_type == "OutConv1DBlock": if out_block_type == "OutConv1DBlock":
return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn)
elif out_block_type == "ValueFunction": elif out_block_type == "ValueFunction":
return OutValueFunctionBlock(fc_dim, embed_dim) return OutValueFunctionBlock(fc_dim, embed_dim, act_fn)
return None return None
...@@ -52,27 +52,21 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -52,27 +52,21 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_training(self): def test_training(self):
pass pass
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_determinism(self): def test_determinism(self):
super().test_determinism() super().test_determinism()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_outputs_equivalence(self): def test_outputs_equivalence(self):
super().test_outputs_equivalence() super().test_outputs_equivalence()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
super().test_from_save_pretrained() super().test_from_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_save_pretrained_variant(self): def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant() super().test_from_save_pretrained_variant()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
super().test_model_from_pretrained() super().test_model_from_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output(self): def test_output(self):
super().test_output() super().test_output()
...@@ -89,12 +83,11 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -89,12 +83,11 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
"mid_block_type": "MidResTemporalBlock1D", "mid_block_type": "MidResTemporalBlock1D",
"down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"), "up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
"act_fn": "mish", "act_fn": "swish",
} }
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
model, loading_info = UNet1DModel.from_pretrained( model, loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
...@@ -107,7 +100,6 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -107,7 +100,6 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output_pretrained(self): def test_output_pretrained(self):
model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet") model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet")
torch.manual_seed(0) torch.manual_seed(0)
...@@ -177,27 +169,21 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -177,27 +169,21 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def output_shape(self): def output_shape(self):
return (4, 14, 1) return (4, 14, 1)
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_determinism(self): def test_determinism(self):
super().test_determinism() super().test_determinism()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_outputs_equivalence(self): def test_outputs_equivalence(self):
super().test_outputs_equivalence() super().test_outputs_equivalence()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_save_pretrained(self): def test_from_save_pretrained(self):
super().test_from_save_pretrained() super().test_from_save_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_save_pretrained_variant(self): def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant() super().test_from_save_pretrained_variant()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
super().test_model_from_pretrained() super().test_model_from_pretrained()
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output(self): def test_output(self):
# UNetRL is a value-function is different output shape # UNetRL is a value-function is different output shape
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -241,7 +227,6 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -241,7 +227,6 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input inputs_dict = self.dummy_input
return init_dict, inputs_dict return init_dict, inputs_dict
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_from_pretrained_hub(self): def test_from_pretrained_hub(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained( value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
...@@ -254,7 +239,6 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -254,7 +239,6 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
assert image is not None, "Make sure output is not None" assert image is not None, "Make sure output is not None"
@unittest.skipIf(torch_device == "mps", "mish op not supported in MPS")
def test_output_pretrained(self): def test_output_pretrained(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained( value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment