Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4a76084d
Unverified
Commit
4a76084d
authored
Jul 08, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 08, 2022
Browse files
[tensor] add zero_like colo op, important for Optimizer (#1236)
parent
3b500984
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
16 additions
and
6 deletions
+16
-6
colossalai/nn/_ops/element_wise.py
colossalai/nn/_ops/element_wise.py
+1
-0
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+1
-1
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+3
-0
tests/test_tensor/test_tensor.py
tests/test_tensor/test_tensor.py
+11
-5
No files found.
colossalai/nn/_ops/element_wise.py
View file @
4a76084d
...
@@ -195,6 +195,7 @@ register_elementwise_op(torch.tan)
...
@@ -195,6 +195,7 @@ register_elementwise_op(torch.tan)
register_elementwise_op
(
torch
.
tanh
)
register_elementwise_op
(
torch
.
tanh
)
register_elementwise_op
(
torch
.
atanh
)
register_elementwise_op
(
torch
.
atanh
)
register_elementwise_op
(
torch
.
arctanh
)
register_elementwise_op
(
torch
.
arctanh
)
register_elementwise_op
(
torch
.
zeros_like
)
# nn.functional OP
# nn.functional OP
register_elementwise_op
(
F
.
threshold
)
register_elementwise_op
(
F
.
threshold
)
...
...
colossalai/tensor/colo_parameter.py
View file @
4a76084d
...
@@ -55,7 +55,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
...
@@ -55,7 +55,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
return
tensor
return
tensor
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
'ColoParameter:
{
torch
.
Tensor
.
__repr__
(
self
)
}
'
return
f
'ColoParameter:
{
Colo
Tensor
.
__repr__
(
self
)
}
'
@
classmethod
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
def
__torch_function__
(
cls
,
func
,
types
,
args
=
...,
kwargs
=
None
):
...
...
colossalai/tensor/colo_tensor.py
View file @
4a76084d
...
@@ -271,3 +271,6 @@ class ColoTensor(torch.Tensor):
...
@@ -271,3 +271,6 @@ class ColoTensor(torch.Tensor):
def
is_shard_1drow
(
self
):
def
is_shard_1drow
(
self
):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
SHARD
\
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
SHARD
\
and
len
(
self
.
dist_spec
.
dims
)
==
1
and
self
.
dist_spec
.
dims
[
0
]
==
0
and
len
(
self
.
dist_spec
.
dims
)
==
1
and
self
.
dist_spec
.
dims
[
0
]
==
0
def
is_sharded
(
self
):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
SHARD
\ No newline at end of file
tests/test_tensor/test_tensor.py
View file @
4a76084d
...
@@ -42,7 +42,7 @@ def _run_wrapped_tensor_func():
...
@@ -42,7 +42,7 @@ def _run_wrapped_tensor_func():
assert
isinstance
(
t_split1
,
ColoTensor
)
and
isinstance
(
t_split2
,
ColoTensor
),
f
"
{
type
(
t_split1
)
}
{
type
(
t_split2
)
}
"
assert
isinstance
(
t_split1
,
ColoTensor
)
and
isinstance
(
t_split2
,
ColoTensor
),
f
"
{
type
(
t_split1
)
}
{
type
(
t_split2
)
}
"
def
_run_operand
():
def
_run_operand
(
world_size
):
pg
=
ProcessGroup
()
pg
=
ProcessGroup
()
t_ref
=
torch
.
randn
(
4
,
5
)
t_ref
=
torch
.
randn
(
4
,
5
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
ColoTensorSpec
(
pg
))
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
ColoTensorSpec
(
pg
))
...
@@ -53,6 +53,13 @@ def _run_operand():
...
@@ -53,6 +53,13 @@ def _run_operand():
assert
isinstance
(
t_res
,
ColoTensor
)
assert
isinstance
(
t_res
,
ColoTensor
)
assert
torch
.
allclose
(
t_ref_res
,
t_res
)
assert
torch
.
allclose
(
t_ref_res
,
t_res
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
ColoTensorSpec
(
pg
))
t
.
set_dist_spec
(
distspec
.
shard
([
0
],
[
world_size
]))
t_new
=
torch
.
zeros_like
(
t
)
assert
isinstance
(
t_new
,
ColoTensor
)
assert
t_new
.
is_sharded
()
#### Test Distributed init a Colotensor
#### Test Distributed init a Colotensor
...
@@ -105,9 +112,8 @@ def run_dist_tests(rank, world_size, port):
...
@@ -105,9 +112,8 @@ def run_dist_tests(rank, world_size, port):
_run_view
(
world_size
)
_run_view
(
world_size
)
_run_process_group
(
world_size
)
_run_process_group
(
world_size
)
_run_tensor_indexing
()
_run_tensor_indexing
()
_run_operand
()
_run_operand
(
world_size
)
# TODO not passed
_run_wrapped_tensor_func
()
# _run_wrapped_tensor_func()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
@@ -119,4 +125,4 @@ def test_dist_cases(world_size):
...
@@ -119,4 +125,4 @@ def test_dist_cases(world_size):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_dist_cases
(
2
)
test_dist_cases
(
1
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment