implicitron_config_system.ipynb 40.4 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
2
3
4
{
  "cells": [
    {
      "cell_type": "code",
5
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
6
7
      "metadata": {
        "customInput": null,
8
9
10
        "customOutput": null,
        "originalKey": "f0af2d90-cb21-4ab4-b4cb-0fd00dbfb77b",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
11
      },
12
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
13
14
      "source": [
        "# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved."
15
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
16
17
    },
    {
18
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
19
20
      "cell_type": "markdown",
      "metadata": {
21
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
22
        "originalKey": "4e15bfa2-5404-40d0-98b6-eb2732c8b72b",
23
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
24
25
26
      },
      "source": [
        "# Implicitron's config system"
27
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
28
29
    },
    {
30
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
31
32
      "cell_type": "markdown",
      "metadata": {
33
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
34
        "originalKey": "287be985-423d-42e0-a2af-1e8c585e723c",
35
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
36
37
38
39
40
41
42
43
44
      },
      "source": [
        "Implicitron's components are all based on a unified hierarchical configuration system. \n",
        "This allows configurable variables and all defaults to be defined separately for each new component.\n",
        "All configs relevant to an experiment are then automatically composed into a single configuration file that fully specifies the experiment.\n",
        "An especially important feature is extension points where users can insert their own sub-classes of Implicitron's base components.\n",
        "\n",
        "The file which defines this system is [here](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/implicitron/tools/config.py) in the PyTorch3D repo.\n",
        "The Implicitron volumes tutorial contains a simple example of using the config system.\n",
45
46
        "This tutorial provides detailed hands-on experience in using and modifying Implicitron's configurable components.\n"
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
47
48
49
50
    },
    {
      "cell_type": "markdown",
      "metadata": {
51
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
52
        "originalKey": "fde300a2-99cb-4d52-9d5b-4464a2083e0b",
53
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
54
55
56
57
58
59
60
61
62
      },
      "source": [
        "## 0. Install and import modules\n",
        "\n",
        "Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:"
      ]
    },
    {
      "cell_type": "code",
63
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
64
65
      "metadata": {
        "customInput": null,
66
67
68
        "customOutput": null,
        "originalKey": "ad6e94a7-e114-43d3-b038-a5210c7d34c9",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
69
      },
70
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
71
72
73
74
75
76
77
78
79
80
      "source": [
        "import os\n",
        "import sys\n",
        "import torch\n",
        "need_pytorch3d=False\n",
        "try:\n",
        "    import pytorch3d\n",
        "except ModuleNotFoundError:\n",
        "    need_pytorch3d=True\n",
        "if need_pytorch3d:\n",
81
        "    if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n",
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
82
83
84
85
86
87
88
89
90
91
92
93
        "        # We try to install PyTorch3D via a released wheel.\n",
        "        pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
        "        version_str=\"\".join([\n",
        "            f\"py3{sys.version_info.minor}_cu\",\n",
        "            torch.version.cuda.replace(\".\",\"\"),\n",
        "            f\"_pyt{pyt_version_str}\"\n",
        "        ])\n",
        "        !pip install fvcore iopath\n",
        "        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
        "    else:\n",
        "        # We try to install PyTorch3D from source.\n",
        "        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
94
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
95
96
    },
    {
97
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
98
99
      "cell_type": "markdown",
      "metadata": {
100
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
101
        "originalKey": "609896c0-9e2e-4716-b074-b565f0170e32",
102
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
103
104
105
      },
      "source": [
        "Ensure omegaconf is installed. If not, run this cell. (It should not be necessary to restart the runtime.)"
106
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
107
108
109
    },
    {
      "cell_type": "code",
110
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
111
112
      "metadata": {
        "customInput": null,
113
114
115
        "customOutput": null,
        "originalKey": "d1c1851e-b9f2-4236-93c3-19aa4d63041c",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
116
      },
117
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
118
119
      "source": [
        "!pip install omegaconf"
120
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
121
122
123
    },
    {
      "cell_type": "code",
124
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
125
126
      "metadata": {
        "code_folding": [],
127
        "collapsed": false,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
128
129
        "customOutput": null,
        "executionStartTime": 1659465468717,
130
131
132
133
        "executionStopTime": 1659465468738,
        "hidden_ranges": [],
        "originalKey": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f",
        "requestMsgId": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
134
      },
135
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
      "source": [
        "from dataclasses import dataclass\n",
        "from typing import Optional, Tuple\n",
        "\n",
        "import torch\n",
        "from omegaconf import DictConfig, OmegaConf\n",
        "from pytorch3d.implicitron.tools.config import (\n",
        "    Configurable,\n",
        "    ReplaceableBase,\n",
        "    expand_args_fields,\n",
        "    get_default_args,\n",
        "    registry,\n",
        "    run_auto_creation,\n",
        ")"
150
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
151
152
153
154
    },
    {
      "cell_type": "markdown",
      "metadata": {
155
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
156
        "originalKey": "a638bf90-eb6b-424d-b53d-eae11954a717",
157
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
158
159
160
161
162
163
164
165
166
      },
      "source": [
        "## 1. Introducing dataclasses \n",
        "\n",
        "[Type hints](https://docs.python.org/3/library/typing.html) give a taxonomy of types in Python. [Dataclasses](https://docs.python.org/3/library/dataclasses.html) let you create a class based on a list of members which have names, types and possibly default values. The `__init__` function is created automatically, and calls a `__post_init__` function if present as a final step. For example"
      ]
    },
    {
      "cell_type": "code",
167
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
168
169
      "metadata": {
        "collapsed": false,
170
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
171
172
        "customOutput": null,
        "executionStartTime": 1659454972732,
173
174
175
176
        "executionStopTime": 1659454972739,
        "originalKey": "71eaad5e-e198-492e-8610-24b0da9dd4ae",
        "requestMsgId": "71eaad5e-e198-492e-8610-24b0da9dd4ae",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
177
      },
178
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
179
180
181
182
183
184
185
186
187
188
      "source": [
        "@dataclass\n",
        "class MyDataclass:\n",
        "    a: int\n",
        "    b: int = 8\n",
        "    c: Optional[Tuple[int, ...]] = None\n",
        "\n",
        "    def __post_init__(self):\n",
        "        print(f\"created with a = {self.a}\")\n",
        "        self.d = 2 * self.b"
189
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
190
191
192
    },
    {
      "cell_type": "code",
193
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
194
195
      "metadata": {
        "collapsed": false,
196
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
197
198
        "customOutput": null,
        "executionStartTime": 1659454973051,
199
200
201
202
        "executionStopTime": 1659454973077,
        "originalKey": "83202a18-a3d3-44ec-a62d-b3360a302645",
        "requestMsgId": "83202a18-a3d3-44ec-a62d-b3360a302645",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
203
      },
204
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
205
206
207
      "source": [
        "my_dataclass_instance = MyDataclass(a=18)\n",
        "assert my_dataclass_instance.d == 16"
208
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
209
210
    },
    {
211
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
212
213
      "cell_type": "markdown",
      "metadata": {
214
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
215
        "originalKey": "b67ccb9f-dc6c-4994-9b99-b5a1bcfebd70",
216
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
217
218
219
220
221
      },
      "source": [
        "👷 Note that the `dataclass` decorator here is function which modifies the definition of the class itself.\n",
        "It runs immediately after the definition.\n",
        "Our config system requires that implicitron library code contains classes whose modified versions need to be aware of user-defined implementations.\n",
222
223
        "Therefore we need the modification of the class to be delayed. We don't use a decorator.\n"
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
224
225
    },
    {
226
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
227
228
      "cell_type": "markdown",
      "metadata": {
229
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
230
        "originalKey": "3e90f664-99df-4387-9c45-a1ad7939ef3a",
231
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
232
233
234
235
236
      },
      "source": [
        "## 2. Introducing omegaconf and OmegaConf.structured\n",
        "\n",
        "The [omegaconf](https://github.com/omry/omegaconf/) library provides a DictConfig class which is like a `dict` with str keys, but with extra features for ease-of-use as a configuration system."
237
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
238
239
240
    },
    {
      "cell_type": "code",
241
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
242
243
      "metadata": {
        "collapsed": false,
244
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
245
246
        "customOutput": null,
        "executionStartTime": 1659451341683,
247
248
249
250
        "executionStopTime": 1659451341690,
        "originalKey": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174",
        "requestMsgId": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
251
      },
252
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
253
254
255
      "source": [
        "dc = DictConfig({\"a\": 2, \"b\": True, \"c\": None, \"d\": \"hello\"})\n",
        "assert dc.a == dc[\"a\"] == 2"
256
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
257
258
    },
    {
259
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
260
261
      "cell_type": "markdown",
      "metadata": {
262
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
263
        "originalKey": "3b5b76a9-4b76-4784-96ff-2a1212e48e48",
264
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
265
266
267
      },
      "source": [
        "OmegaConf has a serialization to and from yaml. The [Hydra](https://hydra.cc/) library relies on this for its configuration files."
268
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
269
270
271
    },
    {
      "cell_type": "code",
272
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
273
274
      "metadata": {
        "collapsed": false,
275
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
276
277
        "customOutput": null,
        "executionStartTime": 1659451411835,
278
279
280
281
        "executionStopTime": 1659451411936,
        "originalKey": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61",
        "requestMsgId": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
282
      },
283
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
284
285
286
      "source": [
        "print(OmegaConf.to_yaml(dc))\n",
        "assert OmegaConf.create(OmegaConf.to_yaml(dc)) == dc"
287
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
288
289
    },
    {
290
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
291
292
      "cell_type": "markdown",
      "metadata": {
293
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
294
        "originalKey": "777fecdd-8bf6-4fd8-827b-cb8af5477fa8",
295
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
296
297
298
      },
      "source": [
        "OmegaConf.structured provides a DictConfig from a dataclass or instance of a dataclass. Unlike a normal DictConfig, it is type-checked and only known keys can be added."
299
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
300
301
302
    },
    {
      "cell_type": "code",
303
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
304
305
      "metadata": {
        "collapsed": false,
306
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
307
308
        "customOutput": null,
        "executionStartTime": 1659455098879,
309
310
311
312
        "executionStopTime": 1659455098900,
        "originalKey": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162",
        "requestMsgId": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
313
      },
314
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
315
316
317
318
319
320
      "source": [
        "structured = OmegaConf.structured(MyDataclass)\n",
        "assert isinstance(structured, DictConfig)\n",
        "print(structured)\n",
        "print()\n",
        "print(OmegaConf.to_yaml(structured))"
321
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
322
323
    },
    {
324
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
325
326
      "cell_type": "markdown",
      "metadata": {
327
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
328
        "originalKey": "be4446da-e536-4139-9ba3-37669a5b5e61",
329
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
330
331
332
      },
      "source": [
        "`structured` knows it is missing a value for `a`."
333
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
334
335
    },
    {
336
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
337
338
      "cell_type": "markdown",
      "metadata": {
339
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
340
        "originalKey": "864811e8-1a75-4932-a85e-f681b0541ae9",
341
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
342
343
344
      },
      "source": [
        "Such an object has members compatible with the dataclass, so an initialisation can be performed as follows."
345
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
346
347
348
    },
    {
      "cell_type": "code",
349
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
350
351
      "metadata": {
        "collapsed": false,
352
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
353
354
        "customOutput": null,
        "executionStartTime": 1659455580491,
355
356
357
358
        "executionStopTime": 1659455580501,
        "originalKey": "eb88aaa0-c22f-4ffb-813a-ca957b490acb",
        "requestMsgId": "eb88aaa0-c22f-4ffb-813a-ca957b490acb",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
359
      },
360
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
361
362
363
364
      "source": [
        "structured.a = 21\n",
        "my_dataclass_instance2 = MyDataclass(**structured)\n",
        "print(my_dataclass_instance2)"
365
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
366
367
    },
    {
368
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
369
370
      "cell_type": "markdown",
      "metadata": {
371
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
372
        "originalKey": "2d08c81c-9d18-4de9-8464-0da2d89f94f3",
373
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
374
375
376
      },
      "source": [
        "You can also call OmegaConf.structured on an instance."
377
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
378
379
380
    },
    {
      "cell_type": "code",
381
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
382
383
      "metadata": {
        "collapsed": false,
384
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
385
386
        "customOutput": null,
        "executionStartTime": 1659455594700,
387
388
389
390
        "executionStopTime": 1659455594737,
        "originalKey": "5e469bac-32a4-475d-9c09-8b64ba3f2155",
        "requestMsgId": "5e469bac-32a4-475d-9c09-8b64ba3f2155",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
391
      },
392
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
393
394
395
396
      "source": [
        "structured_from_instance = OmegaConf.structured(my_dataclass_instance)\n",
        "my_dataclass_instance3 = MyDataclass(**structured_from_instance)\n",
        "print(my_dataclass_instance3)"
397
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
398
399
    },
    {
400
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
401
402
403
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
404
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
405
406
        "customOutput": null,
        "executionStartTime": 1659452594203,
407
408
409
410
        "executionStopTime": 1659452594333,
        "originalKey": "2ed559e3-8552-465a-938f-30c72a321184",
        "requestMsgId": "2ed559e3-8552-465a-938f-30c72a321184",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
411
412
413
414
415
416
417
      },
      "source": [
        "## 3. Our approach to OmegaConf.structured\n",
        "\n",
        "We provide functions which are equivalent to `OmegaConf.structured` but support more features. \n",
        "To achieve the above using our functions, the following is used.\n",
        "Note that we indicate configurable classes using a special base class `Configurable`, not a decorator."
418
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
419
420
421
    },
    {
      "cell_type": "code",
422
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
423
424
      "metadata": {
        "collapsed": false,
425
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
426
427
        "customOutput": null,
        "executionStartTime": 1659454053323,
428
429
430
431
        "executionStopTime": 1659454061629,
        "originalKey": "9888afbd-e617-4596-ab7a-fc1073f58656",
        "requestMsgId": "9888afbd-e617-4596-ab7a-fc1073f58656",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
432
      },
433
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
434
435
436
437
438
439
440
441
442
      "source": [
        "class MyConfigurable(Configurable):\n",
        "    a: int\n",
        "    b: int = 8\n",
        "    c: Optional[Tuple[int, ...]] = None\n",
        "\n",
        "    def __post_init__(self):\n",
        "        print(f\"created with a = {self.a}\")\n",
        "        self.d = 2 * self.b"
443
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
444
445
446
    },
    {
      "cell_type": "code",
447
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
448
449
      "metadata": {
        "collapsed": false,
450
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
451
452
        "customOutput": null,
        "executionStartTime": 1659454784912,
453
454
455
456
        "executionStopTime": 1659454784928,
        "originalKey": "e43155b4-3da5-4df1-a2f5-da1d0369eec9",
        "requestMsgId": "e43155b4-3da5-4df1-a2f5-da1d0369eec9",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
457
      },
458
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
459
      "source": [
460
461
462
        "# The expand_args_fields function modifies the class like @dataclasses.dataclass.\n",
        "# If it has not been called on a Configurable object before it has been instantiated, it will\n",
        "# be called automatically.\n",
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
463
464
465
        "expand_args_fields(MyConfigurable)\n",
        "my_configurable_instance = MyConfigurable(a=18)\n",
        "assert my_configurable_instance.d == 16"
466
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
467
468
469
    },
    {
      "cell_type": "code",
470
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
471
472
      "metadata": {
        "collapsed": false,
473
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
474
475
        "customOutput": null,
        "executionStartTime": 1659460669541,
476
477
478
479
        "executionStopTime": 1659460669566,
        "originalKey": "96eaae18-dce4-4ee1-b451-1466fea51b9f",
        "requestMsgId": "96eaae18-dce4-4ee1-b451-1466fea51b9f",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
480
      },
481
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
482
      "source": [
483
        "# get_default_args also calls expand_args_fields automatically\n",
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
484
485
486
        "our_structured = get_default_args(MyConfigurable)\n",
        "assert isinstance(our_structured, DictConfig)\n",
        "print(OmegaConf.to_yaml(our_structured))"
487
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
488
489
490
    },
    {
      "cell_type": "code",
491
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
492
493
      "metadata": {
        "collapsed": false,
494
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
495
496
        "customOutput": null,
        "executionStartTime": 1659460454020,
497
498
499
500
        "executionStopTime": 1659460454032,
        "originalKey": "359f7925-68de-42cd-bd34-79a099b1c210",
        "requestMsgId": "359f7925-68de-42cd-bd34-79a099b1c210",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
501
      },
502
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
503
504
505
      "source": [
        "our_structured.a = 21\n",
        "print(MyConfigurable(**our_structured))"
506
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
507
508
    },
    {
509
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
510
511
512
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
513
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
514
515
        "customOutput": null,
        "executionStartTime": 1659460599142,
516
517
518
519
        "executionStopTime": 1659460599149,
        "originalKey": "eac7d385-9365-4098-acf9-4f0a0dbdcb85",
        "requestMsgId": "eac7d385-9365-4098-acf9-4f0a0dbdcb85",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
520
521
522
523
524
525
      },
      "source": [
        "## 4. First enhancement: nested types 🪺\n",
        "\n",
        "Our system allows Configurable classes to contain each other. \n",
        "One thing to remember: add a call to `run_auto_creation` in `__post_init__`."
526
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
527
528
529
    },
    {
      "cell_type": "code",
530
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
531
532
      "metadata": {
        "collapsed": false,
533
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
534
535
        "customOutput": null,
        "executionStartTime": 1659465752418,
536
537
538
539
        "executionStopTime": 1659465752976,
        "originalKey": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a",
        "requestMsgId": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
540
      },
541
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
542
543
544
545
546
547
548
549
550
551
552
553
554
555
      "source": [
        "class Inner(Configurable):\n",
        "    a: int = 8\n",
        "    b: bool = True\n",
        "    c: Tuple[int, ...] = (2, 3, 4, 6)\n",
        "\n",
        "\n",
        "class Outer(Configurable):\n",
        "    inner: Inner\n",
        "    x: str = \"hello\"\n",
        "    xx: bool = False\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)"
556
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
557
558
559
    },
    {
      "cell_type": "code",
560
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
561
562
      "metadata": {
        "collapsed": false,
563
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
564
565
        "customOutput": null,
        "executionStartTime": 1659465762326,
566
567
568
569
        "executionStopTime": 1659465762339,
        "originalKey": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7",
        "requestMsgId": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
570
      },
571
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
572
573
574
      "source": [
        "outer_dc = get_default_args(Outer)\n",
        "print(OmegaConf.to_yaml(outer_dc))"
575
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
576
577
578
    },
    {
      "cell_type": "code",
579
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
580
581
      "metadata": {
        "collapsed": false,
582
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
583
584
        "customOutput": null,
        "executionStartTime": 1659465772894,
585
586
587
588
        "executionStopTime": 1659465772911,
        "originalKey": "0254204b-8c7a-4d40-bba6-5132185f63d7",
        "requestMsgId": "0254204b-8c7a-4d40-bba6-5132185f63d7",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
589
      },
590
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
591
592
593
594
595
596
      "source": [
        "outer = Outer(**outer_dc)\n",
        "assert isinstance(outer, Outer)\n",
        "assert isinstance(outer.inner, Inner)\n",
        "print(vars(outer))\n",
        "print(outer.inner)"
597
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
598
599
    },
    {
600
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
601
602
      "cell_type": "markdown",
      "metadata": {
603
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
604
        "originalKey": "44a78c13-ec92-4a87-808a-c4674b320c22",
605
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
606
607
608
609
610
611
      },
      "source": [
        "Note how inner_args is an extra member of outer. `run_auto_creation(self)` is equivalent to\n",
        "```\n",
        "    self.inner = Inner(**self.inner_args)\n",
        "```"
612
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
613
614
    },
    {
615
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
616
617
618
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
619
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
620
621
        "customOutput": null,
        "executionStartTime": 1659461071129,
622
623
624
625
        "executionStopTime": 1659461071137,
        "originalKey": "af0ec78b-7888-4b0d-9346-63d970d43293",
        "requestMsgId": "af0ec78b-7888-4b0d-9346-63d970d43293",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
626
627
628
629
630
631
632
633
634
635
636
637
      },
      "source": [
        "## 5. Second enhancement: pluggable/replaceable components 🔌\n",
        "\n",
        "If a class uses `ReplaceableBase` as a base class instead of `Configurable`, we call it a replaceable.\n",
        "It indicates that it is designed for child classes to use in its place.\n",
        "We might use `NotImplementedError` to indicate functionality which subclasses are expected to implement.\n",
        "The system maintains a global `registry` containing subclasses of each ReplaceableBase.\n",
        "The subclasses register themselves with it with a decorator.\n",
        "\n",
        "A configurable class (i.e. a class which uses our system, i.e. a child of `Configurable` or `ReplaceableBase`) which contains a ReplaceableBase must also \n",
        "contain a corresponding class_type field of type `str` which indicates which concrete child class to use."
638
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
639
640
641
    },
    {
      "cell_type": "code",
642
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
643
644
      "metadata": {
        "collapsed": false,
645
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
646
647
        "customOutput": null,
        "executionStartTime": 1659463453457,
648
649
650
651
        "executionStopTime": 1659463453467,
        "originalKey": "f2898703-d147-4394-978e-fc7f1f559395",
        "requestMsgId": "f2898703-d147-4394-978e-fc7f1f559395",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
652
      },
653
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
      "source": [
        "class InnerBase(ReplaceableBase):\n",
        "    def say_something(self):\n",
        "        raise NotImplementedError\n",
        "\n",
        "\n",
        "@registry.register\n",
        "class Inner1(InnerBase):\n",
        "    a: int = 1\n",
        "    b: str = \"h\"\n",
        "\n",
        "    def say_something(self):\n",
        "        print(\"hello from an Inner1\")\n",
        "\n",
        "\n",
        "@registry.register\n",
        "class Inner2(InnerBase):\n",
        "    a: int = 2\n",
        "\n",
        "    def say_something(self):\n",
        "        print(\"hello from an Inner2\")"
675
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
676
677
678
    },
    {
      "cell_type": "code",
679
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
680
681
      "metadata": {
        "collapsed": false,
682
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
683
684
        "customOutput": null,
        "executionStartTime": 1659463453514,
685
686
687
688
        "executionStopTime": 1659463453592,
        "originalKey": "6f171599-51ee-440f-82d7-a59f84d24624",
        "requestMsgId": "6f171599-51ee-440f-82d7-a59f84d24624",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
689
      },
690
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
691
692
693
694
695
696
697
698
699
700
701
      "source": [
        "class Out(Configurable):\n",
        "    inner: InnerBase\n",
        "    inner_class_type: str = \"Inner1\"\n",
        "    x: int = 19\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def talk(self):\n",
        "        self.inner.say_something()"
702
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
703
704
705
    },
    {
      "cell_type": "code",
706
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
707
708
      "metadata": {
        "collapsed": false,
709
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
710
711
        "customOutput": null,
        "executionStartTime": 1659463191360,
712
713
714
715
        "executionStopTime": 1659463191428,
        "originalKey": "7abaecec-96e6-44df-8c8d-69c36a14b913",
        "requestMsgId": "7abaecec-96e6-44df-8c8d-69c36a14b913",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
716
      },
717
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
718
719
720
      "source": [
        "Out_dc = get_default_args(Out)\n",
        "print(OmegaConf.to_yaml(Out_dc))"
721
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
722
723
724
    },
    {
      "cell_type": "code",
725
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
726
727
      "metadata": {
        "collapsed": false,
728
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
729
730
        "customOutput": null,
        "executionStartTime": 1659463192717,
731
732
733
734
        "executionStopTime": 1659463192754,
        "originalKey": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c",
        "requestMsgId": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
735
      },
736
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
737
738
739
740
      "source": [
        "Out_dc.inner_class_type = \"Inner2\"\n",
        "out = Out(**Out_dc)\n",
        "print(out.inner)"
741
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
742
743
744
    },
    {
      "cell_type": "code",
745
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
746
747
      "metadata": {
        "collapsed": false,
748
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
749
750
        "customOutput": null,
        "executionStartTime": 1659463193751,
751
752
753
754
        "executionStopTime": 1659463193791,
        "originalKey": "aa0e1b04-963a-4724-81b7-5748b598b541",
        "requestMsgId": "aa0e1b04-963a-4724-81b7-5748b598b541",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
755
      },
756
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
757
758
      "source": [
        "out.talk()"
759
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
760
761
    },
    {
762
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
763
764
      "cell_type": "markdown",
      "metadata": {
765
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
766
        "originalKey": "4f78a56c-39cd-4563-a97e-041e5f360f6b",
767
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
768
769
770
      },
      "source": [
        "Note in this case there are many `args` members. It is usually fine to ignore them in the code. They are needed for the config."
771
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
772
773
774
    },
    {
      "cell_type": "code",
775
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
776
777
      "metadata": {
        "collapsed": false,
778
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
779
780
        "customOutput": null,
        "executionStartTime": 1659462145294,
781
782
783
784
        "executionStopTime": 1659462145307,
        "originalKey": "ce7069d5-a813-4286-a7cd-6ff40362105a",
        "requestMsgId": "ce7069d5-a813-4286-a7cd-6ff40362105a",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
785
      },
786
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
787
788
      "source": [
        "print(vars(out))"
789
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
790
791
    },
    {
792
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
793
794
795
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
796
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
797
798
        "customOutput": null,
        "executionStartTime": 1659462231114,
799
800
801
802
        "executionStopTime": 1659462231130,
        "originalKey": "c7f051ff-c264-4b89-80dc-36cf179aafaf",
        "requestMsgId": "c7f051ff-c264-4b89-80dc-36cf179aafaf",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
803
804
805
806
807
      },
      "source": [
        "## 6. Example with torch.nn.Module  🔥\n",
        "Typically in implicitron, we use this system in combination with [`Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)s. \n",
        "Note in this case it is necessary to call `Module.__init__` explicitly in `__post_init__`."
808
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
809
810
811
    },
    {
      "cell_type": "code",
812
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
813
814
      "metadata": {
        "collapsed": false,
815
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
816
817
        "customOutput": null,
        "executionStartTime": 1659462645018,
818
819
820
821
        "executionStopTime": 1659462645037,
        "originalKey": "42d210d6-09e0-4daf-8ccb-411d30f268f4",
        "requestMsgId": "42d210d6-09e0-4daf-8ccb-411d30f268f4",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
822
      },
823
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
824
825
826
827
828
829
830
831
832
833
834
      "source": [
        "class MyLinear(torch.nn.Module, Configurable):\n",
        "    d_in: int = 2\n",
        "    d_out: int = 200\n",
        "\n",
        "    def __post_init__(self):\n",
        "        super().__init__()\n",
        "        self.linear = torch.nn.Linear(in_features=self.d_in, out_features=self.d_out)\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.linear.forward(x)"
835
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
836
837
838
    },
    {
      "cell_type": "code",
839
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
840
841
      "metadata": {
        "collapsed": false,
842
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
843
844
        "customOutput": null,
        "executionStartTime": 1659462692309,
845
846
847
848
        "executionStopTime": 1659462692346,
        "originalKey": "546781fe-5b95-4e48-9cb5-34a634a31313",
        "requestMsgId": "546781fe-5b95-4e48-9cb5-34a634a31313",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
849
      },
850
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
851
852
853
854
855
      "source": [
        "my_linear = MyLinear()\n",
        "input = torch.zeros(2)\n",
        "output = my_linear(input)\n",
        "print(\"output shape:\", output.shape)"
856
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
857
858
    },
    {
859
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
860
861
862
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
863
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
864
865
        "customOutput": null,
        "executionStartTime": 1659462738302,
866
867
868
869
        "executionStopTime": 1659462738419,
        "originalKey": "b6cb71e1-1d54-4e89-a422-0a70772c5c03",
        "requestMsgId": "b6cb71e1-1d54-4e89-a422-0a70772c5c03",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
870
871
872
873
874
      },
      "source": [
        "`my_linear` has all the usual features of a Module.\n",
        "E.g. it can be saved and loaded with `torch.save` and `torch.load`.\n",
        "It has parameters:"
875
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
876
877
878
    },
    {
      "cell_type": "code",
879
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
880
881
      "metadata": {
        "collapsed": false,
882
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
883
884
        "customOutput": null,
        "executionStartTime": 1659462821485,
885
886
887
888
        "executionStopTime": 1659462821501,
        "originalKey": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3",
        "requestMsgId": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
889
      },
890
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
891
892
893
      "source": [
        "for name, value in my_linear.named_parameters():\n",
        "    print(name, value.shape)"
894
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
895
896
    },
    {
897
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
898
899
900
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
901
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
902
903
        "customOutput": null,
        "executionStartTime": 1659463222379,
904
905
906
907
        "executionStopTime": 1659463222409,
        "originalKey": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b",
        "requestMsgId": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
908
909
910
911
912
      },
      "source": [
        "## 7. Example of implementing your own pluggable component \n",
        "Let's say I am using a library with `Out` like in section **5** but I want to implement my own child of InnerBase. \n",
        "All I need to do is register its definition, but I need to do this before expand_args_fields is explicitly or implicitly called on Out."
913
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
914
915
916
    },
    {
      "cell_type": "code",
917
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
918
919
      "metadata": {
        "collapsed": false,
920
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
921
922
        "customOutput": null,
        "executionStartTime": 1659463694644,
923
924
925
926
        "executionStopTime": 1659463694653,
        "originalKey": "d9635511-a52b-43d5-8dae-d5c1a3dd9157",
        "requestMsgId": "d9635511-a52b-43d5-8dae-d5c1a3dd9157",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
927
      },
928
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
929
930
931
932
933
934
935
      "source": [
        "@registry.register\n",
        "class UserImplementedInner(InnerBase):\n",
        "    a: int = 200\n",
        "\n",
        "    def say_something(self):\n",
        "        print(\"hello from the user\")"
936
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
937
938
    },
    {
939
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
940
941
      "cell_type": "markdown",
      "metadata": {
942
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
943
        "originalKey": "f1511aa2-56b8-4ed0-a453-17e2bbfeefe7",
944
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
945
946
947
948
949
950
951
952
      },
      "source": [
        "At this point, we need to redefine the class Out. \n",
        "Otherwise if it has already been expanded without UserImplementedInner, then the following would not work,\n",
        "because the implementations known to a class are fixed when it is expanded.\n",
        "\n",
        "If you are running experiments from a script, the thing to remember here is that you must import your own modules, which register your own implementations,\n",
        "before you *use* the library classes."
953
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
954
955
956
    },
    {
      "cell_type": "code",
957
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
958
959
      "metadata": {
        "collapsed": false,
960
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
961
962
        "customOutput": null,
        "executionStartTime": 1659463745967,
963
964
965
966
        "executionStopTime": 1659463745986,
        "originalKey": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406",
        "requestMsgId": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
967
      },
968
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
969
970
971
972
973
974
975
976
977
978
979
      "source": [
        "class Out(Configurable):\n",
        "    inner: InnerBase\n",
        "    inner_class_type: str = \"Inner1\"\n",
        "    x: int = 19\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def talk(self):\n",
        "        self.inner.say_something()"
980
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
981
982
983
    },
    {
      "cell_type": "code",
984
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
985
986
      "metadata": {
        "collapsed": false,
987
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
988
989
        "customOutput": null,
        "executionStartTime": 1659463747398,
990
991
992
993
        "executionStopTime": 1659463747431,
        "originalKey": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979",
        "requestMsgId": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
994
      },
995
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
996
997
998
      "source": [
        "out2 = Out(inner_class_type=\"UserImplementedInner\")\n",
        "print(out2.inner)"
999
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1000
1001
    },
    {
1002
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1003
1004
1005
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
1006
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1007
1008
        "customOutput": null,
        "executionStartTime": 1659464033633,
1009
1010
1011
1012
        "executionStopTime": 1659464033643,
        "originalKey": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9",
        "requestMsgId": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1013
1014
1015
1016
1017
      },
      "source": [
        "## 8: Example of making a subcomponent pluggable\n",
        "\n",
        "Let's look what needs to happen if we have a subcomponent which we make pluggable, to allow users to supply their own."
1018
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1019
1020
1021
    },
    {
      "cell_type": "code",
1022
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1023
1024
      "metadata": {
        "collapsed": false,
1025
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1026
1027
        "customOutput": null,
        "executionStartTime": 1659464709922,
1028
1029
1030
1031
        "executionStopTime": 1659464709933,
        "originalKey": "e37227b2-6897-4033-8560-9f2040abdeeb",
        "requestMsgId": "e37227b2-6897-4033-8560-9f2040abdeeb",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1032
      },
1033
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
      "source": [
        "class SubComponent(Configurable):\n",
        "    x: float = 0.25\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        return a + self.x\n",
        "\n",
        "\n",
        "class LargeComponent(Configurable):\n",
        "    repeats: int = 4\n",
        "    subcomponent: SubComponent\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        for _ in range(self.repeats):\n",
        "            a = self.subcomponent.apply(a)\n",
        "        return a"
1053
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1054
1055
1056
    },
    {
      "cell_type": "code",
1057
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1058
1059
      "metadata": {
        "collapsed": false,
1060
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1061
1062
        "customOutput": null,
        "executionStartTime": 1659464710339,
1063
1064
1065
1066
        "executionStopTime": 1659464710459,
        "originalKey": "cab4c121-350e-443f-9a49-bd542a9735a2",
        "requestMsgId": "cab4c121-350e-443f-9a49-bd542a9735a2",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1067
      },
1068
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1069
1070
1071
1072
      "source": [
        "large_component = LargeComponent()\n",
        "assert large_component.apply(3) == 4\n",
        "print(OmegaConf.to_yaml(LargeComponent))"
1073
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1074
1075
    },
    {
1076
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1077
1078
      "cell_type": "markdown",
      "metadata": {
1079
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1080
        "originalKey": "be60323a-badf-46e4-a259-72cae1391028",
1081
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1082
1083
1084
      },
      "source": [
        "Made generic:"
1085
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1086
1087
1088
    },
    {
      "cell_type": "code",
1089
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1090
1091
      "metadata": {
        "collapsed": false,
1092
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1093
1094
        "customOutput": null,
        "executionStartTime": 1659464717226,
1095
1096
1097
1098
        "executionStopTime": 1659464717261,
        "originalKey": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8",
        "requestMsgId": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1099
      },
1100
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
      "source": [
        "class SubComponentBase(ReplaceableBase):\n",
        "    def apply(self, a: float) -> float:\n",
        "        raise NotImplementedError\n",
        "\n",
        "\n",
        "@registry.register\n",
        "class SubComponent(SubComponentBase):\n",
        "    x: float = 0.25\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        return a + self.x\n",
        "\n",
        "\n",
        "class LargeComponent(Configurable):\n",
        "    repeats: int = 4\n",
        "    subcomponent: SubComponentBase\n",
        "    subcomponent_class_type: str = \"SubComponent\"\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        for _ in range(self.repeats):\n",
        "            a = self.subcomponent.apply(a)\n",
        "        return a"
1127
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1128
1129
1130
    },
    {
      "cell_type": "code",
1131
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1132
1133
      "metadata": {
        "collapsed": false,
1134
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1135
1136
        "customOutput": null,
        "executionStartTime": 1659464725473,
1137
1138
1139
1140
        "executionStopTime": 1659464725587,
        "originalKey": "bbc3d321-6b49-4356-be75-1a173b1fc3a5",
        "requestMsgId": "bbc3d321-6b49-4356-be75-1a173b1fc3a5",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1141
      },
1142
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1143
1144
1145
1146
      "source": [
        "large_component = LargeComponent()\n",
        "assert large_component.apply(3) == 4\n",
        "print(OmegaConf.to_yaml(LargeComponent))"
1147
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1148
1149
    },
    {
1150
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1151
1152
1153
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
1154
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1155
1156
        "customOutput": null,
        "executionStartTime": 1659464672680,
1157
1158
1159
1160
        "executionStopTime": 1659464673231,
        "originalKey": "5115453a-1d96-4022-97e7-46433e6dcf60",
        "requestMsgId": "5115453a-1d96-4022-97e7-46433e6dcf60",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1161
1162
1163
1164
1165
1166
1167
      },
      "source": [
        "The following things had to change:\n",
        "* The base class SubComponentBase was defined.\n",
        "* SubComponent gained a `@registry.register` decoration and had its base class changed to the new one.\n",
        "* `subcomponent_class_type` was added as a member of the outer class.\n",
        "* In any saved configuration yaml files, the key `subcomponent_args` had to be changed to `subcomponent_SubComponent_args`."
1168
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1169
1170
    },
    {
1171
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1172
1173
1174
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
1175
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1176
1177
        "customOutput": null,
        "executionStartTime": 1659462041307,
1178
1179
1180
1181
        "executionStopTime": 1659462041637,
        "originalKey": "0739269e-5c0e-4551-b06f-f4aab386ba54",
        "requestMsgId": "0739269e-5c0e-4551-b06f-f4aab386ba54",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
      },
      "source": [
        "## Appendix: gotchas ⚠️\n",
        "\n",
        "* Omitting to define `__post_init__` or not calling `run_auto_creation` in it.\n",
        "* Omitting a type annotation on a field. For example, writing \n",
        "```\n",
        "    subcomponent_class_type = \"SubComponent\"\n",
        "```\n",
        "instead of \n",
        "```\n",
        "    subcomponent_class_type: str = \"SubComponent\"\n",
        "```\n",
1195
1196
        "\n"
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1197
    }
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
  ],
  "metadata": {
    "bento_stylesheets": {
      "bento/extensions/flow/main.css": true,
      "bento/extensions/kernel_selector/main.css": true,
      "bento/extensions/kernel_ui/main.css": true,
      "bento/extensions/new_kernel/main.css": true,
      "bento/extensions/system_usage/main.css": true,
      "bento/extensions/theme/main.css": true
    },
    "captumWidgetMessage": {},
    "dataExplorerConfig": {},
    "kernelspec": {
      "display_name": "pytorch3d",
      "language": "python",
      "metadata": {
        "cinder_runtime": false,
        "fbpkg_supported": true,
        "is_prebuilt": true,
        "kernel_name": "bento_kernel_pytorch3d",
        "nightly_builds": true
      },
      "name": "bento_kernel_pytorch3d"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3"
    },
    "last_base_url": "https://9177.od.fbinfra.net:443/",
    "last_kernel_id": "90755407-3729-46f4-ab67-ff2cb1daa5cb",
    "last_msg_id": "f61034eb-826226915ad9548ffbe495ba_6317",
    "last_server_session_id": "d6b46f14-cee7-44c1-8c51-39a38a4ea4c2",
    "outputWidgetContext": {}
  },
  "nbformat": 4,
  "nbformat_minor": 2
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1241
}