Unverified Commit 1b416864 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[hotfix] fix unit test test_module_spec (#1321)

parent 9e4c6449
...@@ -88,7 +88,7 @@ def init_colo_module(module: torch.nn.Module, ...@@ -88,7 +88,7 @@ def init_colo_module(module: torch.nn.Module,
compute_pattern = compute_spec.compute_pattern compute_pattern = compute_spec.compute_pattern
if is_colo_module(module): if is_colo_module(module):
# for each param # for each param
# set DistSpec and ComputeSpec # set its process_group, dist_spec and compute_spec
colo_module = get_colo_module(module) colo_module = get_colo_module(module)
colo_module.register(compute_pattern, pg) colo_module.register(compute_pattern, pg)
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode): if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
...@@ -101,6 +101,7 @@ def init_colo_module(module: torch.nn.Module, ...@@ -101,6 +101,7 @@ def init_colo_module(module: torch.nn.Module,
continue continue
param = module.get_parameter(param_name) param = module.get_parameter(param_name)
if isinstance(param, ColoParameter): if isinstance(param, ColoParameter):
param.set_process_group(pg)
param.set_dist_spec(dist_spec) param.set_dist_spec(dist_spec)
param.compute_spec = compute_spec param.compute_spec = compute_spec
for mod in param.shared_param_modules: for mod in param.shared_param_modules:
......
...@@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]: ...@@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
Tensor._base.__get__, Tensor._base.__get__,
Tensor.grad.__get__, Tensor.grad.__get__,
Tensor._grad.__get__, Tensor._grad.__get__,
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
} }
...@@ -121,11 +121,13 @@ class ColoTensor(torch.Tensor): ...@@ -121,11 +121,13 @@ class ColoTensor(torch.Tensor):
RuntimeError: RuntimeError:
""" """
assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid" assert isinstance(pg, ProcessGroup), f"pg as type {type(pg)} is invalid"
if self.process_group.tp_world_size() != 1: # if the new pg is the same as the old pg, just returns
raise RuntimeError("can not set_process_group on a ColoTensor whose process_group has tp world group") if self.process_group == pg:
return
if self.dist_spec.placement.value != 'r': assert self.process_group.tp_world_size() == 1, \
raise RuntimeError("can not set_process_group on a ColoTensor whose dist spec is not REPLICATE") "Can not set_process_group on a ColoTensor whose process_group has tp world group"
assert self.dist_spec.placement.value == 'r', \
"Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE"
self.process_group = pg self.process_group = pg
...@@ -290,17 +292,17 @@ class ColoTensor(torch.Tensor): ...@@ -290,17 +292,17 @@ class ColoTensor(torch.Tensor):
def is_replicate(self): def is_replicate(self):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \ return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1 or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \ and self.dist_spec.num_partitions[0] == 1) \
or (self.process_group.tp_world_size() == 1) or (self.process_group.tp_world_size() == 1)
def is_shard_1dcol(self): def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \ return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1 and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
def is_shard_1drow(self): def is_shard_1drow(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \ return self.dist_spec.placement == DistPlacementPattern.SHARD \
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0 and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def is_sharded(self): def is_sharded(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD return self.dist_spec.placement == DistPlacementPattern.SHARD
from copy import copy from copy import deepcopy
import pytest import pytest
from functools import partial from functools import partial
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, ShardSpec, ReplicaSpec from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ShardSpec, ColoTensorSpec
from colossalai.nn.parallel.layers import init_colo_module, check_colo_module from colossalai.nn.parallel.layers import init_colo_module, check_colo_module
from _utils import tensor_equal, tensor_shard_equal, set_seed from _utils import tensor_equal, tensor_shard_equal, set_seed
...@@ -112,21 +112,25 @@ def run_linear_with_spec(mode): ...@@ -112,21 +112,25 @@ def run_linear_with_spec(mode):
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = torch.nn.Linear(4, 8) model = torch.nn.Linear(4, 8)
model_handy = copy(model) model_handy = deepcopy(model)
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size) pg = ProcessGroup(tp_degree=world_size)
compute_spec = ComputeSpec(ComputePattern.TP1D) compute_spec = ComputeSpec(ComputePattern.TP1D)
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode) init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
x = torch.rand(2, 4).cuda() x = torch.rand(2, 4).cuda()
colo_x = ColoTensor.from_torch_tensor(x, ColoTensorSpec(pg))
out = model(x) out = model(x)
colo_out = model_handy(x) colo_out = model_handy(colo_x)
assert tensor_equal(out, colo_out) assert tensor_equal(out, colo_out)
grad = torch.rand_like(out) grad = torch.rand_like(out)
out.backward(grad) out.backward(grad)
colo_out.backward(grad) colo_out.backward(grad)
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad, pg.tp_local_rank(), pg.tp_world_size()) assert tensor_shard_equal(model_handy.weight.grad, model.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
assert tensor_shard_equal(model_handy.bias.grad, model.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
def run_check_shared_param(): def run_check_shared_param():
...@@ -196,7 +200,7 @@ def run_dist_check(rank, world_size, port): ...@@ -196,7 +200,7 @@ def run_dist_check(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.skip("under development lazy init ColoParameter in Context") @pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_module_linear_1d(world_size): def test_module_linear_1d(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
...@@ -205,7 +209,7 @@ def test_module_linear_1d(world_size): ...@@ -205,7 +209,7 @@ def test_module_linear_1d(world_size):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@pytest.mark.skip("under development lazy init ColoParameter in Context") @pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_module_model(world_size): def test_module_model(world_size):
run_func = partial(run_dist_model, world_size=world_size, port=free_port()) run_func = partial(run_dist_model, world_size=world_size, port=free_port())
...@@ -214,7 +218,7 @@ def test_module_model(world_size): ...@@ -214,7 +218,7 @@ def test_module_model(world_size):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@pytest.mark.skip("under development lazy init ColoParameter in Context") @pytest.mark.skip("for higher testing speed")
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_module_check(world_size): def test_module_check(world_size):
run_func = partial(run_dist_check, world_size=world_size, port=free_port()) run_func = partial(run_dist_check, world_size=world_size, port=free_port())
...@@ -222,4 +226,4 @@ def test_module_check(world_size): ...@@ -222,4 +226,4 @@ def test_module_check(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_module_check(2) test_module_linear_1d(4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment