Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a98319f0
"...images/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "c701b77b1131a9095f3dca454da4ec667bcbf182"
Unverified
Commit
a98319f0
authored
Jul 07, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 07, 2022
Browse files
[tensor] torch function return colotensor (#1229)
parent
55811708
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
42 additions
and
21 deletions
+42
-21
colossalai/nn/_ops/element_wise.py
colossalai/nn/_ops/element_wise.py
+2
-4
colossalai/nn/_ops/linear.py
colossalai/nn/_ops/linear.py
+3
-2
colossalai/nn/_ops/loss.py
colossalai/nn/_ops/loss.py
+1
-1
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+27
-9
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+4
-4
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+1
-0
tests/test_tensor/test_op.py
tests/test_tensor/test_op.py
+2
-1
tests/test_tensor/test_tensor.py
tests/test_tensor/test_tensor.py
+2
-0
No files found.
colossalai/nn/_ops/element_wise.py
View file @
a98319f0
...
...
@@ -17,14 +17,12 @@ def register_elementwise_op(op):
"""
output
=
op
(
input_tensor
,
*
args
,
**
kwargs
)
if
isinstance
(
input_tensor
,
ColoTensor
):
if
not
isinstance
(
output
,
torch
.
Tensor
):
raise
NotImplementedError
return
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
ColoTensorSpec
(
input_tensor
.
process_group
,
dist_attr
=
input_tensor
.
dist_spec
,
compute_attr
=
input_tensor
.
compute_spec
))
spec
=
ColoTensorSpec
(
input_tensor
.
get_process_group
(),
dist_attr
=
input_tensor
.
dist_spec
))
# Tensor op
...
...
colossalai/nn/_ops/linear.py
View file @
a98319f0
...
...
@@ -22,7 +22,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
assert
not
bias
.
has_compute_spec
(),
'Invalid bias spec for 1Drow Linear op'
output
=
output
+
bias
pg
=
input_tensor
.
get_process_group
()
pg
=
weight
.
get_process_group
()
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
ColoTensorSpec
(
pg
,
distspec
.
replicate
()))
return
output
...
...
@@ -61,6 +61,7 @@ def colo_linear_imp(input_tensor: GeneralTensor,
"""
assert
isinstance
(
weight
,
ColoTensor
)
pg
=
weight
.
get_process_group
()
assert
pg
input_tensor
=
convert_to_colo_tensor
(
input_tensor
,
pg
)
bias
=
convert_to_colo_tensor
(
bias
,
pg
)
# input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
...
...
@@ -70,7 +71,7 @@ def colo_linear_imp(input_tensor: GeneralTensor,
if
not
weight
.
has_compute_spec
():
# No Model Parallel Applied
assert
weight
.
is_replicate
(),
'Invalid weight spec for native Linear op'
assert
bias
is
None
or
bias
.
is_replicate
(),
'Invalid bias spec for native Linear op'
ret_tensor
=
ColoTensor
.
from_torch_tensor
(
F
.
linear
(
input_tensor
,
weight
,
bias
))
ret_tensor
=
ColoTensor
.
from_torch_tensor
(
F
.
linear
(
input_tensor
,
weight
,
bias
)
,
spec
=
ColoTensorSpec
(
pg
)
)
elif
weight
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
is_shard_1dcol
()
and
(
bias
is
None
or
bias
.
is_replicate
()):
mode
=
'row'
...
...
colossalai/nn/_ops/loss.py
View file @
a98319f0
...
...
@@ -35,7 +35,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
elif
input_tensor
.
has_compute_spec
():
# Single Model Parallel Applied
if
input_tensor
.
is_shard_1dcol
():
output
=
VocabParallelCrossEntropyLoss1D
()(
input_tensor
,
target
)
return
ColoTensor
.
from_torch_tensor
(
output
,
ColoTensorSpec
(
pg
))
return
ColoTensor
.
from_torch_tensor
(
output
,
ColoTensorSpec
(
pg
))
.
to_replicate
()
else
:
raise
NotImplementedError
else
:
...
...
colossalai/tensor/colo_tensor.py
View file @
a98319f0
...
...
@@ -11,12 +11,30 @@ from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from
typing
import
Optional
def
_c
heck
_output
(
output
):
if
not
isinstanc
e
(
output
,
torch
.
Tensor
)
:
r
aise
RuntimeError
def
_c
onvert
_output
(
output
,
pg
:
ProcessGroup
):
if
typ
e
(
output
)
==
torch
.
Tensor
:
r
eturn
ColoTensor
.
from_torch_tensor
(
output
,
ColoTensorSpec
(
pg
))
elif
isinstance
(
output
,
(
list
,
tuple
)):
output
=
type
(
output
)(
_check_output
(
o
)
for
o
in
output
)
return
output
return
type
(
output
)(
_convert_output
(
o
,
pg
)
for
o
in
output
)
else
:
return
output
def
_scan_for_pg_from_args
(
args
,
kwargs
)
->
ProcessGroup
:
for
elem
in
args
:
if
isinstance
(
elem
,
ColoTensor
):
pg
=
elem
.
get_process_group
()
return
pg
elif
isinstance
(
elem
,
(
list
,
tuple
)):
pg
=
_scan_for_pg_from_args
(
elem
,
{})
if
pg
is
not
None
:
return
pg
print
(
type
(
elem
),
elem
,
isinstance
(
elem
,
(
list
,
tuple
)))
for
k
,
v
in
kwargs
:
if
isinstance
(
v
,
ColoTensor
):
pg
=
v
.
get_process_group
()
return
pg
return
None
class
ColoTensor
(
torch
.
Tensor
):
...
...
@@ -108,6 +126,7 @@ class ColoTensor(torch.Tensor):
dist_spec (_DistSpec): target dist spec.
"""
assert
isinstance
(
dist_spec
,
_DistSpec
)
assert
self
.
process_group
self
.
_convert_to_dist_spec
(
dist_spec
)
def
set_tensor_spec
(
self
,
dist_spec
,
compute_spec
):
...
...
@@ -136,12 +155,11 @@ class ColoTensor(torch.Tensor):
if
func
in
get_default_nowrap_functions
():
return
ret
else
:
# TODO(jiaruifang) its parallel Op's duty to convert output activations
return
ret
# return _check_output(ret)
pg
=
_scan_for_pg_from_args
(
args
,
kwargs
)
return
_convert_output
(
ret
,
pg
)
def
__repr__
(
self
):
return
f
'ColoTensor:
{
super
().
__repr__
()
}
'
return
f
'ColoTensor:
{
super
().
__repr__
()
}
\n
dist spec:
{
self
.
dist_spec
}
\n
process group:
{
self
.
process_group
}
'
def
_convert_to_dist_spec
(
self
,
dist_spec
:
_DistSpec
)
->
None
:
"""_convert_to_dist_spec
...
...
colossalai/tensor/process_group.py
View file @
a98319f0
...
...
@@ -19,6 +19,10 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
pg_key
=
(
backend
,
rank_tuple
)
if
pg_key
not
in
self
.
dict
:
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
self
.
logger
.
info
(
f
'NCCL initialize TP group on
{
rank_list
}
'
,
ranks
=
[
0
])
self
.
dict
[
pg_key
]
=
torch
.
distributed
.
new_group
(
ranks
=
rank_list
,
backend
=
backend
)
return
self
.
dict
[
pg_key
]
...
...
@@ -92,10 +96,6 @@ class ProcessGroup:
self
.
_tp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'nccl'
)
self
.
_dp_process_group
=
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'nccl'
)
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
self
.
logger
.
info
(
f
'
{
self
.
_rank
}
NCCL initialize TP group on
{
self
.
_tp_rank_list
}
, DP group on
{
self
.
_dp_rank_list
}
'
)
self
.
_has_cpu_groups
=
False
self
.
_cpu_dp_process_group
=
None
self
.
_cpu_tp_process_group
=
None
...
...
tests/test_tensor/test_model.py
View file @
a98319f0
...
...
@@ -113,6 +113,7 @@ def run_1d_hybrid_tp(model_name):
torch
.
distributed
.
broadcast
(
data
,
0
,
group
=
pg
.
tp_process_group
())
torch
.
distributed
.
broadcast
(
label
,
0
,
group
=
pg
.
tp_process_group
())
# Bcast rank0 data to all processes
if
criterion
:
output
=
model
(
data
)
...
...
tests/test_tensor/test_op.py
View file @
a98319f0
...
...
@@ -39,7 +39,7 @@ def check_spec_eq(tensor, other):
assert
isinstance
(
tensor
,
ColoTensor
)
and
isinstance
(
other
,
ColoTensor
)
for
k
in
dir
(
tensor
.
dist_spec
):
if
not
k
.
startswith
(
'__'
):
assert
hasattr
(
other
.
dist_spec
,
k
)
assert
hasattr
(
other
.
dist_spec
,
k
)
,
f
"
{
k
}
"
assert
getattr
(
tensor
.
dist_spec
,
k
)
==
getattr
(
other
.
dist_spec
,
k
)
...
...
@@ -48,6 +48,7 @@ def check_element_wise_ops():
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
t
=
torch
.
rand
(
2
,
2
)
x
=
ColoTensor
(
t
,
spec
=
ColoTensorSpec
(
pg
,
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()])))
check_spec_eq
(
x
,
x
.
cuda
())
assert
torch
.
equal
(
x
.
cuda
(),
t
.
cuda
())
check_spec_eq
(
x
,
torch
.
abs
(
x
))
...
...
tests/test_tensor/test_tensor.py
View file @
a98319f0
...
...
@@ -49,6 +49,8 @@ def _run_operand():
t_ref_res
=
t_ref
+
t_ref
t_res
=
t
+
t
assert
isinstance
(
t_res
,
ColoTensor
)
assert
torch
.
allclose
(
t_ref_res
,
t_res
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment