Commit e0014ba7 authored by wooway777's avatar wooway777
Browse files

issue/824 - reduce test sizes and improve rearrange_tensor

parent 6a5b1119
...@@ -391,7 +391,7 @@ def rearrange_tensor(tensor, new_strides): ...@@ -391,7 +391,7 @@ def rearrange_tensor(tensor, new_strides):
new_positions += offset new_positions += offset
# Copy the original data to the new tensor # Copy the original data to the new tensor
new_tensor.view(-1).index_add_(0, new_positions, tensor.view(-1)) new_tensor.reshape(-1).index_add_(0, new_positions, tensor.reshape(-1))
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides)) new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
return new_tensor return new_tensor
......
...@@ -34,8 +34,8 @@ _TEST_CASES_DATA = [ ...@@ -34,8 +34,8 @@ _TEST_CASES_DATA = [
20, 20,
16, 16,
128, 128,
(20 * 16 * 128 * 16, 16 * 128 * 4, 128 * 2, 1), (655360, 8192, 256, 1),
(20 * 16 * 128 * 16, 16 * 128 * 4, 128 * 2, 1), (655360, 8192, 256, 1),
RopeAlgo.GPT_NEOX, RopeAlgo.GPT_NEOX,
), ),
( (
...@@ -43,26 +43,26 @@ _TEST_CASES_DATA = [ ...@@ -43,26 +43,26 @@ _TEST_CASES_DATA = [
20, 20,
16, 16,
128, 128,
(20 * 16 * 128 * 16, 16 * 128 * 4, 128 * 2, 1), (655360, 8192, 256, 1),
(20 * 16 * 128 * 16, 16 * 128 * 4, 128 * 2, 1), (655360, 8192, 256, 1),
RopeAlgo.GPT_J, RopeAlgo.GPT_J,
), ),
( (
4, 4,
50, 50,
32, 32,
256, 8,
(50 * 32 * 256 * 16, 32 * 256 * 4, 256 * 2, 1), (204800, 1024, 16, 1),
(50 * 32 * 256 * 36, 32 * 256 * 6, 256 * 3, 1), (460800, 1536, 24, 1),
RopeAlgo.GPT_NEOX, RopeAlgo.GPT_NEOX,
), ),
( (
4, 4,
50, 50,
32, 32,
256, 8,
(50 * 32 * 256 * 16, 32 * 256 * 4, 256 * 2, 1), (204800, 1024, 16, 1),
(50 * 32 * 256 * 36, 32 * 256 * 6, 256 * 3, 1), (460800, 1536, 24, 1),
RopeAlgo.GPT_J, RopeAlgo.GPT_J,
), ),
( (
...@@ -70,8 +70,8 @@ _TEST_CASES_DATA = [ ...@@ -70,8 +70,8 @@ _TEST_CASES_DATA = [
64, 64,
8, 8,
128, 128,
(64 * 8 * 128 * 16, 8 * 128 * 4, 128 * 2, 1), (1048576, 4096, 256, 1),
(64 * 8 * 128 * 16, 8 * 128 * 4, 128 * 2, 1), (1048576, 4096, 256, 1),
RopeAlgo.GPT_NEOX, RopeAlgo.GPT_NEOX,
), ),
( (
...@@ -79,26 +79,62 @@ _TEST_CASES_DATA = [ ...@@ -79,26 +79,62 @@ _TEST_CASES_DATA = [
64, 64,
8, 8,
128, 128,
(64 * 8 * 128 * 16, 8 * 128 * 4, 128 * 2, 1), (1048576, 4096, 256, 1),
(64 * 8 * 128 * 16, 8 * 128 * 4, 128 * 2, 1), (1048576, 4096, 256, 1),
RopeAlgo.GPT_J, RopeAlgo.GPT_J,
), ),
( (
64, 64,
128, 17,
32, 32,
64, 64,
(128 * 32 * 64 * 16, 32 * 64 * 4, 64 * 2, 1), (557056, 8192, 128, 1),
(128 * 32 * 64 * 36, 32 * 64 * 6, 64 * 3, 1), (1253376, 12288, 192, 1),
RopeAlgo.GPT_NEOX, RopeAlgo.GPT_NEOX,
), ),
( (
64, 64,
128, 17,
32,
64,
(557056, 8192, 128, 1),
(1253376, 12288, 192, 1),
RopeAlgo.GPT_J,
),
(
8,
20,
4,
64,
(1048576, 64, 262144, 1),
(1048576, 64, 262144, 1),
RopeAlgo.GPT_NEOX,
),
(
8,
20,
4,
64,
(1048576, 64, 262144, 1),
(1048576, 64, 262144, 1),
RopeAlgo.GPT_J,
),
(
8,
20,
32,
64,
(40960, 64, 1280, 1),
(40960, 64, 1280, 1),
RopeAlgo.GPT_NEOX,
),
(
8,
20,
32, 32,
64, 64,
(128 * 32 * 64 * 16, 32 * 64 * 4, 64 * 2, 1), (40960, 64, 1280, 1),
(128 * 32 * 64 * 36, 32 * 64 * 6, 64 * 3, 1), (40960, 64, 1280, 1),
RopeAlgo.GPT_J, RopeAlgo.GPT_J,
), ),
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment