Unverified Commit b68d408f authored by Nouamane Tazi's avatar Nouamane Tazi Committed by GitHub
Browse files

add ONNX support for BLOOM (#17961)



* add onnx support for BLOOM

* use TYPE_CHECKING for type annotations

* fix past_shape for bloom (different from gpt2)

* use logical_or instead of `+` for onnx support

* bigger `atol_for_validation` for larger bloom models

* copied -> taken because it's no longer an exact copy

* remove "copied from" comment
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 462b7f3a
......@@ -53,6 +53,7 @@ Ready-made configurations include the following architectures:
- BigBird-Pegasus
- Blenderbot
- BlenderbotSmall
- BLOOM
- CamemBERT
- CodeGen
- ConvBERT
......
......@@ -22,10 +22,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_
_import_structure = {
"configuration_bloom": [
"BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BloomConfig",
],
"configuration_bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig", "BloomOnnxConfig"],
}
try:
if not is_tokenizers_available():
......@@ -51,7 +48,7 @@ else:
]
if TYPE_CHECKING:
from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig
from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig
try:
if not is_tokenizers_available():
......
......@@ -13,7 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Bloom configuration"""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
from transformers import is_torch_available
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, TensorType
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...utils import logging
......@@ -153,3 +163,88 @@ class BloomConfig(PretrainedConfig):
self.slow_but_exact = slow_but_exact
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
class BloomOnnxConfig(OnnxConfigWithPast):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
if not getattr(self._config, "pad_token_id", None):
# TODO: how to do that better?
self._config.pad_token_id = 0
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
return common_inputs
@property
def num_layers(self) -> int:
return self._config.n_layer
@property
def num_attention_heads(self) -> int:
return self._config.n_head
@property
def atol_for_validation(self) -> float:
return 1e-3
def generate_dummy_inputs(
self,
tokenizer: "PreTrainedTokenizer",
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
# Need to add the past_keys
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_shape = (
batch,
past_key_values_length,
self.num_attention_heads,
self._config.hidden_size // self.num_attention_heads,
)
ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
mask_dtype = ordered_inputs["attention_mask"].dtype
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
return ordered_inputs
@property
def default_onnx_opset(self) -> int:
return 13
......@@ -78,17 +78,14 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=
def attention_mask_func(attention_scores, attention_mask, causal_mask):
if attention_mask.dtype == torch.bool:
attention_mask_bool = ~attention_mask
else:
attention_mask_bool = (1 - attention_mask).bool()
attention_mask_bool = ~attention_mask.bool()
query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
padded_causal_mask = (
attention_mask_bool[:, None, key_length - query_length : key_length, None]
+ ~causal_mask[:, :, key_length - query_length : key_length, :key_length]
).bool()
padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool()
padded_causal_mask = torch.logical_or(
attention_mask_bool[:, None, key_length - query_length : key_length, None],
~causal_mask[:, :, key_length - query_length : key_length, :key_length].bool(),
)
padded_causal_mask = torch.logical_or(padded_causal_mask, attention_mask_bool[:, None, None, :key_length])
# Make use of floats
return (
attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
......@@ -296,11 +293,8 @@ class BloomScaledSoftmax(nn.Module):
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
mask = mask.to(input.device)
causal_mask = (
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
.view(1, 1, max_positions, max_positions)
.to(input.device)
)
seq_ids = torch.arange(max_positions, device=input.device)
causal_mask = (seq_ids[None, :] <= seq_ids[:, None]).view(1, 1, max_positions, max_positions).to(input.device)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
......
......@@ -182,6 +182,15 @@ class FeaturesManager:
"seq2seq-lm-with-past",
onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
),
"bloom": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"sequence-classification",
"token-classification",
onnx_config_cls="models.bloom.BloomOnnxConfig",
),
"camembert": supported_features_mapping(
"default",
"masked-lm",
......
......@@ -205,6 +205,7 @@ PYTORCH_EXPORT_MODELS = {
}
PYTORCH_EXPORT_WITH_PAST_MODELS = {
("bloom", "bigscience/bloom-350m"),
("gpt2", "gpt2"),
("gpt-neo", "EleutherAI/gpt-neo-125M"),
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment