Unverified Commit 8f915c45 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Unpin numba (#23162)

* fix for ragged list

* unpin numba

* make style

* np.object -> object

* propagate changes to tokenizer as well

* np.long -> "long"

* revert tokenization changes

* check with tokenization changes

* list/tuple logic

* catch numpy

* catch else case

* clean up

* up

* better check

* trigger ci

* Empty commit to trigger CI
parent d99f11e8
......@@ -132,7 +132,6 @@ _deps = [
"librosa",
"nltk",
"natten>=0.14.6",
"numba<0.57.0", # Can be removed once unpinned.
"numpy>=1.17",
"onnxconverter-common",
"onnxruntime-tools>=1.4.2",
......@@ -286,8 +285,7 @@ extras["sigopt"] = deps_list("sigopt")
extras["integrations"] = extras["optuna"] + extras["ray"] + extras["sigopt"]
extras["serving"] = deps_list("pydantic", "uvicorn", "fastapi", "starlette")
# numba can be removed here once unpinned
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm", "numba")
extras["audio"] = deps_list("librosa", "pyctcdecode", "phonemizer", "kenlm")
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
extras["speech"] = deps_list("torchaudio") + extras["audio"]
extras["torch-speech"] = deps_list("torchaudio") + extras["audio"]
......
......@@ -37,7 +37,6 @@ deps = {
"librosa": "librosa",
"nltk": "nltk",
"natten": "natten>=0.14.6",
"numba": "numba<0.57.0",
"numpy": "numpy>=1.17",
"onnxconverter-common": "onnxconverter-common",
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
......
......@@ -156,7 +156,15 @@ class BatchFeature(UserDict):
as_tensor = jnp.array
is_tensor = is_jax_tensor
else:
as_tensor = np.asarray
def as_tensor(value, dtype=None):
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
value_lens = [len(val) for val in value]
if len(set(value_lens)) > 1 and dtype is None:
# we have a ragged list so handle explicitly
value = as_tensor([np.asarray(val) for val in value], dtype=object)
return np.asarray(value, dtype=dtype)
is_tensor = is_numpy_array
# Do the tensor conversion in batch
......
......@@ -705,7 +705,15 @@ class BatchEncoding(UserDict):
as_tensor = jnp.array
is_tensor = is_jax_tensor
else:
as_tensor = np.asarray
def as_tensor(value, dtype=None):
if isinstance(value, (list, tuple)) and isinstance(value[0], (list, tuple, np.ndarray)):
value_lens = [len(val) for val in value]
if len(set(value_lens)) > 1 and dtype is None:
# we have a ragged list so handle explicitly
value = as_tensor([np.asarray(val) for val in value], dtype=object)
return np.asarray(value, dtype=dtype)
is_tensor = is_numpy_array
# Do the tensor conversion in batch
......
......@@ -392,7 +392,7 @@ class RealmModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
b"This is the fourth record.",
b"This is the fifth record.",
],
dtype=np.object,
dtype=object,
)
retriever = RealmRetriever(block_records, tokenizer)
model = RealmForOpenQA(openqa_config, retriever)
......
......@@ -100,7 +100,7 @@ class RealmRetrieverTest(TestCase):
b"This is the fifth record",
b"This is a longer longer longer record",
],
dtype=np.object,
dtype=object,
)
return block_records
......@@ -116,7 +116,7 @@ class RealmRetrieverTest(TestCase):
retriever = self.get_dummy_retriever()
tokenizer = retriever.tokenizer
retrieved_block_ids = np.array([0, 3], dtype=np.long)
retrieved_block_ids = np.array([0, 3], dtype="long")
question_input_ids = tokenizer(["Test question"]).input_ids
answer_ids = tokenizer(
["the fourth"],
......@@ -151,7 +151,7 @@ class RealmRetrieverTest(TestCase):
retriever = self.get_dummy_retriever()
tokenizer = retriever.tokenizer
retrieved_block_ids = np.array([0, 3, 5], dtype=np.long)
retrieved_block_ids = np.array([0, 3, 5], dtype="long")
question_input_ids = tokenizer(["Test question"]).input_ids
answer_ids = tokenizer(
["the fourth", "longer longer"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment