Unverified Commit 4b9bba81 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[ColoTensor] rename APIs and add output_replicate to ComputeSpec (#1168)

parent f4ef2243
...@@ -36,10 +36,10 @@ def test_layernorm(): ...@@ -36,10 +36,10 @@ def test_layernorm():
def check_spec_eq(tensor, other): def check_spec_eq(tensor, other):
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor) assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
for k in dir(tensor.spec.dist_spec): for k in dir(tensor.tensor_spec.dist_spec):
if not k.startswith('__'): if not k.startswith('__'):
assert hasattr(other.spec.dist_spec, k) assert hasattr(other.tensor_spec.dist_spec, k)
assert getattr(tensor.spec.dist_spec, k) == getattr(other.spec.dist_spec, k) assert getattr(tensor.tensor_spec.dist_spec, k) == getattr(other.tensor_spec.dist_spec, k)
def check_element_wise_ops(): def check_element_wise_ops():
......
...@@ -66,7 +66,7 @@ def _run_tensor_shard_init(world_size): ...@@ -66,7 +66,7 @@ def _run_tensor_shard_init(world_size):
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size]) shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
tensor_spec = TensorSpec(shard_spec) tensor_spec = TensorSpec(shard_spec)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_spec(TensorSpec(dist_spec=distspec.replicate())) t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
assert t.shape == torch.Size((4 * world_size, 5)) assert t.shape == torch.Size((4 * world_size, 5))
......
...@@ -51,7 +51,7 @@ def init_1d_row_spec(model): ...@@ -51,7 +51,7 @@ def init_1d_row_spec(model):
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n: if 'weight' in n and 'ln' not in n:
p.set_spec(spec) p.set_tensor_spec(spec)
def init_1d_col_spec(model): def init_1d_col_spec(model):
...@@ -61,7 +61,7 @@ def init_1d_col_spec(model): ...@@ -61,7 +61,7 @@ def init_1d_col_spec(model):
with DistSpecManager.no_grad(): with DistSpecManager.no_grad():
for n, p in model.named_parameters(): for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n): if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_spec(spec) p.set_tensor_spec(spec)
@parameterize('use_chunk', [False, True]) @parameterize('use_chunk', [False, True])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment