Unverified Commit b68d408f authored by Nouamane Tazi's avatar Nouamane Tazi Committed by GitHub
Browse files

add ONNX support for BLOOM (#17961)



* add onnx support for BLOOM

* use TYPE_CHECKING for type annotations

* fix past_shape for bloom (different from gpt2)

* use logical_or instead of `+` for onnx support

* bigger `atol_for_validation` for larger bloom models

* copied -> taken because it's no longer an exact copy

* remove "copied from" comment
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 462b7f3a
...@@ -53,6 +53,7 @@ Ready-made configurations include the following architectures: ...@@ -53,6 +53,7 @@ Ready-made configurations include the following architectures:
- BigBird-Pegasus - BigBird-Pegasus
- Blenderbot - Blenderbot
- BlenderbotSmall - BlenderbotSmall
- BLOOM
- CamemBERT - CamemBERT
- CodeGen - CodeGen
- ConvBERT - ConvBERT
......
...@@ -22,10 +22,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_ ...@@ -22,10 +22,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_
_import_structure = { _import_structure = {
"configuration_bloom": [ "configuration_bloom": ["BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP", "BloomConfig", "BloomOnnxConfig"],
"BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP",
"BloomConfig",
],
} }
try: try:
if not is_tokenizers_available(): if not is_tokenizers_available():
...@@ -51,7 +48,7 @@ else: ...@@ -51,7 +48,7 @@ else:
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig
try: try:
if not is_tokenizers_available(): if not is_tokenizers_available():
......
...@@ -13,7 +13,17 @@ ...@@ -13,7 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Bloom configuration""" """ Bloom configuration"""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, List, Mapping, Optional
from transformers import is_torch_available
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, TensorType
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfigWithPast, PatchingSpec
from ...utils import logging from ...utils import logging
...@@ -153,3 +163,88 @@ class BloomConfig(PretrainedConfig): ...@@ -153,3 +163,88 @@ class BloomConfig(PretrainedConfig):
self.slow_but_exact = slow_but_exact self.slow_but_exact = slow_but_exact
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
class BloomOnnxConfig(OnnxConfigWithPast):
def __init__(
self,
config: PretrainedConfig,
task: str = "default",
patching_specs: List[PatchingSpec] = None,
use_past: bool = False,
):
super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past)
if not getattr(self._config, "pad_token_id", None):
# TODO: how to do that better?
self._config.pad_token_id = 0
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}})
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"}
else:
common_inputs["attention_mask"] = {0: "batch", 1: "sequence"}
return common_inputs
@property
def num_layers(self) -> int:
return self._config.n_layer
@property
def num_attention_heads(self) -> int:
return self._config.n_head
@property
def atol_for_validation(self) -> float:
return 1e-3
def generate_dummy_inputs(
self,
tokenizer: "PreTrainedTokenizer",
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
# We need to order the input in the way they appears in the forward()
ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]})
# Need to add the past_keys
if self.use_past:
if not is_torch_available():
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
else:
import torch
batch, seqlen = common_inputs["input_ids"].shape
# Not using the same length for past_key_values
past_key_values_length = seqlen + 2
past_shape = (
batch,
past_key_values_length,
self.num_attention_heads,
self._config.hidden_size // self.num_attention_heads,
)
ordered_inputs["past_key_values"] = [
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers)
]
ordered_inputs["attention_mask"] = common_inputs["attention_mask"]
if self.use_past:
mask_dtype = ordered_inputs["attention_mask"].dtype
ordered_inputs["attention_mask"] = torch.cat(
[ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
)
return ordered_inputs
@property
def default_onnx_opset(self) -> int:
return 13
...@@ -78,17 +78,14 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks= ...@@ -78,17 +78,14 @@ def split_tensor_along_last_dim(tensor, num_partitions, contiguous_split_chunks=
def attention_mask_func(attention_scores, attention_mask, causal_mask): def attention_mask_func(attention_scores, attention_mask, causal_mask):
if attention_mask.dtype == torch.bool: attention_mask_bool = ~attention_mask.bool()
attention_mask_bool = ~attention_mask
else:
attention_mask_bool = (1 - attention_mask).bool()
query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1) query_length, key_length, n_heads = attention_scores.size(2), attention_scores.size(3), attention_scores.size(1)
padded_causal_mask = ( padded_causal_mask = torch.logical_or(
attention_mask_bool[:, None, key_length - query_length : key_length, None] attention_mask_bool[:, None, key_length - query_length : key_length, None],
+ ~causal_mask[:, :, key_length - query_length : key_length, :key_length] ~causal_mask[:, :, key_length - query_length : key_length, :key_length].bool(),
).bool() )
padded_causal_mask = padded_causal_mask + attention_mask_bool[:, None, None, :key_length].bool() padded_causal_mask = torch.logical_or(padded_causal_mask, attention_mask_bool[:, None, None, :key_length])
# Make use of floats # Make use of floats
return ( return (
attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0), attention_scores.masked_fill_(padded_causal_mask.expand(-1, n_heads, -1, -1), -10000.0),
...@@ -296,11 +293,8 @@ class BloomScaledSoftmax(nn.Module): ...@@ -296,11 +293,8 @@ class BloomScaledSoftmax(nn.Module):
mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device) mask = torch.ones(input.shape[0], max_positions, dtype=torch.bool, device=input.device)
mask = mask.to(input.device) mask = mask.to(input.device)
causal_mask = ( seq_ids = torch.arange(max_positions, device=input.device)
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)) causal_mask = (seq_ids[None, :] <= seq_ids[:, None]).view(1, 1, max_positions, max_positions).to(input.device)
.view(1, 1, max_positions, max_positions)
.to(input.device)
)
mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask) mask_output, padded_causal_mask = self.mask_func(input, mask, causal_mask)
probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask) probs = nn.functional.softmax(mask_output, dim=-1, dtype=softmax_dtype) * (~padded_causal_mask)
......
...@@ -182,6 +182,15 @@ class FeaturesManager: ...@@ -182,6 +182,15 @@ class FeaturesManager:
"seq2seq-lm-with-past", "seq2seq-lm-with-past",
onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig", onnx_config_cls="models.blenderbot_small.BlenderbotSmallOnnxConfig",
), ),
"bloom": supported_features_mapping(
"default",
"default-with-past",
"causal-lm",
"causal-lm-with-past",
"sequence-classification",
"token-classification",
onnx_config_cls="models.bloom.BloomOnnxConfig",
),
"camembert": supported_features_mapping( "camembert": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
......
...@@ -205,6 +205,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -205,6 +205,7 @@ PYTORCH_EXPORT_MODELS = {
} }
PYTORCH_EXPORT_WITH_PAST_MODELS = { PYTORCH_EXPORT_WITH_PAST_MODELS = {
("bloom", "bigscience/bloom-350m"),
("gpt2", "gpt2"), ("gpt2", "gpt2"),
("gpt-neo", "EleutherAI/gpt-neo-125M"), ("gpt-neo", "EleutherAI/gpt-neo-125M"),
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment