Unverified Commit 258b4331 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[hotfix] layout converting issue (#3188)

parent 80aed29c
...@@ -10,7 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ...@@ -10,7 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost,
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.sharding_spec import ShardingSpecException from colossalai.tensor.d_tensor.misc import LayoutException
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from .sharding_spec import ShardingSpec from .sharding_spec import ShardingSpec
...@@ -145,7 +145,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -145,7 +145,7 @@ class LayoutConverter(metaclass=SingletonMeta):
entire_shape=source_layout.entire_shape) entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec valid_spec_dict[new_layout] = comm_spec
except ShardingSpecException: except LayoutException:
pass pass
return valid_spec_dict return valid_spec_dict
...@@ -255,7 +255,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -255,7 +255,7 @@ class LayoutConverter(metaclass=SingletonMeta):
device_type=source_layout.device_type, device_type=source_layout.device_type,
entire_shape=source_layout.entire_shape) entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec valid_spec_dict[new_layout] = comm_spec
except ShardingSpecException: except LayoutException:
pass pass
return valid_spec_dict return valid_spec_dict
...@@ -343,7 +343,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -343,7 +343,7 @@ class LayoutConverter(metaclass=SingletonMeta):
device_type=source_layout.device_type, device_type=source_layout.device_type,
entire_shape=source_layout.entire_shape) entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec valid_spec_dict[new_layout] = comm_spec
except ShardingSpecException: except LayoutException:
pass pass
return valid_spec_dict return valid_spec_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment