Unverified Commit 300ee0c7 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Updated tiny distilbert models (#13631)

parent afb07a79
......@@ -33,7 +33,7 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
outputs = feature_extractor("This is a test")
self.assertEqual(
nested_simplify(outputs),
[[[0.216, 0.686, 1.908, 0.125, 1.548, -0.869, 0.78, -0.527, -1.555, -1.012, -0.08, -2.646, -1.333, 1.2, -0.529, 1.109, -0.694, 0.209, -0.986, 0.475, -0.44, -0.641, -0.409, 1.204, -1.311, 0.316, -0.08, 0.454, 0.106, 0.923, 0.745, 1.106], [1.488, -0.283, 0.691, -2.345, -0.144, -1.454, -0.535, 0.976, -1.304, 0.134, -1.707, -0.18, 0.33, 0.982, -1.026, 0.076, 1.223, 0.819, 1.437, 0.549, 0.257, 0.307, -0.304, 0.154, -0.075, -0.583, 0.157, 0.11, 0.921, -2.434, 0.739, 1.024], [-0.328, 0.284, -0.666, 1.846, 0.158, -1.723, -0.865, -0.143, 0.09, -0.517, 0.96, -0.847, -1.069, 0.099, -0.796, 0.384, 1.594, 0.764, -1.596, 0.055, -0.484, 0.208, -0.529, 0.849, 0.051, 2.725, 2.043, -0.864, -0.497, -0.866, -0.209, -0.113], [-0.628, 0.513, -0.434, -0.906, 1.02, -1.155, 1.308, -0.144, 0.861, 0.825, 2.051, 1.127, 1.513, 0.367, 0.575, 0.72, 0.471, 0.36, -0.861, -1.835, -0.026, -0.646, 1.192, -2.123, -0.759, 0.634, -0.296, -0.161, -0.633, -0.698, -1.741, -0.492], [-0.444, 0.616, -0.252, 0.594, -0.219, 0.417, -1.118, -0.088, 1.127, -1.674, 0.762, -0.156, 1.655, 0.965, -1.03, -0.853, -0.037, 1.38, -0.726, -0.469, 0.635, 0.898, -1.506, -0.519, -0.669, 0.406, 1.767, 2.215, -1.425, -1.238, -0.878, -0.137], [1.36, -0.347, -0.009, -0.146, -1.669, 0.735, -2.14, -0.82, 0.739, 0.962, 0.393, -0.371, 0.055, -1.619, -0.605, 1.487, 1.335, 0.697, 0.867, -1.043, 0.833, -1.281, -1.389, 0.815, 0.934, 0.009, 1.075, -0.37, 0.804, -0.635, -1.327, 0.671], [-0.529, 0.78, -1.844, 1.068, 0.689, 1.022, 0.282, -1.327, 0.28, 0.119, 2.389, 0.272, 0.947, 0.404, -0.84, -0.364, -1.209, 0.759, -0.953, 0.369, -0.866, -0.406, -1.054, -0.758, -0.877, 1.837, 1.59, 0.338, -1.003, -0.211, -1.322, 0.42], [0.332, 1.92, 1.158, -0.218, -0.812, -0.676, 1.813, 0.185, -0.191, -0.654, 0.93, 0.845, -0.365, -1.043, -1.189, 0.503, 1.384, 0.161, -0.259, -1.295, 0.694, -1.925, 2.034, 0.57, 0.541, -0.76, 0.199, -1.324, -1.216, -0.444, -0.88, -0.016], [-0.743, -0.093, -0.205, 0.105, -1.328, -0.001, 0.461, 0.787, 0.118, -0.707, 0.101, -0.134, -1.015, 0.494, -1.198, 0.477, 1.271, 0.056, 2.37, 1.827, 1.16, 0.485, -1.197, -0.191, -0.113, -0.012, -1.599, -0.125, 1.035, -2.789, 0.411, 0.293], [-0.612, 1.765, 0.316, -0.193, -0.349, -0.249, -0.168, 0.96, 0.037, 1.451, -0.089, 0.811, 0.166, 0.87, 0.079, 0.885, 1.08, -0.043, 0.258, 0.577, -2.287, -1.271, -2.109, 1.513, -0.846, -1.92, -0.116, 1.476, -0.279, -0.39, -0.815, -0.509], [-0.591, 0.048, -0.422, 1.068, -0.005, 1.891, 0.094, -1.655, -0.857, -0.981, 3.114, 0.789, 0.363, -1.244, -0.046, -0.964, 0.169, 0.806, -0.213, 0.566, 0.25, -0.205, -1.564, 0.263, 0.079, 0.565, 0.597, 1.045, -1.327, -0.293, 0.103, -1.443], [2.397, 1.062, -0.573, 1.594, 0.521, 0.241, 0.896, -1.182, -0.563, -0.935, -0.992, -0.655, 1.825, -1.338, -0.335, -0.868, 1.297, -0.299, 0.411, -0.638, -0.404, -1.499, 0.753, 1.012, -0.144, 0.704, -1.337, 0.454, -1.034, 0.564, -0.943, 0.008], [0.039, 1.009, 1.153, -0.478, -0.411, 0.065, -1.489, -0.562, 0.796, 0.388, -0.451, 2.051, 0.514, -0.756, 0.661, -1.563, 1.441, 0.116, 1.161, -0.769, 0.152, 0.301, -1.23, 0.185, -0.848, -1.246, 0.008, -1.513, -0.283, -0.811, 2.601, -0.233]]]) # fmt: skip
[[[2.287, 1.234, 0.042, 1.53, 1.306, 0.879, -0.526, -1.71, -1.276, 0.756, -0.775, -1.048, -0.25, -0.595, -0.137, -0.598, 2.022, -0.812, 0.284, -0.488, -0.391, -0.403, -0.525, -0.061, -0.228, 1.086, 0.378, -0.14, 0.599, -0.087, -2.259, -0.098], [1.676, 0.232, -1.508, -0.145, 1.798, -1.388, 1.331, -0.37, -0.939, 0.043, 0.06, -0.414, -1.408, 0.24, 0.622, -0.55, -0.569, 1.873, -0.706, 1.924, -0.254, 1.927, -0.423, 0.152, -0.952, 0.509, -0.496, -0.968, 0.093, -1.049, -0.65, 0.312], [0.207, -0.775, -1.822, 0.321, -0.71, -0.201, 0.3, 1.146, -0.233, -0.753, -0.305, 1.309, -1.47, -0.21, 1.802, -1.555, -1.175, 1.323, -0.303, 0.722, -0.076, 0.103, -1.406, 1.931, 0.091, 0.237, 1.172, 1.607, 0.253, -0.9, -1.068, 0.438], [0.615, 1.077, 0.171, -0.175, 1.3, 0.901, -0.653, -0.138, 0.341, -0.654, -0.184, -0.441, -0.424, 0.356, -0.075, 0.26, -1.023, 0.814, 0.524, -0.904, -0.204, -0.623, 1.234, -1.03, 2.594, 0.56, 1.831, -0.199, -1.508, -0.492, -1.687, -2.165], [0.129, 0.008, -1.279, -0.412, -0.004, 1.663, 0.196, 0.104, 0.123, 0.119, 0.635, 1.757, 2.334, -0.799, -1.626, -1.26, 0.595, -0.316, -1.399, 0.232, 0.264, 1.386, -1.171, -0.256, -0.256, -1.944, 1.168, -0.368, -0.714, -0.51, 0.454, 1.148], [-0.32, 0.29, -1.309, -0.177, 0.453, 0.636, -0.024, 0.509, 0.931, -1.754, -1.575, 0.786, 0.046, -1.165, -1.416, 1.373, 1.293, -0.285, -1.541, -1.186, -0.106, -0.994, 2.001, 0.972, -0.02, 1.654, -0.236, 0.643, 1.02, 0.572, -0.914, -0.154], [0.7, -0.937, 0.441, 0.25, 0.78, -0.022, 0.282, -0.095, 1.558, -0.336, 1.706, 0.884, 1.28, 0.198, -0.796, 1.218, -1.769, 1.197, -0.342, -0.177, -0.645, 1.364, 0.008, -0.597, -0.484, -2.772, -0.696, -0.632, -0.34, -1.527, -0.562, 0.862], [2.504, 0.831, -1.271, -0.033, 0.298, -0.735, 1.339, 1.74, 0.233, -1.424, -0.819, -0.761, 0.291, 0.853, -0.092, -0.885, 0.164, 1.025, 0.907, 0.749, -1.515, -0.545, -1.365, 0.271, 0.034, -2.005, 0.031, 0.244, 0.621, 0.176, 0.336, -1.196], [-0.711, 0.591, -1.001, -0.946, 0.784, -1.66, 1.545, 0.799, -0.857, 1.148, 0.213, -0.285, 0.464, -0.139, 0.79, -1.663, -1.121, 0.575, -0.178, -0.508, 1.565, -0.242, -0.346, 1.024, -1.135, -0.158, -2.101, 0.275, 2.009, -0.425, 0.716, 0.981], [0.912, -1.186, -0.846, -0.421, -1.315, -0.827, 0.309, 0.533, 1.029, -2.343, 1.513, -1.238, 1.487, -0.849, 0.896, -0.927, -0.459, 0.159, 0.177, 0.873, 0.935, 1.433, -0.485, 0.737, 1.327, -0.338, 1.608, -0.47, -0.445, -1.118, -0.213, -0.446], [-0.434, -1.362, -1.098, -1.068, 1.507, 0.003, 0.413, -0.395, 0.897, -0.237, 1.405, -0.344, 1.693, 0.677, 0.097, -0.257, -0.602, 1.026, -1.229, 0.855, -0.713, 1.014, 0.443, 0.238, 0.425, -2.184, 1.933, -1.157, -1.132, -0.597, -0.785, 0.967], [0.58, -0.971, 0.789, -0.468, -0.576, 1.779, 1.747, 1.715, -1.939, 0.125, 0.656, -0.042, -1.024, -1.767, 0.107, -0.408, -0.866, -1.774, 1.248, 0.939, -0.033, 1.523, 1.168, -0.744, 0.209, -0.168, -0.316, 0.207, -0.432, 0.047, -0.646, -0.664], [-0.185, -0.613, -1.695, 1.602, -0.32, -0.277, 0.967, 0.728, -0.965, -0.234, 1.069, -0.63, -1.631, 0.711, 0.426, 1.298, -0.191, -0.467, -0.771, 0.971, -0.118, -1.577, -2.064, -0.055, -0.59, 0.642, -0.997, 1.251, 0.538, 1.367, 0.106, 1.704]]]) # fmt: skip
@require_tf
def test_small_model_tf(self):
......@@ -43,7 +43,7 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
outputs = feature_extractor("This is a test")
self.assertEqual(
nested_simplify(outputs),
[[[0.216, 0.686, 1.908, 0.125, 1.548, -0.869, 0.78, -0.527, -1.555, -1.012, -0.08, -2.646, -1.333, 1.2, -0.529, 1.109, -0.694, 0.209, -0.986, 0.475, -0.44, -0.641, -0.409, 1.204, -1.311, 0.316, -0.08, 0.454, 0.106, 0.923, 0.745, 1.106], [1.488, -0.283, 0.691, -2.345, -0.144, -1.454, -0.535, 0.976, -1.304, 0.134, -1.707, -0.18, 0.33, 0.982, -1.026, 0.076, 1.223, 0.819, 1.437, 0.549, 0.257, 0.307, -0.304, 0.154, -0.075, -0.583, 0.157, 0.11, 0.921, -2.434, 0.739, 1.024], [-0.328, 0.284, -0.666, 1.846, 0.158, -1.723, -0.865, -0.143, 0.09, -0.517, 0.96, -0.847, -1.069, 0.099, -0.796, 0.384, 1.594, 0.764, -1.596, 0.055, -0.484, 0.208, -0.529, 0.849, 0.051, 2.725, 2.043, -0.864, -0.497, -0.866, -0.209, -0.113], [-0.628, 0.513, -0.434, -0.906, 1.02, -1.155, 1.308, -0.144, 0.861, 0.825, 2.051, 1.127, 1.513, 0.367, 0.575, 0.72, 0.471, 0.36, -0.861, -1.835, -0.026, -0.646, 1.192, -2.123, -0.759, 0.634, -0.296, -0.161, -0.633, -0.698, -1.741, -0.492], [-0.444, 0.616, -0.252, 0.594, -0.219, 0.417, -1.118, -0.088, 1.127, -1.674, 0.762, -0.156, 1.655, 0.965, -1.03, -0.853, -0.037, 1.38, -0.726, -0.469, 0.635, 0.898, -1.506, -0.519, -0.669, 0.406, 1.767, 2.215, -1.425, -1.238, -0.878, -0.137], [1.36, -0.347, -0.009, -0.146, -1.669, 0.735, -2.14, -0.82, 0.739, 0.962, 0.393, -0.371, 0.055, -1.619, -0.605, 1.487, 1.335, 0.697, 0.867, -1.043, 0.833, -1.281, -1.389, 0.815, 0.934, 0.009, 1.075, -0.37, 0.804, -0.635, -1.327, 0.671], [-0.529, 0.78, -1.844, 1.068, 0.689, 1.022, 0.282, -1.327, 0.28, 0.119, 2.389, 0.272, 0.947, 0.404, -0.84, -0.364, -1.209, 0.759, -0.953, 0.369, -0.866, -0.406, -1.054, -0.758, -0.877, 1.837, 1.59, 0.338, -1.003, -0.211, -1.322, 0.42], [0.332, 1.92, 1.158, -0.218, -0.812, -0.676, 1.813, 0.185, -0.191, -0.654, 0.93, 0.845, -0.365, -1.043, -1.189, 0.503, 1.384, 0.161, -0.259, -1.295, 0.694, -1.925, 2.034, 0.57, 0.541, -0.76, 0.199, -1.324, -1.216, -0.444, -0.88, -0.016], [-0.743, -0.093, -0.205, 0.105, -1.328, -0.001, 0.461, 0.787, 0.118, -0.707, 0.101, -0.134, -1.015, 0.494, -1.198, 0.477, 1.271, 0.056, 2.37, 1.827, 1.16, 0.485, -1.197, -0.191, -0.113, -0.012, -1.599, -0.125, 1.035, -2.789, 0.411, 0.293], [-0.612, 1.765, 0.316, -0.193, -0.349, -0.249, -0.168, 0.96, 0.037, 1.451, -0.089, 0.811, 0.166, 0.87, 0.079, 0.885, 1.08, -0.043, 0.258, 0.577, -2.287, -1.271, -2.109, 1.513, -0.846, -1.92, -0.116, 1.476, -0.279, -0.39, -0.815, -0.509], [-0.591, 0.048, -0.422, 1.068, -0.005, 1.891, 0.094, -1.655, -0.857, -0.981, 3.114, 0.789, 0.363, -1.244, -0.046, -0.964, 0.169, 0.806, -0.213, 0.566, 0.25, -0.205, -1.564, 0.263, 0.079, 0.565, 0.597, 1.045, -1.327, -0.293, 0.103, -1.443], [2.397, 1.062, -0.573, 1.594, 0.521, 0.241, 0.896, -1.182, -0.563, -0.935, -0.992, -0.655, 1.825, -1.338, -0.335, -0.868, 1.297, -0.299, 0.411, -0.638, -0.404, -1.499, 0.753, 1.012, -0.144, 0.704, -1.337, 0.454, -1.034, 0.564, -0.943, 0.008], [0.039, 1.009, 1.153, -0.478, -0.411, 0.065, -1.489, -0.562, 0.796, 0.388, -0.451, 2.051, 0.514, -0.756, 0.661, -1.563, 1.441, 0.116, 1.161, -0.769, 0.152, 0.301, -1.23, 0.185, -0.848, -1.246, 0.008, -1.513, -0.283, -0.811, 2.601, -0.233]]]) # fmt: skip
[[[2.287, 1.234, 0.042, 1.53, 1.306, 0.879, -0.526, -1.71, -1.276, 0.756, -0.775, -1.048, -0.25, -0.595, -0.137, -0.598, 2.022, -0.812, 0.284, -0.488, -0.391, -0.403, -0.525, -0.061, -0.228, 1.086, 0.378, -0.14, 0.599, -0.087, -2.259, -0.098], [1.676, 0.232, -1.508, -0.145, 1.798, -1.388, 1.331, -0.37, -0.939, 0.043, 0.06, -0.414, -1.408, 0.24, 0.622, -0.55, -0.569, 1.873, -0.706, 1.924, -0.254, 1.927, -0.423, 0.152, -0.952, 0.509, -0.496, -0.968, 0.093, -1.049, -0.65, 0.312], [0.207, -0.775, -1.822, 0.321, -0.71, -0.201, 0.3, 1.146, -0.233, -0.753, -0.305, 1.309, -1.47, -0.21, 1.802, -1.555, -1.175, 1.323, -0.303, 0.722, -0.076, 0.103, -1.406, 1.931, 0.091, 0.237, 1.172, 1.607, 0.253, -0.9, -1.068, 0.438], [0.615, 1.077, 0.171, -0.175, 1.3, 0.901, -0.653, -0.138, 0.341, -0.654, -0.184, -0.441, -0.424, 0.356, -0.075, 0.26, -1.023, 0.814, 0.524, -0.904, -0.204, -0.623, 1.234, -1.03, 2.594, 0.56, 1.831, -0.199, -1.508, -0.492, -1.687, -2.165], [0.129, 0.008, -1.279, -0.412, -0.004, 1.663, 0.196, 0.104, 0.123, 0.119, 0.635, 1.757, 2.334, -0.799, -1.626, -1.26, 0.595, -0.316, -1.399, 0.232, 0.264, 1.386, -1.171, -0.256, -0.256, -1.944, 1.168, -0.368, -0.714, -0.51, 0.454, 1.148], [-0.32, 0.29, -1.309, -0.177, 0.453, 0.636, -0.024, 0.509, 0.931, -1.754, -1.575, 0.786, 0.046, -1.165, -1.416, 1.373, 1.293, -0.285, -1.541, -1.186, -0.106, -0.994, 2.001, 0.972, -0.02, 1.654, -0.236, 0.643, 1.02, 0.572, -0.914, -0.154], [0.7, -0.937, 0.441, 0.25, 0.78, -0.022, 0.282, -0.095, 1.558, -0.336, 1.706, 0.884, 1.28, 0.198, -0.796, 1.218, -1.769, 1.197, -0.342, -0.177, -0.645, 1.364, 0.008, -0.597, -0.484, -2.772, -0.696, -0.632, -0.34, -1.527, -0.562, 0.862], [2.504, 0.831, -1.271, -0.033, 0.298, -0.735, 1.339, 1.74, 0.233, -1.424, -0.819, -0.761, 0.291, 0.853, -0.092, -0.885, 0.164, 1.025, 0.907, 0.749, -1.515, -0.545, -1.365, 0.271, 0.034, -2.005, 0.031, 0.244, 0.621, 0.176, 0.336, -1.196], [-0.711, 0.591, -1.001, -0.946, 0.784, -1.66, 1.545, 0.799, -0.857, 1.148, 0.213, -0.285, 0.464, -0.139, 0.79, -1.663, -1.121, 0.575, -0.178, -0.508, 1.565, -0.242, -0.346, 1.024, -1.135, -0.158, -2.101, 0.275, 2.009, -0.425, 0.716, 0.981], [0.912, -1.186, -0.846, -0.421, -1.315, -0.827, 0.309, 0.533, 1.029, -2.343, 1.513, -1.238, 1.487, -0.849, 0.896, -0.927, -0.459, 0.159, 0.177, 0.873, 0.935, 1.433, -0.485, 0.737, 1.327, -0.338, 1.608, -0.47, -0.445, -1.118, -0.213, -0.446], [-0.434, -1.362, -1.098, -1.068, 1.507, 0.003, 0.413, -0.395, 0.897, -0.237, 1.405, -0.344, 1.693, 0.677, 0.097, -0.257, -0.602, 1.026, -1.229, 0.855, -0.713, 1.014, 0.443, 0.238, 0.425, -2.184, 1.933, -1.157, -1.132, -0.597, -0.785, 0.967], [0.58, -0.971, 0.789, -0.468, -0.576, 1.779, 1.747, 1.715, -1.939, 0.125, 0.656, -0.042, -1.024, -1.767, 0.107, -0.408, -0.866, -1.774, 1.248, 0.939, -0.033, 1.523, 1.168, -0.744, 0.209, -0.168, -0.316, 0.207, -0.432, 0.047, -0.646, -0.664], [-0.185, -0.613, -1.695, 1.602, -0.32, -0.277, 0.967, 0.728, -0.965, -0.234, 1.069, -0.63, -1.631, 0.711, 0.426, 1.298, -0.191, -0.467, -0.771, 0.971, -0.118, -1.577, -2.064, -0.055, -0.59, 0.642, -0.997, 1.251, 0.538, 1.367, 0.106, 1.704]]]) # fmt: skip
def get_shape(self, input_, shape=None):
if shape is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment