Unverified Commit 1e9f9c22 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[hotfix]change to fit latest p2p (#1100)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c.

* [hotfix]change to fit latest p2p

* polish

* polish
parent 72bd7c69
...@@ -86,7 +86,14 @@ class PipelineSchedule(BaseSchedule): ...@@ -86,7 +86,14 @@ class PipelineSchedule(BaseSchedule):
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.dtype = torch.float self.dtype = torch.float
self.tensor_shape = tensor_shape assert not isinstance(tensor_shape,
int), "tensor_shape type should be one of Union[torch.Size, List[int], Tuple[int]]."
if tensor_shape is None:
self.tensor_shape = tensor_shape
elif isinstance(tensor_shape, torch.Size):
self.tensor_shape = tensor_shape
else:
self.tensor_shape = torch.Size(tensor_shape)
self.scatter_gather_tensors = False self.scatter_gather_tensors = False
if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1: if gpc.is_initialized(ParallelMode.PARALLEL_1D) and gpc.get_world_size(ParallelMode.PARALLEL_1D) > 1:
self.scatter_gather_tensors = scatter_gather_tensors self.scatter_gather_tensors = scatter_gather_tensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment