Unverified Commit cec1c636 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Fix test (#13608)

parent 5c593718
......@@ -33,7 +33,7 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
outputs = feature_extractor("This is a test")
self.assertEqual(
nested_simplify(outputs),
[[[-0.454, 0.966, 0.619, 0.262, 0.669, -0.661, -0.066, -0.513, -0.768, -0.177, 1.771, -0.665, -0.649, 0.219, 0.236, -0.375, 1.155, -1.07, 0.208, -0.799, 1.065, -1.223, 0.554, 1.274, 0.458, 2.292, -0.481, -0.928, -2.469, -1.692, 0.182, 1.06], [-0.187, -1.277, 0.849, -0.439, -0.967, -1.347, 1.063, 0.469, 1.086, -1.253, 0.349, 0.057, 1.031, -1.903, -0.432, -1.377, 0.379, 0.733, -1.043, 1.307, 0.865, 0.229, 1.373, 1.671, -0.285, 0.599, -1.418, -1.179, -0.369, 1.039, -0.705, 1.082], [-1.735, 1.102, 0.398, -0.245, 1.452, 0.46, -1.734, -0.746, 1.831, 0.562, 1.464, -0.342, -0.619, -0.455, 0.127, -1.209, -0.686, -0.395, -0.316, 2.467, -0.379, 0.328, 0.639, 0.4, -1.097, -0.096, 0.397, -0.806, -1.621, 1.127, -0.345, 0.074], [0.296, -0.638, 1.938, -0.151, -1.19, 1.445, 1.318, 0.711, -0.125, 0.127, -2.179, 0.481, -1.019, 1.178, 0.318, 1.858, -1.646, 0.185, -0.072, -0.979, 0.82, -1.374, 0.836, -1.019, 0.043, -0.156, -0.095, 0.641, -0.195, -0.076, -1.554, 0.275], [-0.266, 0.971, 0.745, -0.37, 1.42, -0.5, -0.53, 0.061, 1.311, -0.1, 1.796, 0.53, -0.739, -0.325, 0.28, -1.72, 0.382, -1.118, 0.442, 1.84, -2.497, 1.003, -0.788, -0.224, -0.604, -1.259, -0.475, 1.18, -1.356, 0.695, 0.201, 0.016], [-0.618, -1.495, -0.67, -0.106, -1.265, -0.51, -1.752, 1.018, 0.674, 0.181, 0.297, 0.479, -0.185, 0.081, -2.44, -0.239, 1.081, -1.38, 0.679, 0.878, 1.336, -1.347, 0.969, -0.847, 0.293, 0.476, 1.647, -0.641, 0.66, 1.236, 0.761, 0.751]]]) # fmt: skip
[[[0.216, 0.686, 1.908, 0.125, 1.548, -0.869, 0.78, -0.527, -1.555, -1.012, -0.08, -2.646, -1.333, 1.2, -0.529, 1.109, -0.694, 0.209, -0.986, 0.475, -0.44, -0.641, -0.409, 1.204, -1.311, 0.316, -0.08, 0.454, 0.106, 0.923, 0.745, 1.106], [1.488, -0.283, 0.691, -2.345, -0.144, -1.454, -0.535, 0.976, -1.304, 0.134, -1.707, -0.18, 0.33, 0.982, -1.026, 0.076, 1.223, 0.819, 1.437, 0.549, 0.257, 0.307, -0.304, 0.154, -0.075, -0.583, 0.157, 0.11, 0.921, -2.434, 0.739, 1.024], [-0.328, 0.284, -0.666, 1.846, 0.158, -1.723, -0.865, -0.143, 0.09, -0.517, 0.96, -0.847, -1.069, 0.099, -0.796, 0.384, 1.594, 0.764, -1.596, 0.055, -0.484, 0.208, -0.529, 0.849, 0.051, 2.725, 2.043, -0.864, -0.497, -0.866, -0.209, -0.113], [-0.628, 0.513, -0.434, -0.906, 1.02, -1.155, 1.308, -0.144, 0.861, 0.825, 2.051, 1.127, 1.513, 0.367, 0.575, 0.72, 0.471, 0.36, -0.861, -1.835, -0.026, -0.646, 1.192, -2.123, -0.759, 0.634, -0.296, -0.161, -0.633, -0.698, -1.741, -0.492], [-0.444, 0.616, -0.252, 0.594, -0.219, 0.417, -1.118, -0.088, 1.127, -1.674, 0.762, -0.156, 1.655, 0.965, -1.03, -0.853, -0.037, 1.38, -0.726, -0.469, 0.635, 0.898, -1.506, -0.519, -0.669, 0.406, 1.767, 2.215, -1.425, -1.238, -0.878, -0.137], [1.36, -0.347, -0.009, -0.146, -1.669, 0.735, -2.14, -0.82, 0.739, 0.962, 0.393, -0.371, 0.055, -1.619, -0.605, 1.487, 1.335, 0.697, 0.867, -1.043, 0.833, -1.281, -1.389, 0.815, 0.934, 0.009, 1.075, -0.37, 0.804, -0.635, -1.327, 0.671], [-0.529, 0.78, -1.844, 1.068, 0.689, 1.022, 0.282, -1.327, 0.28, 0.119, 2.389, 0.272, 0.947, 0.404, -0.84, -0.364, -1.209, 0.759, -0.953, 0.369, -0.866, -0.406, -1.054, -0.758, -0.877, 1.837, 1.59, 0.338, -1.003, -0.211, -1.322, 0.42], [0.332, 1.92, 1.158, -0.218, -0.812, -0.676, 1.813, 0.185, -0.191, -0.654, 0.93, 0.845, -0.365, -1.043, -1.189, 0.503, 1.384, 0.161, -0.259, -1.295, 0.694, -1.925, 2.034, 0.57, 0.541, -0.76, 0.199, -1.324, -1.216, -0.444, -0.88, -0.016], [-0.743, -0.093, -0.205, 0.105, -1.328, -0.001, 0.461, 0.787, 0.118, -0.707, 0.101, -0.134, -1.015, 0.494, -1.198, 0.477, 1.271, 0.056, 2.37, 1.827, 1.16, 0.485, -1.197, -0.191, -0.113, -0.012, -1.599, -0.125, 1.035, -2.789, 0.411, 0.293], [-0.612, 1.765, 0.316, -0.193, -0.349, -0.249, -0.168, 0.96, 0.037, 1.451, -0.089, 0.811, 0.166, 0.87, 0.079, 0.885, 1.08, -0.043, 0.258, 0.577, -2.287, -1.271, -2.109, 1.513, -0.846, -1.92, -0.116, 1.476, -0.279, -0.39, -0.815, -0.509], [-0.591, 0.048, -0.422, 1.068, -0.005, 1.891, 0.094, -1.655, -0.857, -0.981, 3.114, 0.789, 0.363, -1.244, -0.046, -0.964, 0.169, 0.806, -0.213, 0.566, 0.25, -0.205, -1.564, 0.263, 0.079, 0.565, 0.597, 1.045, -1.327, -0.293, 0.103, -1.443], [2.397, 1.062, -0.573, 1.594, 0.521, 0.241, 0.896, -1.182, -0.563, -0.935, -0.992, -0.655, 1.825, -1.338, -0.335, -0.868, 1.297, -0.299, 0.411, -0.638, -0.404, -1.499, 0.753, 1.012, -0.144, 0.704, -1.337, 0.454, -1.034, 0.564, -0.943, 0.008], [0.039, 1.009, 1.153, -0.478, -0.411, 0.065, -1.489, -0.562, 0.796, 0.388, -0.451, 2.051, 0.514, -0.756, 0.661, -1.563, 1.441, 0.116, 1.161, -0.769, 0.152, 0.301, -1.23, 0.185, -0.848, -1.246, 0.008, -1.513, -0.283, -0.811, 2.601, -0.233]]]) # fmt: skip
@require_tf
def test_small_model_tf(self):
......@@ -43,7 +43,7 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
outputs = feature_extractor("This is a test")
self.assertEqual(
nested_simplify(outputs),
[[[-0.454, 0.966, 0.619, 0.262, 0.669, -0.661, -0.066, -0.513, -0.768, -0.177, 1.771, -0.665, -0.649, 0.219, 0.236, -0.375, 1.155, -1.07, 0.208, -0.799, 1.065, -1.223, 0.554, 1.274, 0.458, 2.292, -0.481, -0.928, -2.469, -1.692, 0.182, 1.06], [-0.187, -1.277, 0.849, -0.439, -0.967, -1.347, 1.063, 0.469, 1.086, -1.253, 0.349, 0.057, 1.031, -1.903, -0.432, -1.377, 0.379, 0.733, -1.043, 1.307, 0.865, 0.229, 1.373, 1.671, -0.285, 0.599, -1.418, -1.179, -0.369, 1.039, -0.705, 1.082], [-1.735, 1.102, 0.398, -0.245, 1.452, 0.46, -1.734, -0.746, 1.831, 0.562, 1.464, -0.342, -0.619, -0.455, 0.127, -1.209, -0.686, -0.395, -0.316, 2.467, -0.379, 0.328, 0.639, 0.4, -1.097, -0.096, 0.397, -0.806, -1.621, 1.127, -0.345, 0.074], [0.296, -0.638, 1.938, -0.151, -1.19, 1.445, 1.318, 0.711, -0.125, 0.127, -2.179, 0.481, -1.019, 1.178, 0.318, 1.858, -1.646, 0.185, -0.072, -0.979, 0.82, -1.374, 0.836, -1.019, 0.043, -0.156, -0.095, 0.641, -0.195, -0.076, -1.554, 0.275], [-0.266, 0.971, 0.745, -0.37, 1.42, -0.5, -0.53, 0.061, 1.311, -0.1, 1.796, 0.53, -0.739, -0.325, 0.28, -1.72, 0.382, -1.118, 0.442, 1.84, -2.497, 1.003, -0.788, -0.224, -0.604, -1.259, -0.475, 1.18, -1.356, 0.695, 0.201, 0.016], [-0.618, -1.495, -0.67, -0.106, -1.265, -0.51, -1.752, 1.018, 0.674, 0.181, 0.297, 0.479, -0.185, 0.081, -2.44, -0.239, 1.081, -1.38, 0.679, 0.878, 1.336, -1.347, 0.969, -0.847, 0.293, 0.476, 1.647, -0.641, 0.66, 1.236, 0.761, 0.751]]]) # fmt: skip
[[[0.216, 0.686, 1.908, 0.125, 1.548, -0.869, 0.78, -0.527, -1.555, -1.012, -0.08, -2.646, -1.333, 1.2, -0.529, 1.109, -0.694, 0.209, -0.986, 0.475, -0.44, -0.641, -0.409, 1.204, -1.311, 0.316, -0.08, 0.454, 0.106, 0.923, 0.745, 1.106], [1.488, -0.283, 0.691, -2.345, -0.144, -1.454, -0.535, 0.976, -1.304, 0.134, -1.707, -0.18, 0.33, 0.982, -1.026, 0.076, 1.223, 0.819, 1.437, 0.549, 0.257, 0.307, -0.304, 0.154, -0.075, -0.583, 0.157, 0.11, 0.921, -2.434, 0.739, 1.024], [-0.328, 0.284, -0.666, 1.846, 0.158, -1.723, -0.865, -0.143, 0.09, -0.517, 0.96, -0.847, -1.069, 0.099, -0.796, 0.384, 1.594, 0.764, -1.596, 0.055, -0.484, 0.208, -0.529, 0.849, 0.051, 2.725, 2.043, -0.864, -0.497, -0.866, -0.209, -0.113], [-0.628, 0.513, -0.434, -0.906, 1.02, -1.155, 1.308, -0.144, 0.861, 0.825, 2.051, 1.127, 1.513, 0.367, 0.575, 0.72, 0.471, 0.36, -0.861, -1.835, -0.026, -0.646, 1.192, -2.123, -0.759, 0.634, -0.296, -0.161, -0.633, -0.698, -1.741, -0.492], [-0.444, 0.616, -0.252, 0.594, -0.219, 0.417, -1.118, -0.088, 1.127, -1.674, 0.762, -0.156, 1.655, 0.965, -1.03, -0.853, -0.037, 1.38, -0.726, -0.469, 0.635, 0.898, -1.506, -0.519, -0.669, 0.406, 1.767, 2.215, -1.425, -1.238, -0.878, -0.137], [1.36, -0.347, -0.009, -0.146, -1.669, 0.735, -2.14, -0.82, 0.739, 0.962, 0.393, -0.371, 0.055, -1.619, -0.605, 1.487, 1.335, 0.697, 0.867, -1.043, 0.833, -1.281, -1.389, 0.815, 0.934, 0.009, 1.075, -0.37, 0.804, -0.635, -1.327, 0.671], [-0.529, 0.78, -1.844, 1.068, 0.689, 1.022, 0.282, -1.327, 0.28, 0.119, 2.389, 0.272, 0.947, 0.404, -0.84, -0.364, -1.209, 0.759, -0.953, 0.369, -0.866, -0.406, -1.054, -0.758, -0.877, 1.837, 1.59, 0.338, -1.003, -0.211, -1.322, 0.42], [0.332, 1.92, 1.158, -0.218, -0.812, -0.676, 1.813, 0.185, -0.191, -0.654, 0.93, 0.845, -0.365, -1.043, -1.189, 0.503, 1.384, 0.161, -0.259, -1.295, 0.694, -1.925, 2.034, 0.57, 0.541, -0.76, 0.199, -1.324, -1.216, -0.444, -0.88, -0.016], [-0.743, -0.093, -0.205, 0.105, -1.328, -0.001, 0.461, 0.787, 0.118, -0.707, 0.101, -0.134, -1.015, 0.494, -1.198, 0.477, 1.271, 0.056, 2.37, 1.827, 1.16, 0.485, -1.197, -0.191, -0.113, -0.012, -1.599, -0.125, 1.035, -2.789, 0.411, 0.293], [-0.612, 1.765, 0.316, -0.193, -0.349, -0.249, -0.168, 0.96, 0.037, 1.451, -0.089, 0.811, 0.166, 0.87, 0.079, 0.885, 1.08, -0.043, 0.258, 0.577, -2.287, -1.271, -2.109, 1.513, -0.846, -1.92, -0.116, 1.476, -0.279, -0.39, -0.815, -0.509], [-0.591, 0.048, -0.422, 1.068, -0.005, 1.891, 0.094, -1.655, -0.857, -0.981, 3.114, 0.789, 0.363, -1.244, -0.046, -0.964, 0.169, 0.806, -0.213, 0.566, 0.25, -0.205, -1.564, 0.263, 0.079, 0.565, 0.597, 1.045, -1.327, -0.293, 0.103, -1.443], [2.397, 1.062, -0.573, 1.594, 0.521, 0.241, 0.896, -1.182, -0.563, -0.935, -0.992, -0.655, 1.825, -1.338, -0.335, -0.868, 1.297, -0.299, 0.411, -0.638, -0.404, -1.499, 0.753, 1.012, -0.144, 0.704, -1.337, 0.454, -1.034, 0.564, -0.943, 0.008], [0.039, 1.009, 1.153, -0.478, -0.411, 0.065, -1.489, -0.562, 0.796, 0.388, -0.451, 2.051, 0.514, -0.756, 0.661, -1.563, 1.441, 0.116, 1.161, -0.769, 0.152, 0.301, -1.23, 0.185, -0.848, -1.246, 0.008, -1.513, -0.283, -0.811, 2.601, -0.233]]]) # fmt: skip
def get_shape(self, input_, shape=None):
if shape is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment