Unverified Commit 28398f1c authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

add overlap option (#2613)

parent cb3d1bef
......@@ -352,7 +352,7 @@ def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
return gm
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
"""
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
......@@ -387,15 +387,18 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec, stream):
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(param, comm_spec_to_use, reduction_stream)
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
sharded_buffer_dict = {}
# apply the sharding spec of buffers
......@@ -441,15 +444,18 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
# register hook to the parameters
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
def wrapper(param, comm_spec, stream):
def wrapper(param, comm_spec, stream, overlap):
def hook_fn(grad):
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
if overlap:
with torch.cuda.stream(stream):
_all_reduce(grad, comm_spec, async_op=True)
else:
_all_reduce(grad, comm_spec, async_op=False)
param.register_hook(hook_fn)
wrapper(target, comm_spec_to_use, reduction_stream)
wrapper(target, comm_spec_to_use, reduction_stream, overlap=overlap)
return gm
......@@ -463,13 +469,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
def runtime_preparation_pass(gm: torch.fx.GraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor = None):
strategies_constructor: StrategiesConstructor = None,
overlap=False):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
gm, solution, strategies_constructor)
gm = _size_value_converting(gm, device_mesh)
gm = _node_args_converting(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
gm = _module_params_sharding(gm, device_mesh)
gm = _module_params_sharding(gm, device_mesh, overlap=overlap)
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
......@@ -98,16 +98,22 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
return solution
def transform_to_sharded_model(gm: ColoGraphModule, solution: List[int], device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor):
def transform_to_sharded_model(gm: ColoGraphModule,
solution: List[int],
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap: bool = False):
'''
This method is used to transform the original graph to the sharded graph.
The model parameters will be sharded according to the solution and the grad hooks
will be added to the sharded graph using the runtime_preparation_pass.
The communication node will be added into the graph using the runtime_apply_pass.
'''
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
gm, solution, device_mesh, strategies_constructor)
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm,
solution,
device_mesh,
strategies_constructor,
overlap=overlap)
gm = runtime_apply_pass(gm)
gm.recompile()
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
......@@ -176,6 +182,7 @@ def initialize_model(model: nn.Module,
meta_args: Dict[str, torch.Tensor],
device_mesh: DeviceMesh,
memory_budget: float = -1.0,
overlap: bool = False,
save_solver_solution: bool = False,
load_solver_solution: bool = False,
solution_path: str = None,
......@@ -189,6 +196,8 @@ def initialize_model(model: nn.Module,
device_mesh: the device mesh to execute the model.
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
the memory budget will be infinity.
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
backward computing.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
......@@ -211,7 +220,7 @@ def initialize_model(model: nn.Module,
if save_solver_solution:
torch.save(solution, solution_path)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor, overlap)
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
if return_solution:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment