"...text-generation-inference.git" did not exist on "d1f257ac56f7a38ac964d499932e5ab5242b3363"
Unverified Commit afce73bd authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix ModelOutput instantiation when there is only one tuple (#20416)

parent 993a187c
...@@ -227,12 +227,20 @@ class ModelOutput(OrderedDict): ...@@ -227,12 +227,20 @@ class ModelOutput(OrderedDict):
# if we provided an iterator as first field and the iterator is a (key, value) iterator # if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields # set the associated fields
if first_field_iterator: if first_field_iterator:
for element in iterator: for idx, element in enumerate(iterator):
if ( if (
not isinstance(element, (list, tuple)) not isinstance(element, (list, tuple))
or not len(element) == 2 or not len(element) == 2
or not isinstance(element[0], str) or not isinstance(element[0], str)
): ):
if idx == 0:
# If we do not have an iterator of key/values, set it as attribute
self[class_fields[0].name] = first_field
else:
# If we have a mixed iterator, raise an error
raise ValueError(
f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
)
break break
setattr(self, element[0], element[1]) setattr(self, element[0], element[1])
if element[1] is not None: if element[1] is not None:
......
...@@ -107,3 +107,16 @@ class ModelOutputTester(unittest.TestCase): ...@@ -107,3 +107,16 @@ class ModelOutputTester(unittest.TestCase):
self.assertEqual(list(x.keys()), ["a", "b"]) self.assertEqual(list(x.keys()), ["a", "b"])
self.assertEqual(x.a, 30) self.assertEqual(x.a, 30)
self.assertEqual(x.b, 10) self.assertEqual(x.b, 10)
def test_instantiate_from_iterator(self):
x = ModelOutputTest([("a", 30), ("b", 10)])
self.assertEqual(list(x.keys()), ["a", "b"])
self.assertEqual(x.a, 30)
self.assertEqual(x.b, 10)
with self.assertRaises(ValueError):
_ = ModelOutputTest([("a", 30), (10, 10)])
x = ModelOutputTest(a=(30, 30))
self.assertEqual(list(x.keys()), ["a"])
self.assertEqual(x.a, (30, 30))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment