"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2466>`__ .
Training parameters to avoid padding with random_spans_noise_mask.
When training a model with random_spans_noise_mask, we would like to set the other
training hyperparmeters in a way that avoids padding.
This function helps us compute these hyperparameters.
The number of noise tokens and the number of noise spans and non-noise spans
"""This function is inspired from `random_spans_noise_mask <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
Spans alternate between non-noise and noise, beginning with non-noise.
Args:
inputs_length: int32 scalar
targets_length: int32 scalar
num_noise_spans: int32 scalar
Returns:
a int8 tensor with shape [num_noise_spans]
a boolean tensor with shape [length]
"""
# # pick the lengths of the noise spans and the non-noise spans
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
importenum
classLayerType(enum.Enum):
encoder=1
decoder=2
classAttnType(enum.Enum):
self_attn=1
cross_attn=2
classAttnMaskType(enum.Enum):
padding=1
causal=2# Overrides `attention_mask` to be a lower triangular matrix
prefix=3
custom=4# Forces one to pass an `attention_mask` that's 1 if we need to mask. Tensor that can be broadcast to [micro_batch_size, n_head, seq_length, seq_length]