Unverified Commit e9310363 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Add bloom flax (#25094)



* First commit

* step 1 working

* add alibi

* placeholder for `scan`

* add matrix mult alibi

* beta scaling factor for bmm

* working v1 - simple forward pass

* move layer_number from attribute to arg in call

* partial functioning scan

* hacky working scan

* add more modifs

* add test

* update scan for new kwarg order

* fix position_ids problem

* fix bug in attention layer

* small fix

- do the alibi broadcasting only once

* prelim refactor

* finish refactor

* alibi shifting

* incorporate dropout_add to attention module

* make style

* make padding work again

* update

* remove bogus file

* up

* get generation to work

* clean code a bit

* added small tests

* adding albii test

* make CI tests pass:

- change init weight
- add correct tuple for output attention
- add scan test
- make CI tests work

* fix few nits

* fix nit onnx

* fix onnx nit

* add missing dtype args to nn.Modules

* remove debugging statements

* fix scan generate

* Update modeling_flax_bloom.py

* Update test_modeling_flax_bloom.py

* Update test_modeling_flax_bloom.py

* Update test_modeling_flax_bloom.py

* fix small test issue + make style

* clean up

* Update tests/models/bloom/test_modeling_flax_bloom.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* fix function name

* small fix test

* forward contrib credits from PR17761

* Fix failing test

* fix small typo documentation

* fix non passing test

- remove device from build alibi

* refactor call

- refactor `FlaxBloomBlockCollection` module

* make style

* upcast to fp32

* cleaner way to upcast

* remove unused args

* remove layer number

* fix scan test

* make style

* fix i4 casting

* fix slow test

* Update src/transformers/models/bloom/modeling_flax_bloom.py
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* remove `layer_past`

* refactor a bit

* fix `scan` slow test

* remove useless import

* major changes

- remove unused code
- refactor a bit
- revert import `torch`

* major refactoring

- change build alibi

* remove scan

* fix tests

* make style

* clean-up alibi

* add integration tests

* up

* fix batch norm conversion

* style

* style

* update pt-fx cross tests

* update copyright

* Update src/transformers/modeling_flax_pytorch_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* per-weight check

* style

* line formats

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarhaileyschoelkopf <haileyschoelkopf@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 0c790ddb
...@@ -218,7 +218,7 @@ Flax), PyTorch, und/oder TensorFlow haben. ...@@ -218,7 +218,7 @@ Flax), PyTorch, und/oder TensorFlow haben.
| BigBird-Pegasus | ❌ | ❌ | ✅ | ❌ | ❌ | | BigBird-Pegasus | ❌ | ❌ | ✅ | ❌ | ❌ |
| Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ | | Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ |
| BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ | | BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ |
| BLOOM | ❌ | ✅ | ✅ | ❌ | | | BLOOM | ❌ | ✅ | ✅ | ❌ | |
| CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ | | CamemBERT | ✅ | ✅ | ✅ | ✅ | ❌ |
| CANINE | ✅ | ❌ | ✅ | ❌ | ❌ | | CANINE | ✅ | ❌ | ✅ | ❌ | ❌ |
| CLIP | ✅ | ✅ | ✅ | ✅ | ✅ | | CLIP | ✅ | ✅ | ✅ | ✅ | ✅ |
......
...@@ -300,7 +300,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -300,7 +300,7 @@ Flax), PyTorch, and/or TensorFlow.
| BlenderbotSmall | ✅ | ✅ | ✅ | | BlenderbotSmall | ✅ | ✅ | ✅ |
| BLIP | ✅ | ✅ | ❌ | | BLIP | ✅ | ✅ | ❌ |
| BLIP-2 | ✅ | ❌ | ❌ | | BLIP-2 | ✅ | ❌ | ❌ |
| BLOOM | ✅ | ❌ | | | BLOOM | ✅ | ❌ | |
| BridgeTower | ✅ | ❌ | ❌ | | BridgeTower | ✅ | ❌ | ❌ |
| CamemBERT | ✅ | ✅ | ❌ | | CamemBERT | ✅ | ✅ | ❌ |
| CANINE | ✅ | ❌ | ❌ | | CANINE | ✅ | ❌ | ❌ |
......
...@@ -85,3 +85,13 @@ See also: ...@@ -85,3 +85,13 @@ See also:
[[autodoc]] BloomForQuestionAnswering [[autodoc]] BloomForQuestionAnswering
- forward - forward
## FlaxBloomModel
[[autodoc]] FlaxBloomModel
- __call__
## FlaxBloomForCausalLM
[[autodoc]] FlaxBloomForCausalLM
- __call__
...@@ -3892,6 +3892,13 @@ else: ...@@ -3892,6 +3892,13 @@ else:
"FlaxBlenderbotSmallPreTrainedModel", "FlaxBlenderbotSmallPreTrainedModel",
] ]
) )
_import_structure["models.bloom"].extend(
[
"FlaxBloomForCausalLM",
"FlaxBloomModel",
"FlaxBloomPreTrainedModel",
]
)
_import_structure["models.clip"].extend( _import_structure["models.clip"].extend(
[ [
"FlaxCLIPModel", "FlaxCLIPModel",
...@@ -7275,6 +7282,7 @@ if TYPE_CHECKING: ...@@ -7275,6 +7282,7 @@ if TYPE_CHECKING:
FlaxBlenderbotSmallModel, FlaxBlenderbotSmallModel,
FlaxBlenderbotSmallPreTrainedModel, FlaxBlenderbotSmallPreTrainedModel,
) )
from .models.bloom import FlaxBloomForCausalLM, FlaxBloomModel, FlaxBloomPreTrainedModel
from .models.clip import ( from .models.clip import (
FlaxCLIPModel, FlaxCLIPModel,
FlaxCLIPPreTrainedModel, FlaxCLIPPreTrainedModel,
......
...@@ -135,7 +135,21 @@ def rename_key_and_reshape_tensor( ...@@ -135,7 +135,21 @@ def rename_key_and_reshape_tensor(
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# convert pytorch tensor to numpy # convert pytorch tensor to numpy
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()} # numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
try:
import torch # noqa: F401
except ImportError:
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
pt_state_dict = {
k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items()
}
model_prefix = flax_model.base_model_prefix model_prefix = flax_model.base_model_prefix
...@@ -163,6 +177,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ...@@ -163,6 +177,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# Need to change some parameters name to match Flax names # Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items(): for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split(".")) pt_tuple_key = tuple(pt_key.split("."))
is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16
# remove base model prefix if necessary # remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix has_base_model_prefix = pt_tuple_key[0] == model_prefix
...@@ -197,11 +212,15 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ...@@ -197,11 +212,15 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
continue continue
# also add unexpected weight so that warning is thrown # also add unexpected weight so that warning is thrown
flax_state_dict[("params",) + flax_key] = jnp.asarray(flax_tensor) flax_state_dict[("params",) + flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
else: else:
# also add unexpected weight so that warning is thrown # also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = jnp.asarray(flax_tensor) flax_state_dict[flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
)
return unflatten_dict(flax_state_dict) return unflatten_dict(flax_state_dict)
......
...@@ -35,6 +35,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict( ...@@ -35,6 +35,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("big_bird", "FlaxBigBirdModel"), ("big_bird", "FlaxBigBirdModel"),
("blenderbot", "FlaxBlenderbotModel"), ("blenderbot", "FlaxBlenderbotModel"),
("blenderbot-small", "FlaxBlenderbotSmallModel"), ("blenderbot-small", "FlaxBlenderbotSmallModel"),
("bloom", "FlaxBloomModel"),
("clip", "FlaxCLIPModel"), ("clip", "FlaxCLIPModel"),
("distilbert", "FlaxDistilBertModel"), ("distilbert", "FlaxDistilBertModel"),
("electra", "FlaxElectraModel"), ("electra", "FlaxElectraModel"),
...@@ -139,6 +140,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ...@@ -139,6 +140,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
("bart", "FlaxBartForCausalLM"), ("bart", "FlaxBartForCausalLM"),
("bert", "FlaxBertForCausalLM"), ("bert", "FlaxBertForCausalLM"),
("big_bird", "FlaxBigBirdForCausalLM"), ("big_bird", "FlaxBigBirdForCausalLM"),
("bloom", "FlaxBloomForCausalLM"),
("electra", "FlaxElectraForCausalLM"), ("electra", "FlaxElectraForCausalLM"),
("gpt-sw3", "FlaxGPT2LMHeadModel"), ("gpt-sw3", "FlaxGPT2LMHeadModel"),
("gpt2", "FlaxGPT2LMHeadModel"), ("gpt2", "FlaxGPT2LMHeadModel"),
......
...@@ -14,7 +14,13 @@ ...@@ -14,7 +14,13 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tokenizers_available, is_torch_available from ...utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_flax_available,
is_tokenizers_available,
is_torch_available,
)
_import_structure = { _import_structure = {
...@@ -44,6 +50,19 @@ else: ...@@ -44,6 +50,19 @@ else:
"BloomForQuestionAnswering", "BloomForQuestionAnswering",
] ]
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flax_bloom"] = [
"FlaxBloomForCausalLM",
"FlaxBloomModel",
"FlaxBloomPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig from .configuration_bloom import BLOOM_PRETRAINED_CONFIG_ARCHIVE_MAP, BloomConfig, BloomOnnxConfig
...@@ -71,6 +90,13 @@ if TYPE_CHECKING: ...@@ -71,6 +90,13 @@ if TYPE_CHECKING:
BloomPreTrainedModel, BloomPreTrainedModel,
) )
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flax_bloom import FlaxBloomForCausalLM, FlaxBloomModel, FlaxBloomPreTrainedModel
else: else:
import sys import sys
......
This diff is collapsed.
...@@ -194,7 +194,8 @@ class FlaxGPTJAttention(nn.Module): ...@@ -194,7 +194,8 @@ class FlaxGPTJAttention(nn.Module):
cached_value.value = value cached_value.value = value
num_updated_cache_vectors = query.shape[1] num_updated_cache_vectors = query.shape[1]
cache_index.value = cache_index.value + num_updated_cache_vectors cache_index.value = cache_index.value + num_updated_cache_vectors
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. # causal mask for cached decoder self-attention: our single query position should only attend to those key
# positions that have already been generated and cached, not the remaining zero elements.
pad_mask = jnp.broadcast_to( pad_mask = jnp.broadcast_to(
jnp.arange(max_length) < cur_index + num_updated_cache_vectors, jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
......
...@@ -520,6 +520,27 @@ class FlaxBlenderbotSmallPreTrainedModel(metaclass=DummyObject): ...@@ -520,6 +520,27 @@ class FlaxBlenderbotSmallPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxBloomForCausalLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBloomModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxBloomPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxCLIPModel(metaclass=DummyObject): class FlaxCLIPModel(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
...@@ -487,6 +487,33 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -487,6 +487,33 @@ class BloomModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
tokenizer.decode(greedy_output_without_pad[0, :-3], skip_special_tokens=True), tokenizer.decode(greedy_output_without_pad[0, :-3], skip_special_tokens=True),
) )
@slow
@require_torch_gpu
def test_batch_generated_text(self):
path_560m = "bigscience/bloom-560m"
model = BloomForCausalLM.from_pretrained(path_560m, use_cache=True, revision="gs555750").cuda()
model = model.eval()
tokenizer = BloomTokenizerFast.from_pretrained(path_560m, padding_side="left")
input_sentences = [
"Hello what is",
"Running a quick test with the",
]
inputs = tokenizer(input_sentences, return_tensors="pt", padding=True, truncation=True)
generated_ids = model.generate(
inputs["input_ids"].cuda(), attention_mask=inputs["attention_mask"], max_length=20
)
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# these generations match those of the PyTorch model
EXPECTED_GENERATIONS = [
"Hello what is the best way to get the data from the server? I have tried",
"Running a quick test with the following command:\nsudo apt-get install python3\nsudo apt-get install python2",
]
self.assertListEqual(generated_text, EXPECTED_GENERATIONS)
@require_torch @require_torch
class BloomEmbeddingTest(unittest.TestCase): class BloomEmbeddingTest(unittest.TestCase):
......
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import BloomConfig, BloomTokenizerFast, is_flax_available
from transformers.testing_utils import require_flax, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
if is_flax_available():
import os
# The slow tests are often failing with OOM error on GPU
# This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
# but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import jax.numpy as jnp
from transformers import FlaxBloomForCausalLM, FlaxBloomModel
def prepare_bloom_inputs_dict(config, input_ids, attention_mask=None):
if attention_mask is None:
attention_mask = np.where(input_ids != config.pad_token_id, 1, 0)
return {"input_ids": input_ids, "attention_mask": attention_mask}
@require_flax
class FlaxBloomModelTester:
def __init__(
self,
parent,
batch_size=13,
seq_length=7,
is_training=True,
use_labels=False,
vocab_size=99,
hidden_size=16,
n_layer=2,
n_head=4,
hidden_act="gelu",
hidden_dropout=0.1,
attention_probs_dropout_prob=0.1,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
initializer_range=0.02,
apply_residual_connection_post_layernorm=False,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = n_layer
self.num_attention_heads = n_head
self.hidden_act = hidden_act
self.hidden_dropout = hidden_dropout
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.eos_token_id = eos_token_id
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.initializer_range = initializer_range
self.is_encoder_decoder = False
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
def prepare_config_and_inputs(self):
input_ids = np.clip(ids_tensor([self.batch_size, self.seq_length - 1], self.vocab_size), 3, self.vocab_size)
input_ids = np.concatenate((input_ids, 2 * np.ones((self.batch_size, 1), dtype=np.int64)), -1)
config = BloomConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
n_layer=self.num_hidden_layers,
n_head=self.num_attention_heads,
hidden_dropout=self.hidden_dropout,
attention_dropout=self.attention_probs_dropout_prob,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
is_encoder_decoder=False,
use_cache=False,
)
inputs_dict = prepare_bloom_inputs_dict(config, input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
return config, inputs_dict
def check_use_cache_forward(self, model_class_name, config, inputs_dict):
max_length = 20
model = model_class_name(config)
input_ids = inputs_dict["input_ids"]
attention_mask = jnp.ones((input_ids.shape[0], max_length), dtype="i4")
past_key_values = model.init_cache(input_ids.shape[0], max_length)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask,
past_key_values=past_key_values,
)
outputs_cache_next = model(
input_ids[:, -1:],
attention_mask=attention_mask,
past_key_values=outputs_cache.past_key_values,
)
outputs = model(input_ids)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
max_length = 20
model = model_class_name(config)
input_ids, attention_mask = (
inputs_dict["input_ids"],
inputs_dict["attention_mask"],
)
attention_mask_cache = jnp.concatenate(
[
attention_mask,
jnp.zeros((attention_mask.shape[0], max_length - attention_mask.shape[1])),
],
axis=-1,
)
past_key_values = model.init_cache(input_ids.shape[0], max_length)
outputs_cache = model(
input_ids[:, :-1],
attention_mask=attention_mask_cache,
past_key_values=past_key_values,
)
outputs_cache_next = model(
input_ids[:, -1:],
past_key_values=outputs_cache.past_key_values,
attention_mask=attention_mask_cache,
)
outputs = model(input_ids, attention_mask=attention_mask)
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
@require_flax
class FlaxBloomModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGenerationTesterMixin):
all_model_classes = (FlaxBloomModel, FlaxBloomForCausalLM) if is_flax_available() else ()
all_generative_model_classes = () if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxBloomModelTester(self)
def test_use_cache_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
self.model_tester.check_use_cache_forward(model_class, config, inputs_dict)
def test_use_cache_forward_with_attn_mask(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
self.model_tester.check_use_cache_forward_with_attn_mask(model_class, config, inputs_dict)
@slow
def test_model_from_pretrained(self):
for model_class_name in self.all_model_classes:
model = model_class_name.from_pretrained("bigscience/bloom-560m")
input_ids = np.ones((1, 1)) * model.config.eos_token_id
outputs = model(input_ids)
self.assertIsNotNone(outputs)
@slow
@require_flax
class FlaxBloomGenerationTest(unittest.TestCase):
all_model_classes = (FlaxBloomForCausalLM) if is_flax_available() else ()
all_generative_model_classes = () if is_flax_available() else ()
def setUp(self):
self.model_id = "bigscience/bloom-560m"
self.tokenizer = BloomTokenizerFast.from_pretrained(self.model_id, padding_side="left")
self.model_tester = FlaxBloomModelTester(self)
self.model = FlaxBloomForCausalLM.from_pretrained(self.model_id, from_pt=True, revision="gs555750")
def test_model_batched_gen(self):
# tests if the model outputs the same generation for the same batched input
input_sentences = [
"Hello there is this string is definitely longer I believe that",
"Hello there is this string is definitely longer I believe that",
]
inputs = self.tokenizer(input_sentences, return_tensors="np", padding=True, truncation=True)
sequences_fx = self.model.generate(**inputs, max_length=20).sequences
self.assertEqual(sequences_fx[0].tolist(), sequences_fx[1].tolist())
def test_model_batched_padding_left(self):
# tests if the model outputs the same generation for an input that is part of a batch
# and a single input
input_sentences_batch = [
"Hello there is this string is definitely longer I believe that",
"Hi I want to order",
]
inputs = self.tokenizer(input_sentences_batch, return_tensors="np", padding=True, truncation=True)
sequences_fx_batch = self.model.generate(**inputs, max_length=20).sequences
input_sentence_simple = "Hi I want to order"
inputs_simple = self.tokenizer(input_sentence_simple, return_tensors="np")
sequences_fx_simple = self.model.generate(**inputs_simple, max_length=20).sequences
self.assertEqual(sequences_fx_batch[1][6:].tolist(), sequences_fx_simple[0][:-6].tolist())
def test_batch_generated_text(self):
input_sentences = [
"Hello what is",
"Running a quick test with the",
]
inputs = self.tokenizer(input_sentences, return_tensors="np", padding=True, truncation=True)
generated_ids = self.model.generate(**inputs, max_length=20).sequences
generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# these generations match those of the PyTorch model, ensuring correctness
EXPECTED_GENERATIONS = [
"Hello what is the best way to get the data from the server? I have tried",
"Running a quick test with the following command:\nsudo apt-get install python3\nsudo apt-get install python2",
]
self.assertListEqual(generated_text, EXPECTED_GENERATIONS)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment