implicitron_config_system.ipynb 40.6 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
2
3
4
{
  "cells": [
    {
      "cell_type": "code",
5
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
6
7
      "metadata": {
        "customInput": null,
8
9
10
        "customOutput": null,
        "originalKey": "f0af2d90-cb21-4ab4-b4cb-0fd00dbfb77b",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
11
      },
12
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
13
14
      "source": [
        "# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved."
15
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
16
17
    },
    {
18
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
19
20
      "cell_type": "markdown",
      "metadata": {
21
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
22
        "originalKey": "4e15bfa2-5404-40d0-98b6-eb2732c8b72b",
23
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
24
25
26
      },
      "source": [
        "# Implicitron's config system"
27
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
28
29
    },
    {
30
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
31
32
      "cell_type": "markdown",
      "metadata": {
33
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
34
        "originalKey": "287be985-423d-42e0-a2af-1e8c585e723c",
35
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
36
37
38
39
40
41
42
43
44
      },
      "source": [
        "Implicitron's components are all based on a unified hierarchical configuration system. \n",
        "This allows configurable variables and all defaults to be defined separately for each new component.\n",
        "All configs relevant to an experiment are then automatically composed into a single configuration file that fully specifies the experiment.\n",
        "An especially important feature is extension points where users can insert their own sub-classes of Implicitron's base components.\n",
        "\n",
        "The file which defines this system is [here](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/implicitron/tools/config.py) in the PyTorch3D repo.\n",
        "The Implicitron volumes tutorial contains a simple example of using the config system.\n",
45
46
        "This tutorial provides detailed hands-on experience in using and modifying Implicitron's configurable components.\n"
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
47
48
49
50
    },
    {
      "cell_type": "markdown",
      "metadata": {
51
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
52
        "originalKey": "fde300a2-99cb-4d52-9d5b-4464a2083e0b",
53
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
54
55
56
57
58
59
60
61
62
      },
      "source": [
        "## 0. Install and import modules\n",
        "\n",
        "Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:"
      ]
    },
    {
      "cell_type": "code",
63
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
64
65
      "metadata": {
        "customInput": null,
66
67
68
        "customOutput": null,
        "originalKey": "ad6e94a7-e114-43d3-b038-a5210c7d34c9",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
69
      },
70
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
      "source": [
        "import os\n",
        "import sys\n",
        "import torch\n",
        "need_pytorch3d=False\n",
        "try:\n",
        "    import pytorch3d\n",
        "except ModuleNotFoundError:\n",
        "    need_pytorch3d=True\n",
        "if need_pytorch3d:\n",
        "    if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
        "        # We try to install PyTorch3D via a released wheel.\n",
        "        pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
        "        version_str=\"\".join([\n",
        "            f\"py3{sys.version_info.minor}_cu\",\n",
        "            torch.version.cuda.replace(\".\",\"\"),\n",
        "            f\"_pyt{pyt_version_str}\"\n",
        "        ])\n",
        "        !pip install fvcore iopath\n",
        "        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
        "    else:\n",
        "        # We try to install PyTorch3D from source.\n",
        "        !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz\n",
        "        !tar xzf 1.10.0.tar.gz\n",
        "        os.environ[\"CUB_HOME\"] = os.getcwd() + \"/cub-1.10.0\"\n",
        "        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
97
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
98
99
    },
    {
100
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
101
102
      "cell_type": "markdown",
      "metadata": {
103
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
104
        "originalKey": "609896c0-9e2e-4716-b074-b565f0170e32",
105
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
106
107
108
      },
      "source": [
        "Ensure omegaconf is installed. If not, run this cell. (It should not be necessary to restart the runtime.)"
109
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
110
111
112
    },
    {
      "cell_type": "code",
113
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
114
115
      "metadata": {
        "customInput": null,
116
117
118
        "customOutput": null,
        "originalKey": "d1c1851e-b9f2-4236-93c3-19aa4d63041c",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
119
      },
120
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
121
122
      "source": [
        "!pip install omegaconf"
123
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
124
125
126
    },
    {
      "cell_type": "code",
127
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
128
129
      "metadata": {
        "code_folding": [],
130
        "collapsed": false,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
131
132
        "customOutput": null,
        "executionStartTime": 1659465468717,
133
134
135
136
        "executionStopTime": 1659465468738,
        "hidden_ranges": [],
        "originalKey": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f",
        "requestMsgId": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
137
      },
138
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
      "source": [
        "from dataclasses import dataclass\n",
        "from typing import Optional, Tuple\n",
        "\n",
        "import torch\n",
        "from omegaconf import DictConfig, OmegaConf\n",
        "from pytorch3d.implicitron.tools.config import (\n",
        "    Configurable,\n",
        "    ReplaceableBase,\n",
        "    expand_args_fields,\n",
        "    get_default_args,\n",
        "    registry,\n",
        "    run_auto_creation,\n",
        ")"
153
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
154
155
156
157
    },
    {
      "cell_type": "markdown",
      "metadata": {
158
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
159
        "originalKey": "a638bf90-eb6b-424d-b53d-eae11954a717",
160
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
161
162
163
164
165
166
167
168
169
      },
      "source": [
        "## 1. Introducing dataclasses \n",
        "\n",
        "[Type hints](https://docs.python.org/3/library/typing.html) give a taxonomy of types in Python. [Dataclasses](https://docs.python.org/3/library/dataclasses.html) let you create a class based on a list of members which have names, types and possibly default values. The `__init__` function is created automatically, and calls a `__post_init__` function if present as a final step. For example"
      ]
    },
    {
      "cell_type": "code",
170
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
171
172
      "metadata": {
        "collapsed": false,
173
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
174
175
        "customOutput": null,
        "executionStartTime": 1659454972732,
176
177
178
179
        "executionStopTime": 1659454972739,
        "originalKey": "71eaad5e-e198-492e-8610-24b0da9dd4ae",
        "requestMsgId": "71eaad5e-e198-492e-8610-24b0da9dd4ae",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
180
      },
181
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
182
183
184
185
186
187
188
189
190
191
      "source": [
        "@dataclass\n",
        "class MyDataclass:\n",
        "    a: int\n",
        "    b: int = 8\n",
        "    c: Optional[Tuple[int, ...]] = None\n",
        "\n",
        "    def __post_init__(self):\n",
        "        print(f\"created with a = {self.a}\")\n",
        "        self.d = 2 * self.b"
192
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
193
194
195
    },
    {
      "cell_type": "code",
196
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
197
198
      "metadata": {
        "collapsed": false,
199
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
200
201
        "customOutput": null,
        "executionStartTime": 1659454973051,
202
203
204
205
        "executionStopTime": 1659454973077,
        "originalKey": "83202a18-a3d3-44ec-a62d-b3360a302645",
        "requestMsgId": "83202a18-a3d3-44ec-a62d-b3360a302645",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
206
      },
207
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
208
209
210
      "source": [
        "my_dataclass_instance = MyDataclass(a=18)\n",
        "assert my_dataclass_instance.d == 16"
211
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
212
213
    },
    {
214
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
215
216
      "cell_type": "markdown",
      "metadata": {
217
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
218
        "originalKey": "b67ccb9f-dc6c-4994-9b99-b5a1bcfebd70",
219
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
220
221
222
223
224
      },
      "source": [
        "👷 Note that the `dataclass` decorator here is function which modifies the definition of the class itself.\n",
        "It runs immediately after the definition.\n",
        "Our config system requires that implicitron library code contains classes whose modified versions need to be aware of user-defined implementations.\n",
225
226
        "Therefore we need the modification of the class to be delayed. We don't use a decorator.\n"
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
227
228
    },
    {
229
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
230
231
      "cell_type": "markdown",
      "metadata": {
232
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
233
        "originalKey": "3e90f664-99df-4387-9c45-a1ad7939ef3a",
234
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
235
236
237
238
239
      },
      "source": [
        "## 2. Introducing omegaconf and OmegaConf.structured\n",
        "\n",
        "The [omegaconf](https://github.com/omry/omegaconf/) library provides a DictConfig class which is like a `dict` with str keys, but with extra features for ease-of-use as a configuration system."
240
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
241
242
243
    },
    {
      "cell_type": "code",
244
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
245
246
      "metadata": {
        "collapsed": false,
247
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
248
249
        "customOutput": null,
        "executionStartTime": 1659451341683,
250
251
252
253
        "executionStopTime": 1659451341690,
        "originalKey": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174",
        "requestMsgId": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
254
      },
255
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
256
257
258
      "source": [
        "dc = DictConfig({\"a\": 2, \"b\": True, \"c\": None, \"d\": \"hello\"})\n",
        "assert dc.a == dc[\"a\"] == 2"
259
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
260
261
    },
    {
262
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
263
264
      "cell_type": "markdown",
      "metadata": {
265
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
266
        "originalKey": "3b5b76a9-4b76-4784-96ff-2a1212e48e48",
267
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
268
269
270
      },
      "source": [
        "OmegaConf has a serialization to and from yaml. The [Hydra](https://hydra.cc/) library relies on this for its configuration files."
271
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
272
273
274
    },
    {
      "cell_type": "code",
275
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
276
277
      "metadata": {
        "collapsed": false,
278
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
279
280
        "customOutput": null,
        "executionStartTime": 1659451411835,
281
282
283
284
        "executionStopTime": 1659451411936,
        "originalKey": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61",
        "requestMsgId": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
285
      },
286
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
287
288
289
      "source": [
        "print(OmegaConf.to_yaml(dc))\n",
        "assert OmegaConf.create(OmegaConf.to_yaml(dc)) == dc"
290
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
291
292
    },
    {
293
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
294
295
      "cell_type": "markdown",
      "metadata": {
296
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
297
        "originalKey": "777fecdd-8bf6-4fd8-827b-cb8af5477fa8",
298
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
299
300
301
      },
      "source": [
        "OmegaConf.structured provides a DictConfig from a dataclass or instance of a dataclass. Unlike a normal DictConfig, it is type-checked and only known keys can be added."
302
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
303
304
305
    },
    {
      "cell_type": "code",
306
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
307
308
      "metadata": {
        "collapsed": false,
309
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
310
311
        "customOutput": null,
        "executionStartTime": 1659455098879,
312
313
314
315
        "executionStopTime": 1659455098900,
        "originalKey": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162",
        "requestMsgId": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
316
      },
317
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
318
319
320
321
322
323
      "source": [
        "structured = OmegaConf.structured(MyDataclass)\n",
        "assert isinstance(structured, DictConfig)\n",
        "print(structured)\n",
        "print()\n",
        "print(OmegaConf.to_yaml(structured))"
324
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
325
326
    },
    {
327
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
328
329
      "cell_type": "markdown",
      "metadata": {
330
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
331
        "originalKey": "be4446da-e536-4139-9ba3-37669a5b5e61",
332
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
333
334
335
      },
      "source": [
        "`structured` knows it is missing a value for `a`."
336
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
337
338
    },
    {
339
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
340
341
      "cell_type": "markdown",
      "metadata": {
342
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
343
        "originalKey": "864811e8-1a75-4932-a85e-f681b0541ae9",
344
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
345
346
347
      },
      "source": [
        "Such an object has members compatible with the dataclass, so an initialisation can be performed as follows."
348
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
349
350
351
    },
    {
      "cell_type": "code",
352
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
353
354
      "metadata": {
        "collapsed": false,
355
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
356
357
        "customOutput": null,
        "executionStartTime": 1659455580491,
358
359
360
361
        "executionStopTime": 1659455580501,
        "originalKey": "eb88aaa0-c22f-4ffb-813a-ca957b490acb",
        "requestMsgId": "eb88aaa0-c22f-4ffb-813a-ca957b490acb",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
362
      },
363
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
364
365
366
367
      "source": [
        "structured.a = 21\n",
        "my_dataclass_instance2 = MyDataclass(**structured)\n",
        "print(my_dataclass_instance2)"
368
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
369
370
    },
    {
371
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
372
373
      "cell_type": "markdown",
      "metadata": {
374
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
375
        "originalKey": "2d08c81c-9d18-4de9-8464-0da2d89f94f3",
376
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
377
378
379
      },
      "source": [
        "You can also call OmegaConf.structured on an instance."
380
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
381
382
383
    },
    {
      "cell_type": "code",
384
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
385
386
      "metadata": {
        "collapsed": false,
387
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
388
389
        "customOutput": null,
        "executionStartTime": 1659455594700,
390
391
392
393
        "executionStopTime": 1659455594737,
        "originalKey": "5e469bac-32a4-475d-9c09-8b64ba3f2155",
        "requestMsgId": "5e469bac-32a4-475d-9c09-8b64ba3f2155",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
394
      },
395
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
396
397
398
399
      "source": [
        "structured_from_instance = OmegaConf.structured(my_dataclass_instance)\n",
        "my_dataclass_instance3 = MyDataclass(**structured_from_instance)\n",
        "print(my_dataclass_instance3)"
400
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
401
402
    },
    {
403
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
404
405
406
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
407
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
408
409
        "customOutput": null,
        "executionStartTime": 1659452594203,
410
411
412
413
        "executionStopTime": 1659452594333,
        "originalKey": "2ed559e3-8552-465a-938f-30c72a321184",
        "requestMsgId": "2ed559e3-8552-465a-938f-30c72a321184",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
414
415
416
417
418
419
420
      },
      "source": [
        "## 3. Our approach to OmegaConf.structured\n",
        "\n",
        "We provide functions which are equivalent to `OmegaConf.structured` but support more features. \n",
        "To achieve the above using our functions, the following is used.\n",
        "Note that we indicate configurable classes using a special base class `Configurable`, not a decorator."
421
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
422
423
424
    },
    {
      "cell_type": "code",
425
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
426
427
      "metadata": {
        "collapsed": false,
428
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
429
430
        "customOutput": null,
        "executionStartTime": 1659454053323,
431
432
433
434
        "executionStopTime": 1659454061629,
        "originalKey": "9888afbd-e617-4596-ab7a-fc1073f58656",
        "requestMsgId": "9888afbd-e617-4596-ab7a-fc1073f58656",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
435
      },
436
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
437
438
439
440
441
442
443
444
445
      "source": [
        "class MyConfigurable(Configurable):\n",
        "    a: int\n",
        "    b: int = 8\n",
        "    c: Optional[Tuple[int, ...]] = None\n",
        "\n",
        "    def __post_init__(self):\n",
        "        print(f\"created with a = {self.a}\")\n",
        "        self.d = 2 * self.b"
446
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
447
448
449
    },
    {
      "cell_type": "code",
450
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
451
452
      "metadata": {
        "collapsed": false,
453
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
454
455
        "customOutput": null,
        "executionStartTime": 1659454784912,
456
457
458
459
        "executionStopTime": 1659454784928,
        "originalKey": "e43155b4-3da5-4df1-a2f5-da1d0369eec9",
        "requestMsgId": "e43155b4-3da5-4df1-a2f5-da1d0369eec9",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
460
      },
461
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
462
      "source": [
463
464
465
        "# The expand_args_fields function modifies the class like @dataclasses.dataclass.\n",
        "# If it has not been called on a Configurable object before it has been instantiated, it will\n",
        "# be called automatically.\n",
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
466
467
468
        "expand_args_fields(MyConfigurable)\n",
        "my_configurable_instance = MyConfigurable(a=18)\n",
        "assert my_configurable_instance.d == 16"
469
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
470
471
472
    },
    {
      "cell_type": "code",
473
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
474
475
      "metadata": {
        "collapsed": false,
476
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
477
478
        "customOutput": null,
        "executionStartTime": 1659460669541,
479
480
481
482
        "executionStopTime": 1659460669566,
        "originalKey": "96eaae18-dce4-4ee1-b451-1466fea51b9f",
        "requestMsgId": "96eaae18-dce4-4ee1-b451-1466fea51b9f",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
483
      },
484
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
485
      "source": [
486
        "# get_default_args also calls expand_args_fields automatically\n",
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
487
488
489
        "our_structured = get_default_args(MyConfigurable)\n",
        "assert isinstance(our_structured, DictConfig)\n",
        "print(OmegaConf.to_yaml(our_structured))"
490
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
491
492
493
    },
    {
      "cell_type": "code",
494
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
495
496
      "metadata": {
        "collapsed": false,
497
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
498
499
        "customOutput": null,
        "executionStartTime": 1659460454020,
500
501
502
503
        "executionStopTime": 1659460454032,
        "originalKey": "359f7925-68de-42cd-bd34-79a099b1c210",
        "requestMsgId": "359f7925-68de-42cd-bd34-79a099b1c210",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
504
      },
505
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
506
507
508
      "source": [
        "our_structured.a = 21\n",
        "print(MyConfigurable(**our_structured))"
509
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
510
511
    },
    {
512
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
513
514
515
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
516
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
517
518
        "customOutput": null,
        "executionStartTime": 1659460599142,
519
520
521
522
        "executionStopTime": 1659460599149,
        "originalKey": "eac7d385-9365-4098-acf9-4f0a0dbdcb85",
        "requestMsgId": "eac7d385-9365-4098-acf9-4f0a0dbdcb85",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
523
524
525
526
527
528
      },
      "source": [
        "## 4. First enhancement: nested types 🪺\n",
        "\n",
        "Our system allows Configurable classes to contain each other. \n",
        "One thing to remember: add a call to `run_auto_creation` in `__post_init__`."
529
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
530
531
532
    },
    {
      "cell_type": "code",
533
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
534
535
      "metadata": {
        "collapsed": false,
536
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
537
538
        "customOutput": null,
        "executionStartTime": 1659465752418,
539
540
541
542
        "executionStopTime": 1659465752976,
        "originalKey": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a",
        "requestMsgId": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
543
      },
544
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
545
546
547
548
549
550
551
552
553
554
555
556
557
558
      "source": [
        "class Inner(Configurable):\n",
        "    a: int = 8\n",
        "    b: bool = True\n",
        "    c: Tuple[int, ...] = (2, 3, 4, 6)\n",
        "\n",
        "\n",
        "class Outer(Configurable):\n",
        "    inner: Inner\n",
        "    x: str = \"hello\"\n",
        "    xx: bool = False\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)"
559
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
560
561
562
    },
    {
      "cell_type": "code",
563
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
564
565
      "metadata": {
        "collapsed": false,
566
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
567
568
        "customOutput": null,
        "executionStartTime": 1659465762326,
569
570
571
572
        "executionStopTime": 1659465762339,
        "originalKey": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7",
        "requestMsgId": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
573
      },
574
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
575
576
577
      "source": [
        "outer_dc = get_default_args(Outer)\n",
        "print(OmegaConf.to_yaml(outer_dc))"
578
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
579
580
581
    },
    {
      "cell_type": "code",
582
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
583
584
      "metadata": {
        "collapsed": false,
585
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
586
587
        "customOutput": null,
        "executionStartTime": 1659465772894,
588
589
590
591
        "executionStopTime": 1659465772911,
        "originalKey": "0254204b-8c7a-4d40-bba6-5132185f63d7",
        "requestMsgId": "0254204b-8c7a-4d40-bba6-5132185f63d7",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
592
      },
593
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
594
595
596
597
598
599
      "source": [
        "outer = Outer(**outer_dc)\n",
        "assert isinstance(outer, Outer)\n",
        "assert isinstance(outer.inner, Inner)\n",
        "print(vars(outer))\n",
        "print(outer.inner)"
600
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
601
602
    },
    {
603
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
604
605
      "cell_type": "markdown",
      "metadata": {
606
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
607
        "originalKey": "44a78c13-ec92-4a87-808a-c4674b320c22",
608
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
609
610
611
612
613
614
      },
      "source": [
        "Note how inner_args is an extra member of outer. `run_auto_creation(self)` is equivalent to\n",
        "```\n",
        "    self.inner = Inner(**self.inner_args)\n",
        "```"
615
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
616
617
    },
    {
618
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
619
620
621
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
622
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
623
624
        "customOutput": null,
        "executionStartTime": 1659461071129,
625
626
627
628
        "executionStopTime": 1659461071137,
        "originalKey": "af0ec78b-7888-4b0d-9346-63d970d43293",
        "requestMsgId": "af0ec78b-7888-4b0d-9346-63d970d43293",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
629
630
631
632
633
634
635
636
637
638
639
640
      },
      "source": [
        "## 5. Second enhancement: pluggable/replaceable components 🔌\n",
        "\n",
        "If a class uses `ReplaceableBase` as a base class instead of `Configurable`, we call it a replaceable.\n",
        "It indicates that it is designed for child classes to use in its place.\n",
        "We might use `NotImplementedError` to indicate functionality which subclasses are expected to implement.\n",
        "The system maintains a global `registry` containing subclasses of each ReplaceableBase.\n",
        "The subclasses register themselves with it with a decorator.\n",
        "\n",
        "A configurable class (i.e. a class which uses our system, i.e. a child of `Configurable` or `ReplaceableBase`) which contains a ReplaceableBase must also \n",
        "contain a corresponding class_type field of type `str` which indicates which concrete child class to use."
641
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
642
643
644
    },
    {
      "cell_type": "code",
645
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
646
647
      "metadata": {
        "collapsed": false,
648
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
649
650
        "customOutput": null,
        "executionStartTime": 1659463453457,
651
652
653
654
        "executionStopTime": 1659463453467,
        "originalKey": "f2898703-d147-4394-978e-fc7f1f559395",
        "requestMsgId": "f2898703-d147-4394-978e-fc7f1f559395",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
655
      },
656
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
      "source": [
        "class InnerBase(ReplaceableBase):\n",
        "    def say_something(self):\n",
        "        raise NotImplementedError\n",
        "\n",
        "\n",
        "@registry.register\n",
        "class Inner1(InnerBase):\n",
        "    a: int = 1\n",
        "    b: str = \"h\"\n",
        "\n",
        "    def say_something(self):\n",
        "        print(\"hello from an Inner1\")\n",
        "\n",
        "\n",
        "@registry.register\n",
        "class Inner2(InnerBase):\n",
        "    a: int = 2\n",
        "\n",
        "    def say_something(self):\n",
        "        print(\"hello from an Inner2\")"
678
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
679
680
681
    },
    {
      "cell_type": "code",
682
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
683
684
      "metadata": {
        "collapsed": false,
685
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
686
687
        "customOutput": null,
        "executionStartTime": 1659463453514,
688
689
690
691
        "executionStopTime": 1659463453592,
        "originalKey": "6f171599-51ee-440f-82d7-a59f84d24624",
        "requestMsgId": "6f171599-51ee-440f-82d7-a59f84d24624",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
692
      },
693
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
694
695
696
697
698
699
700
701
702
703
704
      "source": [
        "class Out(Configurable):\n",
        "    inner: InnerBase\n",
        "    inner_class_type: str = \"Inner1\"\n",
        "    x: int = 19\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def talk(self):\n",
        "        self.inner.say_something()"
705
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
706
707
708
    },
    {
      "cell_type": "code",
709
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
710
711
      "metadata": {
        "collapsed": false,
712
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
713
714
        "customOutput": null,
        "executionStartTime": 1659463191360,
715
716
717
718
        "executionStopTime": 1659463191428,
        "originalKey": "7abaecec-96e6-44df-8c8d-69c36a14b913",
        "requestMsgId": "7abaecec-96e6-44df-8c8d-69c36a14b913",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
719
      },
720
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
721
722
723
      "source": [
        "Out_dc = get_default_args(Out)\n",
        "print(OmegaConf.to_yaml(Out_dc))"
724
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
725
726
727
    },
    {
      "cell_type": "code",
728
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
729
730
      "metadata": {
        "collapsed": false,
731
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
732
733
        "customOutput": null,
        "executionStartTime": 1659463192717,
734
735
736
737
        "executionStopTime": 1659463192754,
        "originalKey": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c",
        "requestMsgId": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
738
      },
739
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
740
741
742
743
      "source": [
        "Out_dc.inner_class_type = \"Inner2\"\n",
        "out = Out(**Out_dc)\n",
        "print(out.inner)"
744
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
745
746
747
    },
    {
      "cell_type": "code",
748
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
749
750
      "metadata": {
        "collapsed": false,
751
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
752
753
        "customOutput": null,
        "executionStartTime": 1659463193751,
754
755
756
757
        "executionStopTime": 1659463193791,
        "originalKey": "aa0e1b04-963a-4724-81b7-5748b598b541",
        "requestMsgId": "aa0e1b04-963a-4724-81b7-5748b598b541",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
758
      },
759
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
760
761
      "source": [
        "out.talk()"
762
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
763
764
    },
    {
765
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
766
767
      "cell_type": "markdown",
      "metadata": {
768
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
769
        "originalKey": "4f78a56c-39cd-4563-a97e-041e5f360f6b",
770
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
771
772
773
      },
      "source": [
        "Note in this case there are many `args` members. It is usually fine to ignore them in the code. They are needed for the config."
774
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
775
776
777
    },
    {
      "cell_type": "code",
778
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
779
780
      "metadata": {
        "collapsed": false,
781
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
782
783
        "customOutput": null,
        "executionStartTime": 1659462145294,
784
785
786
787
        "executionStopTime": 1659462145307,
        "originalKey": "ce7069d5-a813-4286-a7cd-6ff40362105a",
        "requestMsgId": "ce7069d5-a813-4286-a7cd-6ff40362105a",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
788
      },
789
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
790
791
      "source": [
        "print(vars(out))"
792
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
793
794
    },
    {
795
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
796
797
798
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
799
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
800
801
        "customOutput": null,
        "executionStartTime": 1659462231114,
802
803
804
805
        "executionStopTime": 1659462231130,
        "originalKey": "c7f051ff-c264-4b89-80dc-36cf179aafaf",
        "requestMsgId": "c7f051ff-c264-4b89-80dc-36cf179aafaf",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
806
807
808
809
810
      },
      "source": [
        "## 6. Example with torch.nn.Module  🔥\n",
        "Typically in implicitron, we use this system in combination with [`Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)s. \n",
        "Note in this case it is necessary to call `Module.__init__` explicitly in `__post_init__`."
811
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
812
813
814
    },
    {
      "cell_type": "code",
815
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
816
817
      "metadata": {
        "collapsed": false,
818
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
819
820
        "customOutput": null,
        "executionStartTime": 1659462645018,
821
822
823
824
        "executionStopTime": 1659462645037,
        "originalKey": "42d210d6-09e0-4daf-8ccb-411d30f268f4",
        "requestMsgId": "42d210d6-09e0-4daf-8ccb-411d30f268f4",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
825
      },
826
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
827
828
829
830
831
832
833
834
835
836
837
      "source": [
        "class MyLinear(torch.nn.Module, Configurable):\n",
        "    d_in: int = 2\n",
        "    d_out: int = 200\n",
        "\n",
        "    def __post_init__(self):\n",
        "        super().__init__()\n",
        "        self.linear = torch.nn.Linear(in_features=self.d_in, out_features=self.d_out)\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.linear.forward(x)"
838
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
839
840
841
    },
    {
      "cell_type": "code",
842
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
843
844
      "metadata": {
        "collapsed": false,
845
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
846
847
        "customOutput": null,
        "executionStartTime": 1659462692309,
848
849
850
851
        "executionStopTime": 1659462692346,
        "originalKey": "546781fe-5b95-4e48-9cb5-34a634a31313",
        "requestMsgId": "546781fe-5b95-4e48-9cb5-34a634a31313",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
852
      },
853
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
854
855
856
857
858
      "source": [
        "my_linear = MyLinear()\n",
        "input = torch.zeros(2)\n",
        "output = my_linear(input)\n",
        "print(\"output shape:\", output.shape)"
859
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
860
861
    },
    {
862
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
863
864
865
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
866
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
867
868
        "customOutput": null,
        "executionStartTime": 1659462738302,
869
870
871
872
        "executionStopTime": 1659462738419,
        "originalKey": "b6cb71e1-1d54-4e89-a422-0a70772c5c03",
        "requestMsgId": "b6cb71e1-1d54-4e89-a422-0a70772c5c03",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
873
874
875
876
877
      },
      "source": [
        "`my_linear` has all the usual features of a Module.\n",
        "E.g. it can be saved and loaded with `torch.save` and `torch.load`.\n",
        "It has parameters:"
878
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
879
880
881
    },
    {
      "cell_type": "code",
882
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
883
884
      "metadata": {
        "collapsed": false,
885
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
886
887
        "customOutput": null,
        "executionStartTime": 1659462821485,
888
889
890
891
        "executionStopTime": 1659462821501,
        "originalKey": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3",
        "requestMsgId": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
892
      },
893
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
894
895
896
      "source": [
        "for name, value in my_linear.named_parameters():\n",
        "    print(name, value.shape)"
897
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
898
899
    },
    {
900
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
901
902
903
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
904
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
905
906
        "customOutput": null,
        "executionStartTime": 1659463222379,
907
908
909
910
        "executionStopTime": 1659463222409,
        "originalKey": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b",
        "requestMsgId": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
911
912
913
914
915
      },
      "source": [
        "## 7. Example of implementing your own pluggable component \n",
        "Let's say I am using a library with `Out` like in section **5** but I want to implement my own child of InnerBase. \n",
        "All I need to do is register its definition, but I need to do this before expand_args_fields is explicitly or implicitly called on Out."
916
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
917
918
919
    },
    {
      "cell_type": "code",
920
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
921
922
      "metadata": {
        "collapsed": false,
923
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
924
925
        "customOutput": null,
        "executionStartTime": 1659463694644,
926
927
928
929
        "executionStopTime": 1659463694653,
        "originalKey": "d9635511-a52b-43d5-8dae-d5c1a3dd9157",
        "requestMsgId": "d9635511-a52b-43d5-8dae-d5c1a3dd9157",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
930
      },
931
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
932
933
934
935
936
937
938
      "source": [
        "@registry.register\n",
        "class UserImplementedInner(InnerBase):\n",
        "    a: int = 200\n",
        "\n",
        "    def say_something(self):\n",
        "        print(\"hello from the user\")"
939
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
940
941
    },
    {
942
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
943
944
      "cell_type": "markdown",
      "metadata": {
945
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
946
        "originalKey": "f1511aa2-56b8-4ed0-a453-17e2bbfeefe7",
947
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
948
949
950
951
952
953
954
955
      },
      "source": [
        "At this point, we need to redefine the class Out. \n",
        "Otherwise if it has already been expanded without UserImplementedInner, then the following would not work,\n",
        "because the implementations known to a class are fixed when it is expanded.\n",
        "\n",
        "If you are running experiments from a script, the thing to remember here is that you must import your own modules, which register your own implementations,\n",
        "before you *use* the library classes."
956
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
957
958
959
    },
    {
      "cell_type": "code",
960
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
961
962
      "metadata": {
        "collapsed": false,
963
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
964
965
        "customOutput": null,
        "executionStartTime": 1659463745967,
966
967
968
969
        "executionStopTime": 1659463745986,
        "originalKey": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406",
        "requestMsgId": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
970
      },
971
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
972
973
974
975
976
977
978
979
980
981
982
      "source": [
        "class Out(Configurable):\n",
        "    inner: InnerBase\n",
        "    inner_class_type: str = \"Inner1\"\n",
        "    x: int = 19\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def talk(self):\n",
        "        self.inner.say_something()"
983
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
984
985
986
    },
    {
      "cell_type": "code",
987
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
988
989
      "metadata": {
        "collapsed": false,
990
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
991
992
        "customOutput": null,
        "executionStartTime": 1659463747398,
993
994
995
996
        "executionStopTime": 1659463747431,
        "originalKey": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979",
        "requestMsgId": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
997
      },
998
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
999
1000
1001
      "source": [
        "out2 = Out(inner_class_type=\"UserImplementedInner\")\n",
        "print(out2.inner)"
1002
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1003
1004
    },
    {
1005
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1006
1007
1008
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
1009
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1010
1011
        "customOutput": null,
        "executionStartTime": 1659464033633,
1012
1013
1014
1015
        "executionStopTime": 1659464033643,
        "originalKey": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9",
        "requestMsgId": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1016
1017
1018
1019
1020
      },
      "source": [
        "## 8: Example of making a subcomponent pluggable\n",
        "\n",
        "Let's look what needs to happen if we have a subcomponent which we make pluggable, to allow users to supply their own."
1021
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1022
1023
1024
    },
    {
      "cell_type": "code",
1025
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1026
1027
      "metadata": {
        "collapsed": false,
1028
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1029
1030
        "customOutput": null,
        "executionStartTime": 1659464709922,
1031
1032
1033
1034
        "executionStopTime": 1659464709933,
        "originalKey": "e37227b2-6897-4033-8560-9f2040abdeeb",
        "requestMsgId": "e37227b2-6897-4033-8560-9f2040abdeeb",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1035
      },
1036
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
      "source": [
        "class SubComponent(Configurable):\n",
        "    x: float = 0.25\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        return a + self.x\n",
        "\n",
        "\n",
        "class LargeComponent(Configurable):\n",
        "    repeats: int = 4\n",
        "    subcomponent: SubComponent\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        for _ in range(self.repeats):\n",
        "            a = self.subcomponent.apply(a)\n",
        "        return a"
1056
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1057
1058
1059
    },
    {
      "cell_type": "code",
1060
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1061
1062
      "metadata": {
        "collapsed": false,
1063
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1064
1065
        "customOutput": null,
        "executionStartTime": 1659464710339,
1066
1067
1068
1069
        "executionStopTime": 1659464710459,
        "originalKey": "cab4c121-350e-443f-9a49-bd542a9735a2",
        "requestMsgId": "cab4c121-350e-443f-9a49-bd542a9735a2",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1070
      },
1071
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1072
1073
1074
1075
      "source": [
        "large_component = LargeComponent()\n",
        "assert large_component.apply(3) == 4\n",
        "print(OmegaConf.to_yaml(LargeComponent))"
1076
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1077
1078
    },
    {
1079
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1080
1081
      "cell_type": "markdown",
      "metadata": {
1082
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1083
        "originalKey": "be60323a-badf-46e4-a259-72cae1391028",
1084
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1085
1086
1087
      },
      "source": [
        "Made generic:"
1088
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1089
1090
1091
    },
    {
      "cell_type": "code",
1092
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1093
1094
      "metadata": {
        "collapsed": false,
1095
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1096
1097
        "customOutput": null,
        "executionStartTime": 1659464717226,
1098
1099
1100
1101
        "executionStopTime": 1659464717261,
        "originalKey": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8",
        "requestMsgId": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1102
      },
1103
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
      "source": [
        "class SubComponentBase(ReplaceableBase):\n",
        "    def apply(self, a: float) -> float:\n",
        "        raise NotImplementedError\n",
        "\n",
        "\n",
        "@registry.register\n",
        "class SubComponent(SubComponentBase):\n",
        "    x: float = 0.25\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        return a + self.x\n",
        "\n",
        "\n",
        "class LargeComponent(Configurable):\n",
        "    repeats: int = 4\n",
        "    subcomponent: SubComponentBase\n",
        "    subcomponent_class_type: str = \"SubComponent\"\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        for _ in range(self.repeats):\n",
        "            a = self.subcomponent.apply(a)\n",
        "        return a"
1130
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1131
1132
1133
    },
    {
      "cell_type": "code",
1134
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1135
1136
      "metadata": {
        "collapsed": false,
1137
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1138
1139
        "customOutput": null,
        "executionStartTime": 1659464725473,
1140
1141
1142
1143
        "executionStopTime": 1659464725587,
        "originalKey": "bbc3d321-6b49-4356-be75-1a173b1fc3a5",
        "requestMsgId": "bbc3d321-6b49-4356-be75-1a173b1fc3a5",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1144
      },
1145
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1146
1147
1148
1149
      "source": [
        "large_component = LargeComponent()\n",
        "assert large_component.apply(3) == 4\n",
        "print(OmegaConf.to_yaml(LargeComponent))"
1150
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1151
1152
    },
    {
1153
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1154
1155
1156
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
1157
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1158
1159
        "customOutput": null,
        "executionStartTime": 1659464672680,
1160
1161
1162
1163
        "executionStopTime": 1659464673231,
        "originalKey": "5115453a-1d96-4022-97e7-46433e6dcf60",
        "requestMsgId": "5115453a-1d96-4022-97e7-46433e6dcf60",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1164
1165
1166
1167
1168
1169
1170
      },
      "source": [
        "The following things had to change:\n",
        "* The base class SubComponentBase was defined.\n",
        "* SubComponent gained a `@registry.register` decoration and had its base class changed to the new one.\n",
        "* `subcomponent_class_type` was added as a member of the outer class.\n",
        "* In any saved configuration yaml files, the key `subcomponent_args` had to be changed to `subcomponent_SubComponent_args`."
1171
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1172
1173
    },
    {
1174
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1175
1176
1177
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
1178
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1179
1180
        "customOutput": null,
        "executionStartTime": 1659462041307,
1181
1182
1183
1184
        "executionStopTime": 1659462041637,
        "originalKey": "0739269e-5c0e-4551-b06f-f4aab386ba54",
        "requestMsgId": "0739269e-5c0e-4551-b06f-f4aab386ba54",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
      },
      "source": [
        "## Appendix: gotchas ⚠️\n",
        "\n",
        "* Omitting to define `__post_init__` or not calling `run_auto_creation` in it.\n",
        "* Omitting a type annotation on a field. For example, writing \n",
        "```\n",
        "    subcomponent_class_type = \"SubComponent\"\n",
        "```\n",
        "instead of \n",
        "```\n",
        "    subcomponent_class_type: str = \"SubComponent\"\n",
        "```\n",
1198
1199
        "\n"
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1200
    }
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
  ],
  "metadata": {
    "bento_stylesheets": {
      "bento/extensions/flow/main.css": true,
      "bento/extensions/kernel_selector/main.css": true,
      "bento/extensions/kernel_ui/main.css": true,
      "bento/extensions/new_kernel/main.css": true,
      "bento/extensions/system_usage/main.css": true,
      "bento/extensions/theme/main.css": true
    },
    "captumWidgetMessage": {},
    "dataExplorerConfig": {},
    "kernelspec": {
      "display_name": "pytorch3d",
      "language": "python",
      "metadata": {
        "cinder_runtime": false,
        "fbpkg_supported": true,
        "is_prebuilt": true,
        "kernel_name": "bento_kernel_pytorch3d",
        "nightly_builds": true
      },
      "name": "bento_kernel_pytorch3d"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3"
    },
    "last_base_url": "https://9177.od.fbinfra.net:443/",
    "last_kernel_id": "90755407-3729-46f4-ab67-ff2cb1daa5cb",
    "last_msg_id": "f61034eb-826226915ad9548ffbe495ba_6317",
    "last_server_session_id": "d6b46f14-cee7-44c1-8c51-39a38a4ea4c2",
    "outputWidgetContext": {}
  },
  "nbformat": 4,
  "nbformat_minor": 2
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1244
}