Unverified Commit b3a03e4b authored by Ziyue Jiang's avatar Ziyue Jiang Committed by GitHub
Browse files

[Tensor] fix equal assert (#1091)

* fix equal assert

* polish
parent 50ec3a7e
...@@ -15,7 +15,7 @@ from colossalai.context import ParallelMode ...@@ -15,7 +15,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.nn.optimizer import ColoOptimizer from colossalai.nn.optimizer import ColoOptimizer
from functools import partial from functools import partial
from _utils import set_seed from _utils import tensor_equal, tensor_shard_equal, set_seed
def init_1d_row_linear(weight): def init_1d_row_linear(weight):
...@@ -144,20 +144,8 @@ def run_1d_hybrid_tp(model_name): ...@@ -144,20 +144,8 @@ def run_1d_hybrid_tp(model_name):
with torch.no_grad(): with torch.no_grad():
# check param # check param
for p1, p2 in zip(model.parameters(), model_torch.parameters()): for p, torch_p in zip(model.parameters(), model_torch.parameters()):
if p1.size() == p2.size(): assert tensor_shard_equal(torch_p, p)
assert torch.allclose(p1, p2)
else:
# TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec.
if p1.size(-1) < p2.size(-1): # col
world_size = p2.size(-1) // p1.size(-1)
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
elif p1.size(0) < p2.size(0): # row
world_size = p2.size(0) // p1.size(0)
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
assert torch.allclose(p1, split_p2)
if i > 5: if i > 5:
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment