Unverified Commit 27de2523 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] fix conv handler numerical test (#1771)

parent 1e88811c
...@@ -141,14 +141,31 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -141,14 +141,31 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0, logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK) comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias and self.is_param("bias"): else:
bias_comm_action = self.get_communication_action( other_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"], sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0, logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK) comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping["bias"] = bias_comm_action communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
...@@ -180,14 +197,31 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -180,14 +197,31 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0, logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK) comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias and self.is_param("bias"): else:
bias_comm_action = self.get_communication_action( other_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"], sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0, logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK) comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping["bias"] = bias_comm_action communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
...@@ -230,14 +264,29 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -230,14 +264,29 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0, logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK) comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias and self.is_param("bias"): else:
bias_comm_action = self.get_communication_action( other_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"], sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0, logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK) comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping["bias"] = bias_comm_action communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
...@@ -277,7 +326,7 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -277,7 +326,7 @@ class ConvStrategyGenerator(StrategyGenerator):
input_comm_action = self.get_communication_action( input_comm_action = self.get_communication_action(
sharding_spec_mapping["input"], sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0, logical_process_axis=mesh_dim_1,
comm_type=CommType.BEFORE, comm_type=CommType.BEFORE,
arg_index=0) arg_index=0)
...@@ -399,14 +448,30 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -399,14 +448,30 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1], logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK) comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action else:
other_comm_action = self.get_communication_action(
if self.has_bias and self.is_param("bias"): sharding_spec_mapping["other"],
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1], logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK) comm_type=CommType.BEFORE,
arg_index=1)
communication_action_mapping["other"] = other_comm_action
if self.has_bias:
if self.is_param("bias"):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.HOOK)
else:
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.BEFORE,
key_for_kwarg='bias')
communication_action_mapping["bias"] = bias_comm_action communication_action_mapping["bias"] = bias_comm_action
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
......
...@@ -290,7 +290,6 @@ def check_conv_function_handler(rank, bias, world_size, port): ...@@ -290,7 +290,6 @@ def check_conv_function_handler(rank, bias, world_size, port):
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
@pytest.mark.skip("some cases need to be fixed")
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist @pytest.mark.dist
# We temporarily ban the bias option before doing bias add # We temporarily ban the bias option before doing bias add
...@@ -303,7 +302,6 @@ def test_conv_module_handler(bias=False): ...@@ -303,7 +302,6 @@ def test_conv_module_handler(bias=False):
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@pytest.mark.skip("some cases need to be fixed")
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist @pytest.mark.dist
# We temporarily ban the bias option before doing bias add # We temporarily ban the bias option before doing bias add
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment