bert_model.py 20.8 KB
Newer Older
thomwolf's avatar
thomwolf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
"""
A PyTorch implementation of Google's BERT Model.

From "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding"
By Jacob Devlin, Ming-Wei Chang, Kenton Lee, Kristina Toutanova
Link: http://arxiv.org/abs/1810.04805

Adapted from HuggingFace's OpenAI PyTorch code and its adaptation by AllenNLP.
"""

# pylint: disable=invalid-name,arguments-differ
from typing import NamedTuple, List
import copy
import io
import json
import logging
import math
import pathlib
import re
import tarfile

import numpy as np
import torch
from torch.nn import Parameter

# pylint: disable=line-too-long
_PARAMETER_NAMES = ["model/we:0",
                    "model/h0/attn/c_attn/w:0", "model/h0/attn/c_attn/b:0", "model/h0/attn/c_proj/w:0",
                    "model/h0/attn/c_proj/b:0", "model/h0/ln_1/g:0", "model/h0/ln_1/b:0", "model/h0/mlp/c_fc/w:0",
                    "model/h0/mlp/c_fc/b:0", "model/h0/mlp/c_proj/w:0", "model/h0/mlp/c_proj/b:0", "model/h0/ln_2/g:0",
                    "model/h0/ln_2/b:0", "model/h1/attn/c_attn/w:0", "model/h1/attn/c_attn/b:0", "model/h1/attn/c_proj/w:0",
                    "model/h1/attn/c_proj/b:0", "model/h1/ln_1/g:0", "model/h1/ln_1/b:0", "model/h1/mlp/c_fc/w:0",
                    "model/h1/mlp/c_fc/b:0", "model/h1/mlp/c_proj/w:0", "model/h1/mlp/c_proj/b:0", "model/h1/ln_2/g:0",
                    "model/h1/ln_2/b:0", "model/h2/attn/c_attn/w:0", "model/h2/attn/c_attn/b:0", "model/h2/attn/c_proj/w:0",
                    "model/h2/attn/c_proj/b:0", "model/h2/ln_1/g:0", "model/h2/ln_1/b:0", "model/h2/mlp/c_fc/w:0",
                    "model/h2/mlp/c_fc/b:0", "model/h2/mlp/c_proj/w:0", "model/h2/mlp/c_proj/b:0", "model/h2/ln_2/g:0",
                    "model/h2/ln_2/b:0", "model/h3/attn/c_attn/w:0", "model/h3/attn/c_attn/b:0", "model/h3/attn/c_proj/w:0",
                    "model/h3/attn/c_proj/b:0", "model/h3/ln_1/g:0", "model/h3/ln_1/b:0", "model/h3/mlp/c_fc/w:0",
                    "model/h3/mlp/c_fc/b:0", "model/h3/mlp/c_proj/w:0", "model/h3/mlp/c_proj/b:0", "model/h3/ln_2/g:0",
                    "model/h3/ln_2/b:0", "model/h4/attn/c_attn/w:0", "model/h4/attn/c_attn/b:0", "model/h4/attn/c_proj/w:0",
                    "model/h4/attn/c_proj/b:0", "model/h4/ln_1/g:0", "model/h4/ln_1/b:0", "model/h4/mlp/c_fc/w:0",
                    "model/h4/mlp/c_fc/b:0", "model/h4/mlp/c_proj/w:0", "model/h4/mlp/c_proj/b:0", "model/h4/ln_2/g:0",
                    "model/h4/ln_2/b:0", "model/h5/attn/c_attn/w:0", "model/h5/attn/c_attn/b:0", "model/h5/attn/c_proj/w:0",
                    "model/h5/attn/c_proj/b:0", "model/h5/ln_1/g:0", "model/h5/ln_1/b:0", "model/h5/mlp/c_fc/w:0",
                    "model/h5/mlp/c_fc/b:0", "model/h5/mlp/c_proj/w:0", "model/h5/mlp/c_proj/b:0", "model/h5/ln_2/g:0",
                    "model/h5/ln_2/b:0", "model/h6/attn/c_attn/w:0", "model/h6/attn/c_attn/b:0", "model/h6/attn/c_proj/w:0",
                    "model/h6/attn/c_proj/b:0", "model/h6/ln_1/g:0", "model/h6/ln_1/b:0", "model/h6/mlp/c_fc/w:0",
                    "model/h6/mlp/c_fc/b:0", "model/h6/mlp/c_proj/w:0", "model/h6/mlp/c_proj/b:0", "model/h6/ln_2/g:0",
                    "model/h6/ln_2/b:0", "model/h7/attn/c_attn/w:0", "model/h7/attn/c_attn/b:0", "model/h7/attn/c_proj/w:0",
                    "model/h7/attn/c_proj/b:0", "model/h7/ln_1/g:0", "model/h7/ln_1/b:0", "model/h7/mlp/c_fc/w:0",
                    "model/h7/mlp/c_fc/b:0", "model/h7/mlp/c_proj/w:0", "model/h7/mlp/c_proj/b:0", "model/h7/ln_2/g:0",
                    "model/h7/ln_2/b:0", "model/h8/attn/c_attn/w:0", "model/h8/attn/c_attn/b:0", "model/h8/attn/c_proj/w:0",
                    "model/h8/attn/c_proj/b:0", "model/h8/ln_1/g:0", "model/h8/ln_1/b:0", "model/h8/mlp/c_fc/w:0",
                    "model/h8/mlp/c_fc/b:0", "model/h8/mlp/c_proj/w:0", "model/h8/mlp/c_proj/b:0", "model/h8/ln_2/g:0",
                    "model/h8/ln_2/b:0", "model/h9/attn/c_attn/w:0", "model/h9/attn/c_attn/b:0", "model/h9/attn/c_proj/w:0",
                    "model/h9/attn/c_proj/b:0", "model/h9/ln_1/g:0", "model/h9/ln_1/b:0", "model/h9/mlp/c_fc/w:0",
                    "model/h9/mlp/c_fc/b:0", "model/h9/mlp/c_proj/w:0", "model/h9/mlp/c_proj/b:0", "model/h9/ln_2/g:0",
                    "model/h9/ln_2/b:0", "model/h10/attn/c_attn/w:0", "model/h10/attn/c_attn/b:0", "model/h10/attn/c_proj/w:0",
                    "model/h10/attn/c_proj/b:0", "model/h10/ln_1/g:0", "model/h10/ln_1/b:0", "model/h10/mlp/c_fc/w:0",
                    "model/h10/mlp/c_fc/b:0", "model/h10/mlp/c_proj/w:0", "model/h10/mlp/c_proj/b:0", "model/h10/ln_2/g:0",
                    "model/h10/ln_2/b:0", "model/h11/attn/c_attn/w:0", "model/h11/attn/c_attn/b:0", "model/h11/attn/c_proj/w:0",
                    "model/h11/attn/c_proj/b:0", "model/h11/ln_1/g:0", "model/h11/ln_1/b:0", "model/h11/mlp/c_fc/w:0",
                    "model/h11/mlp/c_fc/b:0", "model/h11/mlp/c_proj/w:0", "model/h11/mlp/c_proj/b:0", "model/h11/ln_2/g:0",
                    "model/h11/ln_2/b:0", "model/clf/w:0", "model/clf/b:0"]
