Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4b9bba81
Unverified
Commit
4b9bba81
authored
Jun 24, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 24, 2022
Browse files
[ColoTensor] rename APIs and add output_replicate to ComputeSpec (#1168)
parent
f4ef2243
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
6 deletions
+6
-6
tests/test_tensor/test_op.py
tests/test_tensor/test_op.py
+3
-3
tests/test_tensor/test_tensor.py
tests/test_tensor/test_tensor.py
+1
-1
tests/test_tensor/test_zero_optim.py
tests/test_tensor/test_zero_optim.py
+2
-2
No files found.
tests/test_tensor/test_op.py
View file @
4b9bba81
...
@@ -36,10 +36,10 @@ def test_layernorm():
...
@@ -36,10 +36,10 @@ def test_layernorm():
def
check_spec_eq
(
tensor
,
other
):
def
check_spec_eq
(
tensor
,
other
):
assert
isinstance
(
tensor
,
ColoTensor
)
and
isinstance
(
other
,
ColoTensor
)
assert
isinstance
(
tensor
,
ColoTensor
)
and
isinstance
(
other
,
ColoTensor
)
for
k
in
dir
(
tensor
.
spec
.
dist_spec
):
for
k
in
dir
(
tensor
.
tensor_
spec
.
dist_spec
):
if
not
k
.
startswith
(
'__'
):
if
not
k
.
startswith
(
'__'
):
assert
hasattr
(
other
.
spec
.
dist_spec
,
k
)
assert
hasattr
(
other
.
tensor_
spec
.
dist_spec
,
k
)
assert
getattr
(
tensor
.
spec
.
dist_spec
,
k
)
==
getattr
(
other
.
spec
.
dist_spec
,
k
)
assert
getattr
(
tensor
.
tensor_
spec
.
dist_spec
,
k
)
==
getattr
(
other
.
tensor_
spec
.
dist_spec
,
k
)
def
check_element_wise_ops
():
def
check_element_wise_ops
():
...
...
tests/test_tensor/test_tensor.py
View file @
4b9bba81
...
@@ -66,7 +66,7 @@ def _run_tensor_shard_init(world_size):
...
@@ -66,7 +66,7 @@ def _run_tensor_shard_init(world_size):
shard_spec
=
distspec
.
shard
(
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
),
dims
=
[
0
],
num_partitions
=
[
world_size
])
shard_spec
=
distspec
.
shard
(
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
),
dims
=
[
0
],
num_partitions
=
[
world_size
])
tensor_spec
=
TensorSpec
(
shard_spec
)
tensor_spec
=
TensorSpec
(
shard_spec
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
tensor_spec
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
tensor_spec
)
t
.
set_spec
(
TensorSpec
(
dist_spec
=
distspec
.
replicate
()))
t
.
set_
tensor_
spec
(
TensorSpec
(
dist_spec
=
distspec
.
replicate
()))
assert
t
.
shape
==
torch
.
Size
((
4
*
world_size
,
5
))
assert
t
.
shape
==
torch
.
Size
((
4
*
world_size
,
5
))
...
...
tests/test_tensor/test_zero_optim.py
View file @
4b9bba81
...
@@ -51,7 +51,7 @@ def init_1d_row_spec(model):
...
@@ -51,7 +51,7 @@ def init_1d_row_spec(model):
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
and
'ln'
not
in
n
:
if
'weight'
in
n
and
'ln'
not
in
n
:
p
.
set_spec
(
spec
)
p
.
set_
tensor_
spec
(
spec
)
def
init_1d_col_spec
(
model
):
def
init_1d_col_spec
(
model
):
...
@@ -61,7 +61,7 @@ def init_1d_col_spec(model):
...
@@ -61,7 +61,7 @@ def init_1d_col_spec(model):
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
for
n
,
p
in
model
.
named_parameters
():
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
p
.
set_spec
(
spec
)
p
.
set_
tensor_
spec
(
spec
)
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment