Unverified Commit e4293e50 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[hotfix] update test for latest version (#2060)

parent 19438ea0
...@@ -126,12 +126,13 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh): ...@@ -126,12 +126,13 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
if method in (torch.Tensor.view, torch.Tensor.reshape): if method in (torch.Tensor.view, torch.Tensor.reshape):
for arg in node.args: for arg in node.args:
if isinstance(arg, Node): if isinstance(arg, Node):
if isinstance(arg._meta_data, int): if isinstance(arg._meta_data, (int, tuple, list)):
new_args.append(arg._meta_data) new_args.append(arg._meta_data)
else: else:
new_args.append(arg) new_args.append(arg)
else: else:
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.' assert isinstance(
arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
new_args.append(arg) new_args.append(arg)
for dim, shard_dims in output_dim_partition_dict.items(): for dim, shard_dims in output_dim_partition_dict.items():
......
...@@ -102,12 +102,12 @@ def check_linear_module_handler(rank, bias, world_size, port): ...@@ -102,12 +102,12 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert len(strategy_name_list) > 8 assert len(strategy_name_list) > 8
# SS = SR x RS # SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list assert 'S1S0 = S1R x RS0_0' in strategy_name_list
# SR = SS x SR # SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list assert 'S1R = S1S0 x S0R_0' in strategy_name_list
# RS = RS x SS # RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list assert 'RS0 = RS1 x S1S0' in strategy_name_list
......
...@@ -95,12 +95,12 @@ def check_linear_module_handler(rank, bias, world_size, port): ...@@ -95,12 +95,12 @@ def check_linear_module_handler(rank, bias, world_size, port):
assert len(strategy_name_list) > 8 assert len(strategy_name_list) > 8
# SS = SR x RS # SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list assert 'S1S0 = S1R x RS0_0' in strategy_name_list
# SR = SS x SR # SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list assert 'S1R = S1S0 x S0R_0' in strategy_name_list
# RS = RS x SS # RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list assert 'RS0 = RS1 x S1S0' in strategy_name_list
...@@ -212,12 +212,12 @@ def check_linear_function_handler(rank, bias, world_size, port): ...@@ -212,12 +212,12 @@ def check_linear_function_handler(rank, bias, world_size, port):
assert len(strategy_name_list) > 8 assert len(strategy_name_list) > 8
# SS = SR x RS # SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list assert 'S0S1 = S0R x RS1_0' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list assert 'S1S0 = S1R x RS0_0' in strategy_name_list
# SR = SS x SR # SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list assert 'S0R = S0S1 x S1R_0' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list assert 'S1R = S1S0 x S0R_0' in strategy_name_list
# RS = RS x SS # RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list assert 'RS0 = RS1 x S1S0' in strategy_name_list
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment