implicitron_config_system.ipynb 40.6 KB
Newer Older
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1
2
3
4
{
  "cells": [
    {
      "cell_type": "code",
5
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
6
7
      "metadata": {
        "customInput": null,
8
9
10
        "customOutput": null,
        "originalKey": "f0af2d90-cb21-4ab4-b4cb-0fd00dbfb77b",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
11
      },
12
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
13
14
      "source": [
        "# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved."
15
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
16
17
    },
    {
18
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
19
20
      "cell_type": "markdown",
      "metadata": {
21
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
22
        "originalKey": "4e15bfa2-5404-40d0-98b6-eb2732c8b72b",
23
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
24
25
26
      },
      "source": [
        "# Implicitron's config system"
27
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
28
29
    },
    {
30
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
31
32
      "cell_type": "markdown",
      "metadata": {
33
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
34
        "originalKey": "287be985-423d-42e0-a2af-1e8c585e723c",
35
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
36
37
38
39
40
41
42
43
44
      },
      "source": [
        "Implicitron's components are all based on a unified hierarchical configuration system. \n",
        "This allows configurable variables and all defaults to be defined separately for each new component.\n",
        "All configs relevant to an experiment are then automatically composed into a single configuration file that fully specifies the experiment.\n",
        "An especially important feature is extension points where users can insert their own sub-classes of Implicitron's base components.\n",
        "\n",
        "The file which defines this system is [here](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/implicitron/tools/config.py) in the PyTorch3D repo.\n",
        "The Implicitron volumes tutorial contains a simple example of using the config system.\n",
45
46
        "This tutorial provides detailed hands-on experience in using and modifying Implicitron's configurable components.\n"
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
47
48
49
50
    },
    {
      "cell_type": "markdown",
      "metadata": {
51
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
52
        "originalKey": "fde300a2-99cb-4d52-9d5b-4464a2083e0b",
53
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
54
55
56
57
58
59
60
61
62
      },
      "source": [
        "## 0. Install and import modules\n",
        "\n",
        "Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:"
      ]
    },
    {
      "cell_type": "code",
63
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
64
65
      "metadata": {
        "customInput": null,
66
67
68
        "customOutput": null,
        "originalKey": "ad6e94a7-e114-43d3-b038-a5210c7d34c9",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
69
      },
70
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
71
72
73
74
      "source": [
        "import os\n",
        "import sys\n",
        "import torch\n",
Roeia Kishk's avatar
Roeia Kishk committed
75
        "import subprocess\n",
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
76
77
78
79
80
81
        "need_pytorch3d=False\n",
        "try:\n",
        "    import pytorch3d\n",
        "except ModuleNotFoundError:\n",
        "    need_pytorch3d=True\n",
        "if need_pytorch3d:\n",
Roeia Kishk's avatar
Roeia Kishk committed
82
83
84
85
86
87
88
89
90
        "    pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
        "    version_str=\"\".join([\n",
        "        f\"py3{sys.version_info.minor}_cu\",\n",
        "        torch.version.cuda.replace(\".\",\"\"),\n",
        "        f\"_pyt{pyt_version_str}\"\n",
        "    ])\n",
        "    !pip install fvcore iopath\n",
        "    if sys.platform.startswith(\"linux\"):\n",
        "        print(\"Trying to install wheel for PyTorch3D\")\n",
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
91
        "        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
Roeia Kishk's avatar
Roeia Kishk committed
92
93
94
95
96
97
98
99
        "        pip_list = !pip freeze\n",
        "        need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for  i in pip_list)\n",
        "    if need_pytorch3d:\n",
        "        print(f\"failed to find/install wheel for {version_str}\")\n",
        "if need_pytorch3d:\n",
        "    print(\"Installing PyTorch3D from source\")\n",
        "    !pip install ninja\n",
        "    !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
100
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
101
102
    },
    {
103
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
104
105
      "cell_type": "markdown",
      "metadata": {
106
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
107
        "originalKey": "609896c0-9e2e-4716-b074-b565f0170e32",
108
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
109
110
111
      },
      "source": [
        "Ensure omegaconf is installed. If not, run this cell. (It should not be necessary to restart the runtime.)"
112
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
113
114
115
    },
    {
      "cell_type": "code",
116
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
117
118
      "metadata": {
        "customInput": null,
119
120
121
        "customOutput": null,
        "originalKey": "d1c1851e-b9f2-4236-93c3-19aa4d63041c",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
122
      },
123
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
124
125
      "source": [
        "!pip install omegaconf"
126
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
127
128
129
    },
    {
      "cell_type": "code",
130
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
131
132
      "metadata": {
        "code_folding": [],
133
        "collapsed": false,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
134
135
        "customOutput": null,
        "executionStartTime": 1659465468717,
136
137
138
139
        "executionStopTime": 1659465468738,
        "hidden_ranges": [],
        "originalKey": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f",
        "requestMsgId": "5ac7ef23-b74c-46b2-b8d3-799524d7ba4f"
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
140
      },
141
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
142
143
144
145
146
147
148
149
150
151
152
153
154
155
      "source": [
        "from dataclasses import dataclass\n",
        "from typing import Optional, Tuple\n",
        "\n",
        "import torch\n",
        "from omegaconf import DictConfig, OmegaConf\n",
        "from pytorch3d.implicitron.tools.config import (\n",
        "    Configurable,\n",
        "    ReplaceableBase,\n",
        "    expand_args_fields,\n",
        "    get_default_args,\n",
        "    registry,\n",
        "    run_auto_creation,\n",
        ")"
156
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
157
158
159
160
    },
    {
      "cell_type": "markdown",
      "metadata": {
161
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
162
        "originalKey": "a638bf90-eb6b-424d-b53d-eae11954a717",
163
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
164
165
166
167
168
169
170
171
172
      },
      "source": [
        "## 1. Introducing dataclasses \n",
        "\n",
        "[Type hints](https://docs.python.org/3/library/typing.html) give a taxonomy of types in Python. [Dataclasses](https://docs.python.org/3/library/dataclasses.html) let you create a class based on a list of members which have names, types and possibly default values. The `__init__` function is created automatically, and calls a `__post_init__` function if present as a final step. For example"
      ]
    },
    {
      "cell_type": "code",
173
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
174
175
      "metadata": {
        "collapsed": false,
176
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
177
178
        "customOutput": null,
        "executionStartTime": 1659454972732,
179
180
181
182
        "executionStopTime": 1659454972739,
        "originalKey": "71eaad5e-e198-492e-8610-24b0da9dd4ae",
        "requestMsgId": "71eaad5e-e198-492e-8610-24b0da9dd4ae",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
183
      },
184
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
185
186
187
188
189
190
191
192
193
194
      "source": [
        "@dataclass\n",
        "class MyDataclass:\n",
        "    a: int\n",
        "    b: int = 8\n",
        "    c: Optional[Tuple[int, ...]] = None\n",
        "\n",
        "    def __post_init__(self):\n",
        "        print(f\"created with a = {self.a}\")\n",
        "        self.d = 2 * self.b"
195
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
196
197
198
    },
    {
      "cell_type": "code",
199
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
200
201
      "metadata": {
        "collapsed": false,
202
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
203
204
        "customOutput": null,
        "executionStartTime": 1659454973051,
205
206
207
208
        "executionStopTime": 1659454973077,
        "originalKey": "83202a18-a3d3-44ec-a62d-b3360a302645",
        "requestMsgId": "83202a18-a3d3-44ec-a62d-b3360a302645",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
209
      },
210
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
211
212
213
      "source": [
        "my_dataclass_instance = MyDataclass(a=18)\n",
        "assert my_dataclass_instance.d == 16"
214
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
215
216
    },
    {
217
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
218
219
      "cell_type": "markdown",
      "metadata": {
220
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
221
        "originalKey": "b67ccb9f-dc6c-4994-9b99-b5a1bcfebd70",
222
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
223
224
225
226
227
      },
      "source": [
        "👷 Note that the `dataclass` decorator here is function which modifies the definition of the class itself.\n",
        "It runs immediately after the definition.\n",
        "Our config system requires that implicitron library code contains classes whose modified versions need to be aware of user-defined implementations.\n",
228
229
        "Therefore we need the modification of the class to be delayed. We don't use a decorator.\n"
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
230
231
    },
    {
232
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
233
234
      "cell_type": "markdown",
      "metadata": {
235
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
236
        "originalKey": "3e90f664-99df-4387-9c45-a1ad7939ef3a",
237
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
238
239
240
241
242
      },
      "source": [
        "## 2. Introducing omegaconf and OmegaConf.structured\n",
        "\n",
        "The [omegaconf](https://github.com/omry/omegaconf/) library provides a DictConfig class which is like a `dict` with str keys, but with extra features for ease-of-use as a configuration system."
243
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
244
245
246
    },
    {
      "cell_type": "code",
247
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
248
249
      "metadata": {
        "collapsed": false,
250
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
251
252
        "customOutput": null,
        "executionStartTime": 1659451341683,
253
254
255
256
        "executionStopTime": 1659451341690,
        "originalKey": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174",
        "requestMsgId": "81c73c9b-27ee-4aab-b55e-fb0dd67fe174",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
257
      },
258
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
259
260
261
      "source": [
        "dc = DictConfig({\"a\": 2, \"b\": True, \"c\": None, \"d\": \"hello\"})\n",
        "assert dc.a == dc[\"a\"] == 2"
262
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
263
264
    },
    {
265
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
266
267
      "cell_type": "markdown",
      "metadata": {
268
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
269
        "originalKey": "3b5b76a9-4b76-4784-96ff-2a1212e48e48",
270
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
271
272
273
      },
      "source": [
        "OmegaConf has a serialization to and from yaml. The [Hydra](https://hydra.cc/) library relies on this for its configuration files."
274
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
275
276
277
    },
    {
      "cell_type": "code",
278
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
279
280
      "metadata": {
        "collapsed": false,
281
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
282
283
        "customOutput": null,
        "executionStartTime": 1659451411835,
284
285
286
287
        "executionStopTime": 1659451411936,
        "originalKey": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61",
        "requestMsgId": "d7a25ec1-caea-46bc-a1da-4b1f040c4b61",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
288
      },
289
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
290
291
292
      "source": [
        "print(OmegaConf.to_yaml(dc))\n",
        "assert OmegaConf.create(OmegaConf.to_yaml(dc)) == dc"
293
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
294
295
    },
    {
296
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
297
298
      "cell_type": "markdown",
      "metadata": {
299
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
300
        "originalKey": "777fecdd-8bf6-4fd8-827b-cb8af5477fa8",
301
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
302
303
304
      },
      "source": [
        "OmegaConf.structured provides a DictConfig from a dataclass or instance of a dataclass. Unlike a normal DictConfig, it is type-checked and only known keys can be added."
305
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
306
307
308
    },
    {
      "cell_type": "code",
309
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
310
311
      "metadata": {
        "collapsed": false,
312
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
313
314
        "customOutput": null,
        "executionStartTime": 1659455098879,
315
316
317
318
        "executionStopTime": 1659455098900,
        "originalKey": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162",
        "requestMsgId": "de36efb4-0b08-4fb8-bb3a-be1b2c0cd162",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
319
      },
320
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
321
322
323
324
325
326
      "source": [
        "structured = OmegaConf.structured(MyDataclass)\n",
        "assert isinstance(structured, DictConfig)\n",
        "print(structured)\n",
        "print()\n",
        "print(OmegaConf.to_yaml(structured))"
327
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
328
329
    },
    {
330
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
331
332
      "cell_type": "markdown",
      "metadata": {
333
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
334
        "originalKey": "be4446da-e536-4139-9ba3-37669a5b5e61",
335
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
336
337
338
      },
      "source": [
        "`structured` knows it is missing a value for `a`."
339
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
340
341
    },
    {
342
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
343
344
      "cell_type": "markdown",
      "metadata": {
345
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
346
        "originalKey": "864811e8-1a75-4932-a85e-f681b0541ae9",
347
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
348
349
350
      },
      "source": [
        "Such an object has members compatible with the dataclass, so an initialisation can be performed as follows."
351
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
352
353
354
    },
    {
      "cell_type": "code",
355
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
356
357
      "metadata": {
        "collapsed": false,
358
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
359
360
        "customOutput": null,
        "executionStartTime": 1659455580491,
361
362
363
364
        "executionStopTime": 1659455580501,
        "originalKey": "eb88aaa0-c22f-4ffb-813a-ca957b490acb",
        "requestMsgId": "eb88aaa0-c22f-4ffb-813a-ca957b490acb",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
365
      },
366
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
367
368
369
370
      "source": [
        "structured.a = 21\n",
        "my_dataclass_instance2 = MyDataclass(**structured)\n",
        "print(my_dataclass_instance2)"
371
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
372
373
    },
    {
374
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
375
376
      "cell_type": "markdown",
      "metadata": {
377
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
378
        "originalKey": "2d08c81c-9d18-4de9-8464-0da2d89f94f3",
379
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
380
381
382
      },
      "source": [
        "You can also call OmegaConf.structured on an instance."
383
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
384
385
386
    },
    {
      "cell_type": "code",
387
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
388
389
      "metadata": {
        "collapsed": false,
390
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
391
392
        "customOutput": null,
        "executionStartTime": 1659455594700,
393
394
395
396
        "executionStopTime": 1659455594737,
        "originalKey": "5e469bac-32a4-475d-9c09-8b64ba3f2155",
        "requestMsgId": "5e469bac-32a4-475d-9c09-8b64ba3f2155",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
397
      },
398
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
399
400
401
402
      "source": [
        "structured_from_instance = OmegaConf.structured(my_dataclass_instance)\n",
        "my_dataclass_instance3 = MyDataclass(**structured_from_instance)\n",
        "print(my_dataclass_instance3)"
403
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
404
405
    },
    {
406
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
407
408
409
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
410
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
411
412
        "customOutput": null,
        "executionStartTime": 1659452594203,
413
414
415
416
        "executionStopTime": 1659452594333,
        "originalKey": "2ed559e3-8552-465a-938f-30c72a321184",
        "requestMsgId": "2ed559e3-8552-465a-938f-30c72a321184",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
417
418
419
420
421
422
423
      },
      "source": [
        "## 3. Our approach to OmegaConf.structured\n",
        "\n",
        "We provide functions which are equivalent to `OmegaConf.structured` but support more features. \n",
        "To achieve the above using our functions, the following is used.\n",
        "Note that we indicate configurable classes using a special base class `Configurable`, not a decorator."
424
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
425
426
427
    },
    {
      "cell_type": "code",
428
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
429
430
      "metadata": {
        "collapsed": false,
431
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
432
433
        "customOutput": null,
        "executionStartTime": 1659454053323,
434
435
436
437
        "executionStopTime": 1659454061629,
        "originalKey": "9888afbd-e617-4596-ab7a-fc1073f58656",
        "requestMsgId": "9888afbd-e617-4596-ab7a-fc1073f58656",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
438
      },
439
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
440
441
442
443
444
445
446
447
448
      "source": [
        "class MyConfigurable(Configurable):\n",
        "    a: int\n",
        "    b: int = 8\n",
        "    c: Optional[Tuple[int, ...]] = None\n",
        "\n",
        "    def __post_init__(self):\n",
        "        print(f\"created with a = {self.a}\")\n",
        "        self.d = 2 * self.b"
449
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
450
451
452
    },
    {
      "cell_type": "code",
453
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
454
455
      "metadata": {
        "collapsed": false,
456
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
457
458
        "customOutput": null,
        "executionStartTime": 1659454784912,
459
460
461
462
        "executionStopTime": 1659454784928,
        "originalKey": "e43155b4-3da5-4df1-a2f5-da1d0369eec9",
        "requestMsgId": "e43155b4-3da5-4df1-a2f5-da1d0369eec9",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
463
      },
464
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
465
      "source": [
466
467
468
        "# The expand_args_fields function modifies the class like @dataclasses.dataclass.\n",
        "# If it has not been called on a Configurable object before it has been instantiated, it will\n",
        "# be called automatically.\n",
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
469
470
471
        "expand_args_fields(MyConfigurable)\n",
        "my_configurable_instance = MyConfigurable(a=18)\n",
        "assert my_configurable_instance.d == 16"
472
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
473
474
475
    },
    {
      "cell_type": "code",
476
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
477
478
      "metadata": {
        "collapsed": false,
479
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
480
481
        "customOutput": null,
        "executionStartTime": 1659460669541,
482
483
484
485
        "executionStopTime": 1659460669566,
        "originalKey": "96eaae18-dce4-4ee1-b451-1466fea51b9f",
        "requestMsgId": "96eaae18-dce4-4ee1-b451-1466fea51b9f",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
486
      },
487
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
488
      "source": [
489
        "# get_default_args also calls expand_args_fields automatically\n",
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
490
491
492
        "our_structured = get_default_args(MyConfigurable)\n",
        "assert isinstance(our_structured, DictConfig)\n",
        "print(OmegaConf.to_yaml(our_structured))"
493
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
494
495
496
    },
    {
      "cell_type": "code",
497
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
498
499
      "metadata": {
        "collapsed": false,
500
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
501
502
        "customOutput": null,
        "executionStartTime": 1659460454020,
503
504
505
506
        "executionStopTime": 1659460454032,
        "originalKey": "359f7925-68de-42cd-bd34-79a099b1c210",
        "requestMsgId": "359f7925-68de-42cd-bd34-79a099b1c210",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
507
      },
508
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
509
510
511
      "source": [
        "our_structured.a = 21\n",
        "print(MyConfigurable(**our_structured))"
512
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
513
514
    },
    {
515
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
516
517
518
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
519
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
520
521
        "customOutput": null,
        "executionStartTime": 1659460599142,
522
523
524
525
        "executionStopTime": 1659460599149,
        "originalKey": "eac7d385-9365-4098-acf9-4f0a0dbdcb85",
        "requestMsgId": "eac7d385-9365-4098-acf9-4f0a0dbdcb85",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
526
527
528
529
530
531
      },
      "source": [
        "## 4. First enhancement: nested types 🪺\n",
        "\n",
        "Our system allows Configurable classes to contain each other. \n",
        "One thing to remember: add a call to `run_auto_creation` in `__post_init__`."
532
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
533
534
535
    },
    {
      "cell_type": "code",
536
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
537
538
      "metadata": {
        "collapsed": false,
539
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
540
541
        "customOutput": null,
        "executionStartTime": 1659465752418,
542
543
544
545
        "executionStopTime": 1659465752976,
        "originalKey": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a",
        "requestMsgId": "9bd70ee5-4ec1-4021-bce5-9638b5088c0a",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
546
      },
547
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
548
549
550
551
552
553
554
555
556
557
558
559
560
561
      "source": [
        "class Inner(Configurable):\n",
        "    a: int = 8\n",
        "    b: bool = True\n",
        "    c: Tuple[int, ...] = (2, 3, 4, 6)\n",
        "\n",
        "\n",
        "class Outer(Configurable):\n",
        "    inner: Inner\n",
        "    x: str = \"hello\"\n",
        "    xx: bool = False\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)"
562
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
563
564
565
    },
    {
      "cell_type": "code",
566
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
567
568
      "metadata": {
        "collapsed": false,
569
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
570
571
        "customOutput": null,
        "executionStartTime": 1659465762326,
572
573
574
575
        "executionStopTime": 1659465762339,
        "originalKey": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7",
        "requestMsgId": "9f2b9f98-b54b-46cc-9b02-9e902cb279e7",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
576
      },
577
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
578
579
580
      "source": [
        "outer_dc = get_default_args(Outer)\n",
        "print(OmegaConf.to_yaml(outer_dc))"
581
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
582
583
584
    },
    {
      "cell_type": "code",
585
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
586
587
      "metadata": {
        "collapsed": false,
588
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
589
590
        "customOutput": null,
        "executionStartTime": 1659465772894,
591
592
593
594
        "executionStopTime": 1659465772911,
        "originalKey": "0254204b-8c7a-4d40-bba6-5132185f63d7",
        "requestMsgId": "0254204b-8c7a-4d40-bba6-5132185f63d7",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
595
      },
596
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
597
598
599
600
601
602
      "source": [
        "outer = Outer(**outer_dc)\n",
        "assert isinstance(outer, Outer)\n",
        "assert isinstance(outer.inner, Inner)\n",
        "print(vars(outer))\n",
        "print(outer.inner)"
603
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
604
605
    },
    {
606
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
607
608
      "cell_type": "markdown",
      "metadata": {
609
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
610
        "originalKey": "44a78c13-ec92-4a87-808a-c4674b320c22",
611
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
612
613
614
615
616
617
      },
      "source": [
        "Note how inner_args is an extra member of outer. `run_auto_creation(self)` is equivalent to\n",
        "```\n",
        "    self.inner = Inner(**self.inner_args)\n",
        "```"
618
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
619
620
    },
    {
621
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
622
623
624
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
625
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
626
627
        "customOutput": null,
        "executionStartTime": 1659461071129,
628
629
630
631
        "executionStopTime": 1659461071137,
        "originalKey": "af0ec78b-7888-4b0d-9346-63d970d43293",
        "requestMsgId": "af0ec78b-7888-4b0d-9346-63d970d43293",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
632
633
634
635
636
637
638
639
640
641
642
643
      },
      "source": [
        "## 5. Second enhancement: pluggable/replaceable components 🔌\n",
        "\n",
        "If a class uses `ReplaceableBase` as a base class instead of `Configurable`, we call it a replaceable.\n",
        "It indicates that it is designed for child classes to use in its place.\n",
        "We might use `NotImplementedError` to indicate functionality which subclasses are expected to implement.\n",
        "The system maintains a global `registry` containing subclasses of each ReplaceableBase.\n",
        "The subclasses register themselves with it with a decorator.\n",
        "\n",
        "A configurable class (i.e. a class which uses our system, i.e. a child of `Configurable` or `ReplaceableBase`) which contains a ReplaceableBase must also \n",
        "contain a corresponding class_type field of type `str` which indicates which concrete child class to use."
644
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
645
646
647
    },
    {
      "cell_type": "code",
648
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
649
650
      "metadata": {
        "collapsed": false,
651
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
652
653
        "customOutput": null,
        "executionStartTime": 1659463453457,
654
655
656
657
        "executionStopTime": 1659463453467,
        "originalKey": "f2898703-d147-4394-978e-fc7f1f559395",
        "requestMsgId": "f2898703-d147-4394-978e-fc7f1f559395",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
658
      },
659
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
      "source": [
        "class InnerBase(ReplaceableBase):\n",
        "    def say_something(self):\n",
        "        raise NotImplementedError\n",
        "\n",
        "\n",
        "@registry.register\n",
        "class Inner1(InnerBase):\n",
        "    a: int = 1\n",
        "    b: str = \"h\"\n",
        "\n",
        "    def say_something(self):\n",
        "        print(\"hello from an Inner1\")\n",
        "\n",
        "\n",
        "@registry.register\n",
        "class Inner2(InnerBase):\n",
        "    a: int = 2\n",
        "\n",
        "    def say_something(self):\n",
        "        print(\"hello from an Inner2\")"
681
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
682
683
684
    },
    {
      "cell_type": "code",
685
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
686
687
      "metadata": {
        "collapsed": false,
688
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
689
690
        "customOutput": null,
        "executionStartTime": 1659463453514,
691
692
693
694
        "executionStopTime": 1659463453592,
        "originalKey": "6f171599-51ee-440f-82d7-a59f84d24624",
        "requestMsgId": "6f171599-51ee-440f-82d7-a59f84d24624",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
695
      },
696
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
697
698
699
700
701
702
703
704
705
706
707
      "source": [
        "class Out(Configurable):\n",
        "    inner: InnerBase\n",
        "    inner_class_type: str = \"Inner1\"\n",
        "    x: int = 19\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def talk(self):\n",
        "        self.inner.say_something()"
708
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
709
710
711
    },
    {
      "cell_type": "code",
712
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
713
714
      "metadata": {
        "collapsed": false,
715
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
716
717
        "customOutput": null,
        "executionStartTime": 1659463191360,
718
719
720
721
        "executionStopTime": 1659463191428,
        "originalKey": "7abaecec-96e6-44df-8c8d-69c36a14b913",
        "requestMsgId": "7abaecec-96e6-44df-8c8d-69c36a14b913",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
722
      },
723
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
724
725
726
      "source": [
        "Out_dc = get_default_args(Out)\n",
        "print(OmegaConf.to_yaml(Out_dc))"
727
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
728
729
730
    },
    {
      "cell_type": "code",
731
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
732
733
      "metadata": {
        "collapsed": false,
734
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
735
736
        "customOutput": null,
        "executionStartTime": 1659463192717,
737
738
739
740
        "executionStopTime": 1659463192754,
        "originalKey": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c",
        "requestMsgId": "c82dc2ca-ba8f-4a44-aed3-43f6b52ec28c",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
741
      },
742
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
743
744
745
746
      "source": [
        "Out_dc.inner_class_type = \"Inner2\"\n",
        "out = Out(**Out_dc)\n",
        "print(out.inner)"
747
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
748
749
750
    },
    {
      "cell_type": "code",
751
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
752
753
      "metadata": {
        "collapsed": false,
754
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
755
756
        "customOutput": null,
        "executionStartTime": 1659463193751,
757
758
759
760
        "executionStopTime": 1659463193791,
        "originalKey": "aa0e1b04-963a-4724-81b7-5748b598b541",
        "requestMsgId": "aa0e1b04-963a-4724-81b7-5748b598b541",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
761
      },
762
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
763
764
      "source": [
        "out.talk()"
765
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
766
767
    },
    {
768
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
769
770
      "cell_type": "markdown",
      "metadata": {
771
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
772
        "originalKey": "4f78a56c-39cd-4563-a97e-041e5f360f6b",
773
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
774
775
776
      },
      "source": [
        "Note in this case there are many `args` members. It is usually fine to ignore them in the code. They are needed for the config."
777
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
778
779
780
    },
    {
      "cell_type": "code",
781
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
782
783
      "metadata": {
        "collapsed": false,
784
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
785
786
        "customOutput": null,
        "executionStartTime": 1659462145294,
787
788
789
790
        "executionStopTime": 1659462145307,
        "originalKey": "ce7069d5-a813-4286-a7cd-6ff40362105a",
        "requestMsgId": "ce7069d5-a813-4286-a7cd-6ff40362105a",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
791
      },
792
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
793
794
      "source": [
        "print(vars(out))"
795
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
796
797
    },
    {
798
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
799
800
801
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
802
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
803
804
        "customOutput": null,
        "executionStartTime": 1659462231114,
805
806
807
808
        "executionStopTime": 1659462231130,
        "originalKey": "c7f051ff-c264-4b89-80dc-36cf179aafaf",
        "requestMsgId": "c7f051ff-c264-4b89-80dc-36cf179aafaf",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
809
810
811
812
813
      },
      "source": [
        "## 6. Example with torch.nn.Module  🔥\n",
        "Typically in implicitron, we use this system in combination with [`Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)s. \n",
        "Note in this case it is necessary to call `Module.__init__` explicitly in `__post_init__`."
814
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
815
816
817
    },
    {
      "cell_type": "code",
818
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
819
820
      "metadata": {
        "collapsed": false,
821
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
822
823
        "customOutput": null,
        "executionStartTime": 1659462645018,
824
825
826
827
        "executionStopTime": 1659462645037,
        "originalKey": "42d210d6-09e0-4daf-8ccb-411d30f268f4",
        "requestMsgId": "42d210d6-09e0-4daf-8ccb-411d30f268f4",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
828
      },
829
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
830
831
832
833
834
835
836
837
838
839
840
      "source": [
        "class MyLinear(torch.nn.Module, Configurable):\n",
        "    d_in: int = 2\n",
        "    d_out: int = 200\n",
        "\n",
        "    def __post_init__(self):\n",
        "        super().__init__()\n",
        "        self.linear = torch.nn.Linear(in_features=self.d_in, out_features=self.d_out)\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.linear.forward(x)"
841
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
842
843
844
    },
    {
      "cell_type": "code",
845
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
846
847
      "metadata": {
        "collapsed": false,
848
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
849
850
        "customOutput": null,
        "executionStartTime": 1659462692309,
851
852
853
854
        "executionStopTime": 1659462692346,
        "originalKey": "546781fe-5b95-4e48-9cb5-34a634a31313",
        "requestMsgId": "546781fe-5b95-4e48-9cb5-34a634a31313",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
855
      },
856
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
857
858
859
860
861
      "source": [
        "my_linear = MyLinear()\n",
        "input = torch.zeros(2)\n",
        "output = my_linear(input)\n",
        "print(\"output shape:\", output.shape)"
862
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
863
864
    },
    {
865
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
866
867
868
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
869
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
870
871
        "customOutput": null,
        "executionStartTime": 1659462738302,
872
873
874
875
        "executionStopTime": 1659462738419,
        "originalKey": "b6cb71e1-1d54-4e89-a422-0a70772c5c03",
        "requestMsgId": "b6cb71e1-1d54-4e89-a422-0a70772c5c03",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
876
877
878
879
880
      },
      "source": [
        "`my_linear` has all the usual features of a Module.\n",
        "E.g. it can be saved and loaded with `torch.save` and `torch.load`.\n",
        "It has parameters:"
881
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
882
883
884
    },
    {
      "cell_type": "code",
885
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
886
887
      "metadata": {
        "collapsed": false,
888
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
889
890
        "customOutput": null,
        "executionStartTime": 1659462821485,
891
892
893
894
        "executionStopTime": 1659462821501,
        "originalKey": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3",
        "requestMsgId": "47e8c53e-2d2c-4b41-8aa3-65aa3ea8a7d3",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
895
      },
896
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
897
898
899
      "source": [
        "for name, value in my_linear.named_parameters():\n",
        "    print(name, value.shape)"
900
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
901
902
    },
    {
903
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
904
905
906
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
907
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
908
909
        "customOutput": null,
        "executionStartTime": 1659463222379,
910
911
912
913
        "executionStopTime": 1659463222409,
        "originalKey": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b",
        "requestMsgId": "a01f0ea7-55f2-4af9-8e81-45dddf40f13b",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
914
915
916
917
918
      },
      "source": [
        "## 7. Example of implementing your own pluggable component \n",
        "Let's say I am using a library with `Out` like in section **5** but I want to implement my own child of InnerBase. \n",
        "All I need to do is register its definition, but I need to do this before expand_args_fields is explicitly or implicitly called on Out."
919
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
920
921
922
    },
    {
      "cell_type": "code",
923
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
924
925
      "metadata": {
        "collapsed": false,
926
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
927
928
        "customOutput": null,
        "executionStartTime": 1659463694644,
929
930
931
932
        "executionStopTime": 1659463694653,
        "originalKey": "d9635511-a52b-43d5-8dae-d5c1a3dd9157",
        "requestMsgId": "d9635511-a52b-43d5-8dae-d5c1a3dd9157",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
933
      },
934
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
935
936
937
938
939
940
941
      "source": [
        "@registry.register\n",
        "class UserImplementedInner(InnerBase):\n",
        "    a: int = 200\n",
        "\n",
        "    def say_something(self):\n",
        "        print(\"hello from the user\")"
942
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
943
944
    },
    {
945
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
946
947
      "cell_type": "markdown",
      "metadata": {
948
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
949
        "originalKey": "f1511aa2-56b8-4ed0-a453-17e2bbfeefe7",
950
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
951
952
953
954
955
956
957
958
      },
      "source": [
        "At this point, we need to redefine the class Out. \n",
        "Otherwise if it has already been expanded without UserImplementedInner, then the following would not work,\n",
        "because the implementations known to a class are fixed when it is expanded.\n",
        "\n",
        "If you are running experiments from a script, the thing to remember here is that you must import your own modules, which register your own implementations,\n",
        "before you *use* the library classes."
959
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
960
961
962
    },
    {
      "cell_type": "code",
963
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
964
965
      "metadata": {
        "collapsed": false,
966
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
967
968
        "customOutput": null,
        "executionStartTime": 1659463745967,
969
970
971
972
        "executionStopTime": 1659463745986,
        "originalKey": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406",
        "requestMsgId": "c7bb5a6e-682b-4eb0-a214-e0f5990b9406",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
973
      },
974
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
975
976
977
978
979
980
981
982
983
984
985
      "source": [
        "class Out(Configurable):\n",
        "    inner: InnerBase\n",
        "    inner_class_type: str = \"Inner1\"\n",
        "    x: int = 19\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def talk(self):\n",
        "        self.inner.say_something()"
986
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
987
988
989
    },
    {
      "cell_type": "code",
990
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
991
992
      "metadata": {
        "collapsed": false,
993
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
994
995
        "customOutput": null,
        "executionStartTime": 1659463747398,
996
997
998
999
        "executionStopTime": 1659463747431,
        "originalKey": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979",
        "requestMsgId": "b6ecdc86-4b7b-47c6-9f45-a7e557c94979",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1000
      },
1001
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1002
1003
1004
      "source": [
        "out2 = Out(inner_class_type=\"UserImplementedInner\")\n",
        "print(out2.inner)"
1005
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1006
1007
    },
    {
1008
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1009
1010
1011
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
1012
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1013
1014
        "customOutput": null,
        "executionStartTime": 1659464033633,
1015
1016
1017
1018
        "executionStopTime": 1659464033643,
        "originalKey": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9",
        "requestMsgId": "c7fe0df3-da13-40b8-9b06-6b1f37f37bb9",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1019
1020
1021
1022
1023
      },
      "source": [
        "## 8: Example of making a subcomponent pluggable\n",
        "\n",
        "Let's look what needs to happen if we have a subcomponent which we make pluggable, to allow users to supply their own."
1024
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1025
1026
1027
    },
    {
      "cell_type": "code",
1028
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1029
1030
      "metadata": {
        "collapsed": false,
1031
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1032
1033
        "customOutput": null,
        "executionStartTime": 1659464709922,
1034
1035
1036
1037
        "executionStopTime": 1659464709933,
        "originalKey": "e37227b2-6897-4033-8560-9f2040abdeeb",
        "requestMsgId": "e37227b2-6897-4033-8560-9f2040abdeeb",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1038
      },
1039
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
      "source": [
        "class SubComponent(Configurable):\n",
        "    x: float = 0.25\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        return a + self.x\n",
        "\n",
        "\n",
        "class LargeComponent(Configurable):\n",
        "    repeats: int = 4\n",
        "    subcomponent: SubComponent\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        for _ in range(self.repeats):\n",
        "            a = self.subcomponent.apply(a)\n",
        "        return a"
1059
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1060
1061
1062
    },
    {
      "cell_type": "code",
1063
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1064
1065
      "metadata": {
        "collapsed": false,
1066
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1067
1068
        "customOutput": null,
        "executionStartTime": 1659464710339,
1069
1070
1071
1072
        "executionStopTime": 1659464710459,
        "originalKey": "cab4c121-350e-443f-9a49-bd542a9735a2",
        "requestMsgId": "cab4c121-350e-443f-9a49-bd542a9735a2",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1073
      },
1074
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1075
1076
1077
1078
      "source": [
        "large_component = LargeComponent()\n",
        "assert large_component.apply(3) == 4\n",
        "print(OmegaConf.to_yaml(LargeComponent))"
1079
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1080
1081
    },
    {
1082
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1083
1084
      "cell_type": "markdown",
      "metadata": {
1085
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1086
        "originalKey": "be60323a-badf-46e4-a259-72cae1391028",
1087
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1088
1089
1090
      },
      "source": [
        "Made generic:"
1091
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1092
1093
1094
    },
    {
      "cell_type": "code",
1095
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1096
1097
      "metadata": {
        "collapsed": false,
1098
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1099
1100
        "customOutput": null,
        "executionStartTime": 1659464717226,
1101
1102
1103
1104
        "executionStopTime": 1659464717261,
        "originalKey": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8",
        "requestMsgId": "fc0d8cdb-4627-4427-b92a-17ac1c1b37b8",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1105
      },
1106
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
      "source": [
        "class SubComponentBase(ReplaceableBase):\n",
        "    def apply(self, a: float) -> float:\n",
        "        raise NotImplementedError\n",
        "\n",
        "\n",
        "@registry.register\n",
        "class SubComponent(SubComponentBase):\n",
        "    x: float = 0.25\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        return a + self.x\n",
        "\n",
        "\n",
        "class LargeComponent(Configurable):\n",
        "    repeats: int = 4\n",
        "    subcomponent: SubComponentBase\n",
        "    subcomponent_class_type: str = \"SubComponent\"\n",
        "\n",
        "    def __post_init__(self):\n",
        "        run_auto_creation(self)\n",
        "\n",
        "    def apply(self, a: float) -> float:\n",
        "        for _ in range(self.repeats):\n",
        "            a = self.subcomponent.apply(a)\n",
        "        return a"
1133
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1134
1135
1136
    },
    {
      "cell_type": "code",
1137
      "execution_count": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1138
1139
      "metadata": {
        "collapsed": false,
1140
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1141
1142
        "customOutput": null,
        "executionStartTime": 1659464725473,
1143
1144
1145
1146
        "executionStopTime": 1659464725587,
        "originalKey": "bbc3d321-6b49-4356-be75-1a173b1fc3a5",
        "requestMsgId": "bbc3d321-6b49-4356-be75-1a173b1fc3a5",
        "showInput": true
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1147
      },
1148
      "outputs": [],
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1149
1150
1151
1152
      "source": [
        "large_component = LargeComponent()\n",
        "assert large_component.apply(3) == 4\n",
        "print(OmegaConf.to_yaml(LargeComponent))"
1153
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1154
1155
    },
    {
1156
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1157
1158
1159
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
1160
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1161
1162
        "customOutput": null,
        "executionStartTime": 1659464672680,
1163
1164
1165
1166
        "executionStopTime": 1659464673231,
        "originalKey": "5115453a-1d96-4022-97e7-46433e6dcf60",
        "requestMsgId": "5115453a-1d96-4022-97e7-46433e6dcf60",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1167
1168
1169
1170
1171
1172
1173
      },
      "source": [
        "The following things had to change:\n",
        "* The base class SubComponentBase was defined.\n",
        "* SubComponent gained a `@registry.register` decoration and had its base class changed to the new one.\n",
        "* `subcomponent_class_type` was added as a member of the outer class.\n",
        "* In any saved configuration yaml files, the key `subcomponent_args` had to be changed to `subcomponent_SubComponent_args`."
1174
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1175
1176
    },
    {
1177
      "attachments": {},
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1178
1179
1180
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
1181
        "customInput": null,
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1182
1183
        "customOutput": null,
        "executionStartTime": 1659462041307,
1184
1185
1186
1187
        "executionStopTime": 1659462041637,
        "originalKey": "0739269e-5c0e-4551-b06f-f4aab386ba54",
        "requestMsgId": "0739269e-5c0e-4551-b06f-f4aab386ba54",
        "showInput": false
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
      },
      "source": [
        "## Appendix: gotchas ⚠️\n",
        "\n",
        "* Omitting to define `__post_init__` or not calling `run_auto_creation` in it.\n",
        "* Omitting a type annotation on a field. For example, writing \n",
        "```\n",
        "    subcomponent_class_type = \"SubComponent\"\n",
        "```\n",
        "instead of \n",
        "```\n",
        "    subcomponent_class_type: str = \"SubComponent\"\n",
        "```\n",
1201
1202
        "\n"
      ]
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1203
    }
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
  ],
  "metadata": {
    "bento_stylesheets": {
      "bento/extensions/flow/main.css": true,
      "bento/extensions/kernel_selector/main.css": true,
      "bento/extensions/kernel_ui/main.css": true,
      "bento/extensions/new_kernel/main.css": true,
      "bento/extensions/system_usage/main.css": true,
      "bento/extensions/theme/main.css": true
    },
    "captumWidgetMessage": {},
    "dataExplorerConfig": {},
    "kernelspec": {
      "display_name": "pytorch3d",
      "language": "python",
      "metadata": {
        "cinder_runtime": false,
        "fbpkg_supported": true,
        "is_prebuilt": true,
        "kernel_name": "bento_kernel_pytorch3d",
        "nightly_builds": true
      },
      "name": "bento_kernel_pytorch3d"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3"
    },
    "last_base_url": "https://9177.od.fbinfra.net:443/",
    "last_kernel_id": "90755407-3729-46f4-ab67-ff2cb1daa5cb",
    "last_msg_id": "f61034eb-826226915ad9548ffbe495ba_6317",
    "last_server_session_id": "d6b46f14-cee7-44c1-8c51-39a38a4ea4c2",
    "outputWidgetContext": {}
  },
  "nbformat": 4,
  "nbformat_minor": 2
Jeremy Reizenstein's avatar
Jeremy Reizenstein committed
1247
}