# pylint: enable=line-too-long

def gelu(x: torch.Tensor) -> torch.Tensor:
    return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

class BERTConfig(NamedTuple):
    """
    BERT's hyper-parameters
    """
    embedding_dim: int = 768
    num_heads: int = 12
    dropout: float = 0.1

class LayerNorm(torch.nn.Module):
    """
    A layernorm module in the Tensorflow style (with the epsilon inside the square root).
    """

    def __init__(self, n_state, e=1e-5):
        super().__init__()
        self.g = torch.nn.Parameter(torch.ones(n_state))
        self.b = torch.nn.Parameter(torch.zeros(n_state))
        self.e = e

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.e)
        return self.g * x + self.b


class Conv1D(torch.nn.Module):
    """
    A batched linear layer using torch.addmm
    """
    def __init__(self, nf: int, rf: int, nx: int) -> None:
        super().__init__()
        self.rf = rf
        self.nf = nf
        w = torch.empty(nx, nf)
        torch.nn.init.normal_(w, std=0.02)
        self.w = Parameter(w)
        self.b = Parameter(torch.zeros(nf))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        size_out = x.size()[:-1] + (self.nf,)
        x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
        x = x.view(*size_out)
        return x

class Attention(torch.nn.Module):
    """
    A self-attention layer comprising a sequence of:
        - a linear layer: instance of the `Conv1D` class,
        - spliting the inputs in key, value, query tensors (x.split),
        - reshaping key, value, query tensors according to the number of head (self.split_heads)
        - appying self attention (self._attn)
        - merging back the heads results (self.merge_heads)
        - a linear layer: instance of the `Conv1D` class,
        - a dropout layer: instance of `torch.nn.Dropout` class.

    See above for the details of Conv1D.
    """
    def __init__(self,
                 nx: int,
                 n_ctx: int,
                 config: BERTConfig,
                 scale: bool = False) -> None:
        super().__init__()
        n_state = nx  # in Attention: n_state=768 (nx=n_embd)
        # [switch nx => n_state from Block to Attention to keep identical to TF implem]
        assert n_state % config.num_heads == 0
        self.register_buffer('b', torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx))
        self.n_head = config.num_heads
        self.split_size = n_state
        self.scale = scale
        self.c_attn = Conv1D(n_state * 3, 1, nx)
        self.c_proj = Conv1D(n_state, 1, nx)
        self.attn_dropout = torch.nn.Dropout(config.dropout)
        self.resid_dropout = torch.nn.Dropout(config.dropout)

    def _attn(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
        w = torch.matmul(q, k)
        if self.scale:
            w = w / math.sqrt(v.size(-1))
        w = w * self.b + -1e9 * (1 - self.b)  # TF implem method: mask_attn_weights
        w = torch.nn.Softmax(dim=-1)(w)
        w = self.attn_dropout(w)
        return torch.matmul(w, v)

    def merge_heads(self, x: torch.Tensor):
        # pylint: disable=no-self-use
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)  # in Tensorflow implem: fct merge_states

    def split_heads(self, x: torch.Tensor, k: bool = False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)  # in Tensorflow implem: fct split_states
        if k:
            return x.permute(0, 2, 3, 1)
        else:
            return x.permute(0, 2, 1, 3)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.c_attn(x)
        query, key, value = x.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
        a = self._attn(query, key, value)
        a = self.merge_heads(a)
        a = self.c_proj(a)
        a = self.resid_dropout(a)
        return a


class MLP(torch.nn.Module):
    """
    A multi-layer perceptron layer comprising a sequence of:
        - a linear layer: instance of the `Conv1D` class,
        - an activation function: the `gelu` function,
        - another linear layer: instance of the `Conv1D` class,
        - a dropout layer: instance of `torch.nn.Dropout` class.

    See above for the details of Conv1D and the gelu function.
    """
    def __init__(self, n_state: int, config: BERTConfig) -> None:  # in MLP: n_state=3072 (4 * n_embd)
        super().__init__()
        self.c_fc = Conv1D(n_state, 1, config.embedding_dim)
        self.c_proj = Conv1D(config.embedding_dim, 1, n_state)
        self.act = gelu
        self.dropout = torch.nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.act(self.c_fc(x))
        h2 = self.c_proj(h)
        return self.dropout(h2)


class Block(torch.nn.Module):
    """
    A Transformer Block comprising a sequence of:
        - a self-attention layer: instance of the `Attention` class,
        - a Layer Normalization layer: instance of the `LayerNorm` class,
        - a Multi-layer perceptron layer: instance of the `MLP` class,
        - another Layer Normalization layer: instance of the `LayerNorm` class.

    See above for the details of these classes.
    """
    def __init__(self,
                 n_ctx: int,
                 config: BERTConfig,
                 scale: bool = False) -> None:
        super().__init__()
        nx = config.embedding_dim
        self.attn = Attention(nx, n_ctx, config, scale)
        self.ln_1 = LayerNorm(nx)
        self.mlp = MLP(4 * nx, config)
        self.ln_2 = LayerNorm(nx)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        a = self.attn(x)
        n = self.ln_1(x + a)
        m = self.mlp(n)
        h = self.ln_2(n + m)
        return h

class BERT(torch.nn.Module):
    """
    Google's BERT Model.
    Default parameters are the ones for Google's pretrained model.

    Parameters
    ----------
    vocab_size: ``int`` (optional, default: 40478)
        The size of the vocabulary (number of byte pair embeddings)
        excluding the n_special embeddings (if any), and the positional embeddings.
    n_ctx: ``int`` (optional, default: 512)
        The number of positional encodings to use for evaluation.
    embedding_dim: ``int`` (optional, default: 768)
        The dimension of the output embeddings.
    num_heads: ``int`` (optional, default: 12)
        How many "heads" the attention has.
    num_layers: ``int`` (optional, default: 12)
        How many layers of "blocks" the transformer has.
    dropout_probability: ``float`` (optional, default: 0.1)
        Dropout for all layers.
    model_path: ``str`` (optional, default: ``None``)
        A tar.gz file containing serialized model weights. If supplied,
        the weights will be loaded from that file.
    requires_grad: ``bool`` (optional, default: ``False``)
        If true, the transformer will be fine-tuneable.
    n_special: ``int`` (optional, default: ``-1``)
        The number of special tokens added to the byte pair vocabulary
    """
    def __init__(self,
                 vocab_size: int = 40478,
                 n_ctx: int = 512,
                 embedding_dim: int = 768,
                 num_heads: int = 12,
                 num_layers: int = 12,
                 dropout_probability: float = 0.1,
                 model_path: str = None,
                 requires_grad: bool = False,
                 n_special: int = -1) -> None:
        super().__init__()

        config = BERTConfig(
                embedding_dim,
                num_heads,
                embedding_dropout_probability,
                attention_dropout_probability,
                residual_dropout_probability,
                activation_function,
        )

        # the embedding size is vocab_size + n_special embeddings + n_ctx
        embedding_size = vocab_size + max(n_special, 0) + n_ctx
        self.vocab_size = embedding_size
        self.n_ctx = n_ctx
        self.n_special = n_special

        self.num_output_layers = 1 + num_layers

        self.embed = torch.nn.Embedding(embedding_size, embedding_dim)
        self.drop = torch.nn.Dropout(embedding_dropout_probability)

        block = Block(n_ctx, config, scale=True)
        self.h = torch.nn.ModuleList([copy.deepcopy(block) for _ in range(num_layers)])
        self.decoder = torch.nn.Linear(embedding_dim, embedding_size, bias=False)
        self.decoder.weight = self.embed.weight  # Tied weights
        # To reproduce the noise_shape parameter of TF implementation

        torch.nn.init.normal_(self.embed.weight, std=0.02)

        for parameter in self.parameters():
            parameter.requires_grad = requires_grad

        if model_path:
            self.load_weights(model_path, n_special=n_special, n_ctx=n_ctx)

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        #x = x.view(-1, x.size(2), x.size(3))

        # x is (batch_size, sequence_length) tensor of byte-pair ids

        # e is (batch_size, sequence_length, 2, embedding_dim) tensor of embeddings
        e = self.embed(x)

        # h is (batch_size, sequence_length, embedding_dim)
        h = e.sum(dim=2)

        all_layers = [h]
        for block in self.h:
            h = block(h)
            all_layers.append(h)

        # result is list of (batch_size, sequence_length, embedding_dim)
        return all_layers

    def load_weights(self,
                     bert_model_path: str,
                     n_ctx: int = -1,
                     n_special: int = -1,
                     n_transfer: int = 12,
                     n_embd: int = 768,
                     names: List[str] = _PARAMETER_NAMES) -> None:
        # pylint: disable=dangerous-default-value

        logger.info(f"loading weights from {bert_model_path}")
        # if `file_path` is a URL, redirect to the cache

        with tarfile.open(bert_model_path) as tmp:
            num_params_files = len([member for member in tmp.getmembers() if member.name.endswith('.npy')])
            shapesfile = tmp.extractfile('model/params_shapes.json')
            if shapesfile:
                shapes = json.loads(shapesfile.read())
            else:
                raise ConfigurationError("unable to find model/params_shapes.json in the archive")

            # numpy can't read from a tarfile directly, so we need a workaround
            # https://github.com/numpy/numpy/issues/7989#issuecomment-341656702
            init_params: List[np.ndarray] = []
            for n in range(num_params_files):
                array_file = io.BytesIO()
                array_file.write(tmp.extractfile(f'model/params_{n}.npy').read())
                array_file.seek(0)
                # each np.load is a (11653478,) numpy array
                init_params.append(np.load(array_file))

        # init_params is a list of 10 arrays of size (11653578,)
        # shapes are [[512, 768], [40478, 768], [1, 768, 2304], [2304], ...  # 146 elts
        # products are [512 * 768, 40478 * 768, ...]
        # offsets is [512 * 768, 512 * 768 + 40478 * 768, ...]
        offsets = np.cumsum([np.prod(shape) for shape in shapes])

        # split into the 146 subarrays corresponding to shapes
        init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]

        # reshape
        init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]

        # truncate if necessary
        if n_ctx > 0:
            # positional embeddings?
            # init_params[0] is (512, 768) = (max_chars, embedding_dim)
            init_params[0] = init_params[0][:n_ctx]

        # combine init_params[1] and init_params[0]
        if n_special > 0:
            # init_params[1] is (40478, 768)
            # special is (n_special, 768)
            # init_params[0] is (512, 768)
            # result is (40990 + n_special, 768)
            init_params[0] = np.concatenate(
                    [init_params[1],
                     (np.random.randn(n_special, n_embd) * 0.02).astype(np.float32),
                     init_params[0]],
                    0
            )
        else:
            # result is (40990, 768)
            init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
        del init_params[1]

        # number of dimensions to transfer, 12 per layer, plus one extra
        if n_transfer == -1:
            n_transfer = 0
        else:
            n_transfer = 1 + n_transfer * 12

        # squeeze?
        init_params = [arr.squeeze() for arr in init_params]

        # embedding.weight is (vocab_size, embedding_dim)
        # make sure init_params[0] has the same shape
        try:
            assert self.embed.weight.shape == init_params[0].shape
        except AssertionError as e:
            e.args += (self.embed.weight.shape, init_params[0].shape)
            raise

        # and then assign it
        self.embed.weight.data = torch.from_numpy(init_params[0])
        self.decoder.weight = self.embed.weight

        # for each (name, array) pair to transfer over
        for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
                                            # "model/h0/attn/c_attn/w:0"
            name = name[6:]                 # "h0/attn/c_attn/w:0"
            assert name[-2:] == ":0"
            name = name[:-2]                # "h0/attn/c_attn/w"
            name_parts = name.split('/')    # ['h0', 'attn', 'c_attn', 'w']

            pointer = self
            for m_name in name_parts:
                if re.fullmatch(r'[A-Za-z]+\d+', m_name):
                    l = re.split(r'(\d+)', m_name)   # ['h', '0', '']
                else:
                    l = [m_name]                     # ['attn']
                pointer = getattr(pointer, l[0])
                if len(l) >= 2:
                    num = int(l[1])
                    pointer = pointer[num]
            try:
                assert pointer.shape == ip.shape
            except AssertionError as e:
                e.args += (pointer.shape, ip.shape)
                raise

            pointer.data = torch.from_numpy(ip)  # pylint: disable=attribute-defined-outside-init

    def dump_weights(self,
                     output_dir: str,
                     num_pieces: int = 10) -> None:
        output_path = pathlib.Path(output_dir) / 'model'
        output_path.mkdir(exist_ok=True, parents=True)  # pylint: disable=no-member

        named_parameters = list(self.named_parameters())

        # embedding weights get special treatment
        _, array = named_parameters[0]
        num_bpe = self.vocab_size - self.n_ctx
        byte_pair_embeddings = array[:num_bpe]
        positional_embeddings = array[num_bpe:]

        arrays = [positional_embeddings.numpy().ravel(), byte_pair_embeddings.numpy().ravel()]
        shapes = [positional_embeddings.shape, byte_pair_embeddings.shape]
        names = ["model/we:0"]

        for param_name, tensor in named_parameters[1:]:
            param_name = f'h{param_name}'            # 'h0.attn.c_attn.w'
            parts = param_name.split(".")            # ['h0', 'attn', 'c_attn', 'w']
            name = "model/" + '/'.join(parts) + ':0' # 'model/h0/attn/c_attn/w:0'
            array = tensor.numpy().ravel()

            arrays.append(array)
            shapes.append(list(tensor.shape))
            names.append(name)

        # write out the arrays
        big_array = np.concatenate(arrays)
        total_size = len(big_array)
        batch_size = math.ceil(total_size / num_pieces)

        for i in range(num_pieces):
            filename = output_path / f"params_{i}.npy"
            start = i * batch_size
            end = start + batch_size
            subarray = big_array[start:end]

            np.save(filename, subarray)

        # write out the shapes
        with open(output_path / 'params_shapes.json', 'w') as shapes_file:
            json.dump(shapes, shapes_file)