configs.py 1.96 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

15
16
17
18
"""The ALBERT configurations."""

import six

Le Hou's avatar
Le Hou committed
19
from official.legacy.bert import configs
20
21


22
class AlbertConfig(configs.BertConfig):
23
24
  """Configuration for `ALBERT`."""

Hongkun Yu's avatar
Hongkun Yu committed
25
  def __init__(self, num_hidden_groups=1, inner_group_num=1, **kwargs):
26
27
28
29
30
31
32
33
34
35
36
37
38
    """Constructs AlbertConfig.

    Args:
      num_hidden_groups: Number of group for the hidden layers, parameters in
        the same group are shared. Note that this value and also the following
        'inner_group_num' has to be 1 for now, because all released ALBERT
        models set them to 1. We may support arbitary valid values in future.
      inner_group_num: Number of inner repetition of attention and ffn.
      **kwargs: The remaining arguments are the same as above 'BertConfig'.
    """
    super(AlbertConfig, self).__init__(**kwargs)

    # TODO(chendouble): 'inner_group_num' and 'num_hidden_groups' are always 1
Chen Chen's avatar
Chen Chen committed
39
    # in the released ALBERT. Support other values in AlbertEncoder if needed.
40
41
42
43
44
45
46
    if inner_group_num != 1 or num_hidden_groups != 1:
      raise ValueError("We only support 'inner_group_num' and "
                       "'num_hidden_groups' as 1.")

  @classmethod
  def from_dict(cls, json_object):
    """Constructs a `AlbertConfig` from a Python dictionary of parameters."""
47
    config = AlbertConfig(vocab_size=None)
48
49
50
    for (key, value) in six.iteritems(json_object):
      config.__dict__[key] = value
    return config