Unverified Commit e6ec99d3 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[utils] fixed lazy init context (#1867)

parent 50c4cb01
#!/usr/bin/env python #!/usr/bin/env python
# coding: utf-8 # coding: utf-8
import inspect
import types
from typing import Callable, List
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.tensor import ColoParameter, ColoTensor
import types from colossalai.tensor import ColoParameter, ColoTensor
import inspect
from typing import List, Callable
from colossalai.utils.model.utils import substitute_init_recursively from colossalai.utils.model.utils import substitute_init_recursively
...@@ -35,14 +36,15 @@ class LazyInitContext(): ...@@ -35,14 +36,15 @@ class LazyInitContext():
assert not model.weight.is_meta and torch.all(model.weight == 0) assert not model.weight.is_meta and torch.all(model.weight == 0)
Args: Args:
to_meta (bool): optional, whether to initialize the model with meta tensors, default is False. to_meta (bool): optional, whether to initialize the model with meta tensors, default is True. This
argument exists for now because some corner cases such as self.weight = torch.zeros(...) cannot be captured yet.
extra_torch_tensor_func (List[str]): extra torch tensor functions related extra_torch_tensor_func (List[str]): extra torch tensor functions related
to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default. to value setting, such as `zero_` and `triu_`. `zero_` is pre-added by default.
""" """
tensor_set_value_func = ['zero_', 'fill_'] tensor_set_value_func = ['zero_', 'fill_']
def __init__(self, to_meta: bool = False, extra_torch_tensor_func: List[str] = None): def __init__(self, to_meta: bool = True, extra_torch_tensor_func: List[str] = None):
# TODO: hijack the torch constructor functions as well # TODO: hijack the torch constructor functions as well
self._to_meta = to_meta self._to_meta = to_meta
self._intercepted_nn_init_func_cache = {} self._intercepted_nn_init_func_cache = {}
...@@ -212,18 +214,19 @@ class LazyInitContext(): ...@@ -212,18 +214,19 @@ class LazyInitContext():
materialized_tensor = torch.empty_like(tensor, device=device) materialized_tensor = torch.empty_like(tensor, device=device)
# if this tensor is a meta tensor, it must have an init function # if this tensor is a meta tensor, it must have an init function
assert tensor in self._intercepted_nn_init_func_cache assert tensor in self._intercepted_nn_init_func_cache
tensor = materialized_tensor else:
materialized_tensor = tensor
# apply init function # apply init function
if tensor in self._intercepted_nn_init_func_cache: if tensor in self._intercepted_nn_init_func_cache:
init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1] init_func, args, kwargs = self._intercepted_nn_init_func_cache[tensor][-1]
init_func(tensor, *args, **kwargs) init_func(materialized_tensor, *args, **kwargs)
# convert it to ColoTensor or ColoParameter # convert it to ColoTensor or ColoParameter
if is_param: if is_param:
tensor = ColoParameter.from_torch_tensor(tensor, requires_grad=tensor.requires_grad) tensor = ColoParameter.from_torch_tensor(materialized_tensor, requires_grad=tensor.requires_grad)
else: else:
tensor = ColoTensor.from_torch_tensor(tensor) tensor = ColoTensor.from_torch_tensor(materialized_tensor)
# override the original tensor # override the original tensor
with torch.no_grad(): with torch.no_grad():
......
import colossalai from functools import partial
import torch
import torch.nn as nn
import pytest import pytest
import torch.multiprocessing as mp import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.testing import rerun_if_address_is_in_use import torch.multiprocessing as mp
from functools import partial import torch.nn as nn
import colossalai
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.utils.model.lazy_init_context import LazyInitContext
from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass from colossalai.fx.passes.shard_1d_pass import transformer_mlp_pass
from colossalai.utils import free_port
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.model.lazy_init_context import LazyInitContext
class MLP(torch.nn.Module): class MLP(torch.nn.Module):
...@@ -35,6 +37,9 @@ def run_workflow(world_size): ...@@ -35,6 +37,9 @@ def run_workflow(world_size):
with LazyInitContext() as ctx: with LazyInitContext() as ctx:
model = MLP(16) model = MLP(16)
for param in model.parameters():
assert param.is_meta
# tracing # tracing
tracer = ColoTracer() tracer = ColoTracer()
graph = tracer.trace(model) graph = tracer.trace(model)
...@@ -46,6 +51,8 @@ def run_workflow(world_size): ...@@ -46,6 +51,8 @@ def run_workflow(world_size):
# materialization and sharding # materialization and sharding
ctx.lazy_init_parameters(annotated_gm) ctx.lazy_init_parameters(annotated_gm)
for param in model.parameters():
assert not param.is_meta
# # check sharding # # check sharding
assert list(model.linear1.weight.shape) == [16 // world_size, 16] assert list(model.linear1.weight.shape) == [16 // world_size, 16]
...@@ -57,7 +64,7 @@ def run_workflow(world_size): ...@@ -57,7 +64,7 @@ def run_workflow(world_size):
data = torch.rand(4, 16) data = torch.rand(4, 16)
non_fx_out = model(data) non_fx_out = model(data)
fx_out = annotated_gm(data) fx_out = annotated_gm(data)
assert torch.equal(non_fx_out, fx_out) assert torch.equal(non_fx_out, fx_out), f'{non_fx_out} vs {fx_out}'
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
...@@ -74,4 +81,4 @@ def test_complete_workflow(world_size): ...@@ -74,4 +81,4 @@ def test_complete_workflow(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_complete_workflow(2) test_complete_workflow(1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment