Unverified Commit e3fde4ee authored by ver217's avatar ver217 Committed by GitHub
Browse files

fix import error in sharded model v2 (#1053)

parent e1922ea4
...@@ -2,7 +2,6 @@ import functools ...@@ -2,7 +2,6 @@ import functools
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Optional, Iterator, Tuple from typing import Any, Optional, Iterator, Tuple
from copy import deepcopy from copy import deepcopy
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
import itertools import itertools
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -28,6 +27,11 @@ from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFacto ...@@ -28,6 +27,11 @@ from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFacto
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor) get_gradient_predivide_factor)
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
class ShardedModelV2(nn.Module): class ShardedModelV2(nn.Module):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment