"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "581d8aacf72c2ea759ec1642f28b1df8febd131e"
Unverified Commit 3e3e5521 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Reformer] fix reformer num buckets (#4564)

* fix reformer num buckets

* fix

* adapt docs

* set num buckets in config
parent 3dea40b8
...@@ -62,7 +62,7 @@ For more information, see the `original Paper <https://arxiv.org/abs/2001.04451> ...@@ -62,7 +62,7 @@ For more information, see the `original Paper <https://arxiv.org/abs/2001.04451>
Note that ``config.num_buckets`` can also be factorized into a ``list``:math:`(n_{\text{buckets}}^1, n_{\text{buckets}}^2)`. This way instead of assigning the query key embedding vectors to one of :math:`(1,\ldots, n_{\text{buckets}})` they are assigned to one of :math:`(1-1,\ldots, n_{\text{buckets}}^1-1, \ldots, 1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)`. This is crucial for very long sequences to save memory. Note that ``config.num_buckets`` can also be factorized into a ``list``:math:`(n_{\text{buckets}}^1, n_{\text{buckets}}^2)`. This way instead of assigning the query key embedding vectors to one of :math:`(1,\ldots, n_{\text{buckets}})` they are assigned to one of :math:`(1-1,\ldots, n_{\text{buckets}}^1-1, \ldots, 1-n_{\text{buckets}}^2, \ldots, n_{\text{buckets}}^1-n_{\text{buckets}}^2)`. This is crucial for very long sequences to save memory.
It is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length, a good value for ``num_buckets`` are calculated on the fly. When training a model from scratch, it is recommended to leave ``config.num_buckets=None``, so that depending on the sequence length a good value for ``num_buckets`` is calculated on the fly. This value will then automatically be saved in the config and should be reused for inference.
Using LSH self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length. Using LSH self attention, the memory and time complexity of the query-key matmul operation can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times \log(n_s))`, which usually represents the memory and time bottleneck in a transformer model, with :math:`n_s` being the sequence length.
......
...@@ -110,10 +110,10 @@ class ReformerConfig(PretrainedConfig): ...@@ -110,10 +110,10 @@ class ReformerConfig(PretrainedConfig):
Typically set this to something large just in case (e.g., 512 or 1024 or 2048). Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
num_attention_heads (:obj:`int`, optional, defaults to 12): num_attention_heads (:obj:`int`, optional, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder. Number of attention heads for each attention layer in the Transformer encoder.
num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `64`): num_buckets (:obj:`int` or :obj:`list(int)`, optional, defaults to `None`):
Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme. Each query key vector is hashed into a hash in `1, ..., num_buckets`. Number of buckets, the key query vectors can be "hashed into" using the locality sensitive hashing scheme. Each query key vector is hashed into a hash in `1, ..., num_buckets`.
The number of buckets can also be factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is factorized into two factors. The number of buckets can also be factorized into a list for improved memory complexity. In this case, each query key vector is hashed into a hash in `1-1, 1-2, ..., num_buckets[0]-1, ..., num_buckets[0]-num_buckets[1]` if `num_buckets` is factorized into two factors.
The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length. The number of buckets (or the product the factors) should approximately equal sequence length / lsh_chunk_length. If `num_buckets` is set to `None`, a good value for `num_buckets` is calculated on the fly.
num_hashes (:obj:`int`, optional, defaults to 1): num_hashes (:obj:`int`, optional, defaults to 1):
Number of hashing rounds (e.g. number of random rotations) in Local Sensitive Hashing scheme. Number of hashing rounds (e.g. number of random rotations) in Local Sensitive Hashing scheme.
The higher `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive the hashing becomes. The higher `num_hashes`, the more accurate the `LSHSelfAttention` becomes, but also the more memory and time intensive the hashing becomes.
...@@ -172,7 +172,7 @@ class ReformerConfig(PretrainedConfig): ...@@ -172,7 +172,7 @@ class ReformerConfig(PretrainedConfig):
lsh_num_chunks_after=0, lsh_num_chunks_after=0,
max_position_embeddings=4096, max_position_embeddings=4096,
num_attention_heads=2, num_attention_heads=2,
num_buckets=32, num_buckets=None,
num_hashes=1, num_hashes=1,
pad_token_id=0, pad_token_id=0,
vocab_size=320, vocab_size=320,
......
...@@ -283,6 +283,8 @@ class EfficientAttentionMixin: ...@@ -283,6 +283,8 @@ class EfficientAttentionMixin:
class LSHSelfAttention(nn.Module, EfficientAttentionMixin): class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config
self.chunk_length = config.lsh_attn_chunk_length self.chunk_length = config.lsh_attn_chunk_length
self.num_hashes = config.num_hashes self.num_hashes = config.num_hashes
self.num_buckets = config.num_buckets self.num_buckets = config.num_buckets
...@@ -532,15 +534,22 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -532,15 +534,22 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
return sorted_bucket_idx, undo_sorted_bucket_idx return sorted_bucket_idx, undo_sorted_bucket_idx
def _set_num_buckets(self, sequence_length): def _set_num_buckets(self, sequence_length):
# recommended `num_buckets` from paper # `num_buckets` should be set to 2 * sequence_length // chunk_length as recommended in paper
num_buckets = 2 * sequence_length // self.chunk_length num_buckets_pow_2 = (2 * (sequence_length // self.chunk_length)).bit_length() - 1
# make sure buckets are power of 2
num_buckets = 2 ** num_buckets_pow_2
# factorize `num_buckets` if `num_buckets` becomes too large # factorize `num_buckets` if `num_buckets` becomes too large
num_buckets_limit = max(int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,) num_buckets_limit = 2 * max(
if num_buckets > 2 * num_buckets_limit: int((self.max_position_embeddings // self.chunk_length) ** (0.5)), self.chunk_length,
num_buckets = [num_buckets_limit, num_buckets // num_buckets_limit + 1] )
if num_buckets > num_buckets_limit:
num_buckets = [2 ** (num_buckets_pow_2 // 2), 2 ** (num_buckets_pow_2 - num_buckets_pow_2 // 2)]
logger.warning("config.num_buckets is not set. Setting config.num_buckets to {}...".format(num_buckets)) logger.warning("config.num_buckets is not set. Setting config.num_buckets to {}...".format(num_buckets))
# set num buckets in config to be properly saved
self.config.num_buckets = num_buckets
self.num_buckets = num_buckets self.num_buckets = num_buckets
def _attend( def _attend(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment