Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
aa7bef73
Unverified
Commit
aa7bef73
authored
Jun 27, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 27, 2022
Browse files
[Tensor] distributed view supports inter-process hybrid parallel (#1169)
parent
9e1daa63
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
101 additions
and
19 deletions
+101
-19
colossalai/core.py
colossalai/core.py
+2
-0
colossalai/nn/_ops/addmm.py
colossalai/nn/_ops/addmm.py
+3
-3
colossalai/nn/_ops/embedding.py
colossalai/nn/_ops/embedding.py
+2
-2
colossalai/nn/_ops/embedding_bag.py
colossalai/nn/_ops/embedding_bag.py
+1
-1
colossalai/nn/_ops/linear.py
colossalai/nn/_ops/linear.py
+4
-4
colossalai/nn/_ops/loss.py
colossalai/nn/_ops/loss.py
+1
-1
colossalai/tensor/chunk.py
colossalai/tensor/chunk.py
+1
-1
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+10
-0
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+56
-1
colossalai/tensor/dist_spec_mgr.py
colossalai/tensor/dist_spec_mgr.py
+1
-0
colossalai/tensor/tensor_spec.py
colossalai/tensor/tensor_spec.py
+1
-1
tests/test_tensor/test_gpt.py
tests/test_tensor/test_gpt.py
+1
-1
tests/test_tensor/test_tensor.py
tests/test_tensor/test_tensor.py
+18
-4
No files found.
colossalai/core.py
View file @
aa7bef73
...
@@ -2,3 +2,5 @@
...
@@ -2,3 +2,5 @@
# -*- encoding: utf-8 -*-
# -*- encoding: utf-8 -*-
from
colossalai.context.parallel_context
import
global_context
from
colossalai.context.parallel_context
import
global_context
__all__
=
[
'global_context'
]
\ No newline at end of file
colossalai/nn/_ops/addmm.py
View file @
aa7bef73
...
@@ -68,11 +68,11 @@ def colo_addmm(input_tensor: GeneralTensor,
...
@@ -68,11 +68,11 @@ def colo_addmm(input_tensor: GeneralTensor,
# Add communication logic before and after linear call.
# Add communication logic before and after linear call.
ret_tensor
=
None
ret_tensor
=
None
if
not
mat2
.
has_compute_spec
():
# No Model Parallel Applied
if
not
mat2
.
has_compute_spec
():
# No Model Parallel Applied
assert
mat2
.
tensor_spec
.
is_
gathered
(),
'Invalid mat2 spec for native addmm op'
assert
mat2
.
tensor_spec
.
is_
replicate
(),
'Invalid mat2 spec for native addmm op'
assert
input_tensor
.
tensor_spec
.
is_
gathered
(),
'Invalid input spec for native addmm op'
assert
input_tensor
.
tensor_spec
.
is_
replicate
(),
'Invalid input spec for native addmm op'
ret_tensor
=
ColoTensor
.
from_torch_tensor
(
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
))
ret_tensor
=
ColoTensor
.
from_torch_tensor
(
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
))
elif
mat2
.
tensor_spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
elif
mat2
.
tensor_spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
mat2
.
tensor_spec
.
is_1D_row
()
and
input_tensor
.
tensor_spec
.
is_
gathered
():
if
mat2
.
tensor_spec
.
is_1D_row
()
and
input_tensor
.
tensor_spec
.
is_
replicate
():
mode
=
'row'
mode
=
'row'
elif
mat2
.
tensor_spec
.
is_1D_col
()
and
(
input_tensor
.
tensor_spec
.
is_1D_col
()
elif
mat2
.
tensor_spec
.
is_1D_col
()
and
(
input_tensor
.
tensor_spec
.
is_1D_col
()
or
input_tensor
.
tensor_spec
.
is_1D_row
()):
or
input_tensor
.
tensor_spec
.
is_1D_row
()):
...
...
colossalai/nn/_ops/embedding.py
View file @
aa7bef73
...
@@ -51,7 +51,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
...
@@ -51,7 +51,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
()))
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
()))
tensor_parallel_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
tensor_parallel_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
num_embeddings_per_partition
=
weight
.
size
(
0
)
num_embeddings_per_partition
=
weight
.
size
_base
(
0
)
vocab_start_index
=
tensor_parallel_rank
*
num_embeddings_per_partition
vocab_start_index
=
tensor_parallel_rank
*
num_embeddings_per_partition
vocab_end_index
=
vocab_start_index
+
num_embeddings_per_partition
vocab_end_index
=
vocab_start_index
+
num_embeddings_per_partition
...
@@ -115,7 +115,7 @@ def colo_embedding(input_tensor: GeneralTensor,
...
@@ -115,7 +115,7 @@ def colo_embedding(input_tensor: GeneralTensor,
# Handle differen parallel actions.
# Handle differen parallel actions.
if
not
weight
.
has_compute_spec
():
# No Model Parallel Applied
if
not
weight
.
has_compute_spec
():
# No Model Parallel Applied
assert
weight
.
tensor_spec
.
is_
gathered
(),
'Invalid weight spec for native embedding op'
assert
weight
.
tensor_spec
.
is_
replicate
(),
'Invalid weight spec for native embedding op'
return
ColoTensor
.
from_torch_tensor
(
return
ColoTensor
.
from_torch_tensor
(
F
.
embedding
(
input_tensor
,
F
.
embedding
(
input_tensor
,
weight
,
weight
,
...
...
colossalai/nn/_ops/embedding_bag.py
View file @
aa7bef73
...
@@ -90,7 +90,7 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
...
@@ -90,7 +90,7 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
# Handle differen parallel actions.
# Handle differen parallel actions.
if
not
weight
.
has_compute_spec
():
# No Model Parallel Applied
if
not
weight
.
has_compute_spec
():
# No Model Parallel Applied
assert
weight
.
tensor_spec
.
is_
gathered
(),
'Invalid weight spec for native embedding op'
assert
weight
.
tensor_spec
.
is_
replicate
(),
'Invalid weight spec for native embedding op'
return
ColoTensor
.
from_torch_tensor
(
return
ColoTensor
.
from_torch_tensor
(
F
.
embedding_bag
(
input_tensor
,
F
.
embedding_bag
(
input_tensor
,
weight
,
weight
,
...
...
colossalai/nn/_ops/linear.py
View file @
aa7bef73
...
@@ -67,17 +67,17 @@ def colo_linear_imp(input_tensor: GeneralTensor,
...
@@ -67,17 +67,17 @@ def colo_linear_imp(input_tensor: GeneralTensor,
# Add communication logic before and after linear call.
# Add communication logic before and after linear call.
ret_tensor
=
None
ret_tensor
=
None
if
not
weight
.
has_compute_spec
():
# No Model Parallel Applied
if
not
weight
.
has_compute_spec
():
# No Model Parallel Applied
assert
weight
.
tensor_spec
.
is_
gathered
(),
'Invalid weight spec for native Linear op'
assert
weight
.
tensor_spec
.
is_
replicate
(),
'Invalid weight spec for native Linear op'
assert
bias
is
None
or
bias
.
tensor_spec
.
is_
gathered
(),
'Invalid bias spec for native Linear op'
assert
bias
is
None
or
bias
.
tensor_spec
.
is_
replicate
(),
'Invalid bias spec for native Linear op'
ret_tensor
=
ColoTensor
.
from_torch_tensor
(
F
.
linear
(
input_tensor
,
weight
,
bias
))
ret_tensor
=
ColoTensor
.
from_torch_tensor
(
F
.
linear
(
input_tensor
,
weight
,
bias
))
elif
weight
.
tensor_spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
elif
weight
.
tensor_spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
tensor_spec
.
is_1D_col
()
and
(
bias
is
None
or
bias
.
tensor_spec
.
is_
gathered
()):
if
weight
.
tensor_spec
.
is_1D_col
()
and
(
bias
is
None
or
bias
.
tensor_spec
.
is_
replicate
()):
mode
=
'row'
mode
=
'row'
elif
weight
.
tensor_spec
.
is_1D_row
()
and
(
bias
is
None
or
bias
.
tensor_spec
.
is_1D_row
()
elif
weight
.
tensor_spec
.
is_1D_row
()
and
(
bias
is
None
or
bias
.
tensor_spec
.
is_1D_row
()
or
bias
.
tensor_spec
.
is_1D_col
()):
or
bias
.
tensor_spec
.
is_1D_col
()):
mode
=
'col'
mode
=
'col'
else
:
else
:
raise
NotImplementedError
raise
RuntimeError
(
f
"the weight or bias tensor spec is not valid, weight
{
weight
.
tensor_spec
}
, bias
{
bias
}
"
)
ret_tensor
=
colo_linear_1d
(
mode
,
input_tensor
,
weight
,
bias
)
ret_tensor
=
colo_linear_1d
(
mode
,
input_tensor
,
weight
,
bias
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
...
colossalai/nn/_ops/loss.py
View file @
aa7bef73
...
@@ -18,7 +18,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
...
@@ -18,7 +18,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
label_smoothing
:
float
=
0.0
):
label_smoothing
:
float
=
0.0
):
input_tensor
,
target
,
weight
=
tuple
(
map
(
convert_to_colo_tensor
,
(
input_tensor
,
target
,
weight
)))
input_tensor
,
target
,
weight
=
tuple
(
map
(
convert_to_colo_tensor
,
(
input_tensor
,
target
,
weight
)))
if
input_tensor
.
tensor_spec
.
is_
gathered
():
# Input is gathered
if
input_tensor
.
tensor_spec
.
is_
replicate
():
# Input is gathered
output
=
F
.
cross_entropy
(
input_tensor
,
output
=
F
.
cross_entropy
(
input_tensor
,
target
,
target
,
weight
=
weight
,
weight
=
weight
,
...
...
colossalai/tensor/chunk.py
View file @
aa7bef73
...
@@ -114,7 +114,7 @@ class Chunk:
...
@@ -114,7 +114,7 @@ class Chunk:
# if the process owns the rank, then copy the tensor to its chunk buffer
# if the process owns the rank, then copy the tensor to its chunk buffer
# otherwise set its storage size to 0 to reduce memory consumption
# otherwise set its storage size to 0 to reduce memory consumption
if
self
.
is_src_rank
:
if
self
.
is_src_rank
:
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
view
(
-
1
))
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
copy_
(
tensor
.
flatten
(
))
tensor_state
=
TensorState
.
HOLD
tensor_state
=
TensorState
.
HOLD
tensor
.
data
=
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
tensor
.
data
=
self
.
_payload
[
self
.
utilized_size
:
new_utilized_size
].
view
(
tensor
.
shape
)
else
:
else
:
...
...
colossalai/tensor/colo_parameter.py
View file @
aa7bef73
...
@@ -101,3 +101,13 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
...
@@ -101,3 +101,13 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
# TODO(jzy) we don't support object reflection now.
# TODO(jzy) we don't support object reflection now.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
# distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`.
raise
NotImplementedError
raise
NotImplementedError
#### the ColoParameter should use the torch.Tensor's builtin methodes ###
def
view
(
self
,
*
args
)
->
'ColoTensor'
:
return
super
().
view_base
(
*
args
)
def
size
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Size
:
# import inspect
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
return
super
().
size_base
(
*
args
,
**
kwargs
)
colossalai/tensor/colo_tensor.py
View file @
aa7bef73
...
@@ -8,6 +8,7 @@ from colossalai.tensor import TensorSpec
...
@@ -8,6 +8,7 @@ from colossalai.tensor import TensorSpec
from
colossalai.tensor
import
distspec
from
colossalai.tensor
import
distspec
from
colossalai.tensor.dist_spec_mgr
import
DistSpecManager
from
colossalai.tensor.dist_spec_mgr
import
DistSpecManager
from
colossalai.tensor.distspec
import
_DistSpec
from
colossalai.tensor.distspec
import
_DistSpec
from
typing
import
Optional
def
_convert_output
(
output
):
def
_convert_output
(
output
):
...
@@ -60,6 +61,12 @@ class ColoTensor(torch.Tensor):
...
@@ -60,6 +61,12 @@ class ColoTensor(torch.Tensor):
def
tensor_spec
(
self
)
->
TensorSpec
:
def
tensor_spec
(
self
)
->
TensorSpec
:
return
self
.
_tensor_spec
return
self
.
_tensor_spec
@
tensor_spec
.
setter
def
tensor_spec
(
self
,
tenseor_spec
:
TensorSpec
):
spec
=
copy
(
spec
)
self
.
_convert_to_dist_spec
(
spec
.
dist_spec
)
self
.
_tensor_spec
=
spec
def
set_tensor_spec
(
self
,
spec
:
TensorSpec
)
->
None
:
def
set_tensor_spec
(
self
,
spec
:
TensorSpec
)
->
None
:
spec
=
copy
(
spec
)
spec
=
copy
(
spec
)
self
.
_convert_to_dist_spec
(
spec
.
dist_spec
)
self
.
_convert_to_dist_spec
(
spec
.
dist_spec
)
...
@@ -136,4 +143,52 @@ class ColoTensor(torch.Tensor):
...
@@ -136,4 +143,52 @@ class ColoTensor(torch.Tensor):
data
=
self
.
data
.
clone
()
data
=
self
.
data
.
clone
()
tensor
=
ColoTensor
(
data
,
spec
=
copy
(
self
.
tensor_spec
))
tensor
=
ColoTensor
(
data
,
spec
=
copy
(
self
.
tensor_spec
))
memo
[
id
(
self
)]
=
tensor
memo
[
id
(
self
)]
=
tensor
return
tensor
return
tensor
\ No newline at end of file
##### override builtin functions which must use tensor in replicate placement ####
def
view_base
(
self
,
*
args
)
->
'ColoTensor'
:
return
super
().
view
(
*
args
)
def
size_base
(
self
,
*
args
,
**
kwargs
)
->
torch
.
Size
:
return
super
().
size
(
*
args
,
**
kwargs
)
def
view
(
self
,
*
args
)
->
'ColoTensor'
:
"""override the torch buildin view()
the args passed in must be in a replicate placement.
Returns:
ColoTensor: a tensor after viewed.
"""
if
self
.
tensor_spec
.
is_replicate
():
return
super
().
view
(
*
args
)
# TODO(jiaruifang) check why this not work
# self.data = self.to_replicate()
self
.
data
=
DistSpecManager
.
handle_trans_spec
(
self
.
data
,
self
.
tensor_spec
.
dist_spec
,
distspec
.
replicate
())
self
.
_tensor_spec
.
dist_spec
=
distspec
.
replicate
()
return
super
().
view
(
*
args
)
def
size
(
self
,
args
:
Optional
[
int
]
=
None
):
"""override the torch buildin size()
the shape passed in must be in a replicate placement.
Returns:
ColoTensor: a tensor after viewed.
"""
if
self
.
tensor_spec
.
is_replicate
():
if
args
is
not
None
:
return
super
().
size
(
args
)
else
:
return
super
().
size
()
spec
=
self
.
tensor_spec
.
dist_spec
dims
=
spec
.
dims
num_partitions
=
spec
.
num_partitions
# import inspect
# print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()])
size_list
=
list
(
super
().
size
())
for
dim
,
num_partition
in
zip
(
dims
,
num_partitions
):
size_list
[
dim
]
*=
num_partition
if
args
is
not
None
:
return
size_list
[
args
]
else
:
return
torch
.
Size
(
size_list
)
colossalai/tensor/dist_spec_mgr.py
View file @
aa7bef73
...
@@ -68,6 +68,7 @@ class DistSpecManager:
...
@@ -68,6 +68,7 @@ class DistSpecManager:
num_parts
=
prod
(
dist_spec
.
num_partitions
)
num_parts
=
prod
(
dist_spec
.
num_partitions
)
for
i
,
dim
in
enumerate
(
dist_spec
.
dims
):
for
i
,
dim
in
enumerate
(
dist_spec
.
dims
):
num_parts
//=
dist_spec
.
num_partitions
[
i
]
num_parts
//=
dist_spec
.
num_partitions
[
i
]
chunk_size
=
divide
(
tensor
.
size
(
dim
),
dist_spec
.
num_partitions
[
i
])
chunk_size
=
divide
(
tensor
.
size
(
dim
),
dist_spec
.
num_partitions
[
i
])
chunk
=
chunk
.
narrow
(
dim
,
idx
//
num_parts
*
chunk_size
,
chunk_size
)
chunk
=
chunk
.
narrow
(
dim
,
idx
//
num_parts
*
chunk_size
,
chunk_size
)
idx
%=
num_parts
idx
%=
num_parts
...
...
colossalai/tensor/tensor_spec.py
View file @
aa7bef73
...
@@ -26,7 +26,7 @@ class TensorSpec(object):
...
@@ -26,7 +26,7 @@ class TensorSpec(object):
def
get_placement
(
self
):
def
get_placement
(
self
):
return
self
.
dist_spec
.
placement
return
self
.
dist_spec
.
placement
def
is_
gathered
(
self
):
def
is_
replicate
(
self
):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
REPLICATE
\
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
REPLICATE
\
or
(
len
(
self
.
dist_spec
.
num_partitions
)
==
1
or
(
len
(
self
.
dist_spec
.
num_partitions
)
==
1
and
self
.
dist_spec
.
num_partitions
[
0
]
==
1
)
\
and
self
.
dist_spec
.
num_partitions
[
0
]
==
1
)
\
...
...
tests/test_tensor/test_gpt.py
View file @
aa7bef73
...
@@ -101,4 +101,4 @@ def test_gpt(world_size, use_ddp):
...
@@ -101,4 +101,4 @@ def test_gpt(world_size, use_ddp):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_gpt
(
4
,
Fals
e
)
test_gpt
(
4
,
Tru
e
)
tests/test_tensor/test_tensor.py
View file @
aa7bef73
...
@@ -60,6 +60,19 @@ def test_operand():
...
@@ -60,6 +60,19 @@ def test_operand():
#### Test Distributed init a Colotensor
#### Test Distributed init a Colotensor
def
_run_view
(
world_size
):
t_ref
=
torch
.
randn
(
4
,
5
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
,
TensorSpec
(
distspec
.
shard
(
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
),
dims
=
[
0
],
num_partitions
=
[
2
])))
assert
t
.
size
()[
0
]
==
4
*
world_size
assert
t
.
size
(
1
)
==
5
assert
t
.
size
()
==
torch
.
Size
([
4
*
world_size
,
5
])
t
=
t
.
view
(
4
*
5
*
world_size
)
assert
t
.
shape
==
torch
.
Size
([
4
*
5
*
world_size
])
def
_run_tensor_shard_init
(
world_size
):
def
_run_tensor_shard_init
(
world_size
):
t_ref
=
torch
.
randn
(
4
,
5
)
t_ref
=
torch
.
randn
(
4
,
5
)
print
(
gpc
.
get_group
(
ParallelMode
.
DATA
).
size
())
print
(
gpc
.
get_group
(
ParallelMode
.
DATA
).
size
())
...
@@ -77,20 +90,21 @@ def _run_tensor_replicated_init(world_size):
...
@@ -77,20 +90,21 @@ def _run_tensor_replicated_init(world_size):
assert
t
.
shape
==
torch
.
Size
((
4
*
world_size
,
5
)),
f
"
{
t
.
shape
}
"
assert
t
.
shape
==
torch
.
Size
((
4
*
world_size
,
5
)),
f
"
{
t
.
shape
}
"
def
run_
tensor_init
(
rank
,
world_size
,
port
):
def
run_
dist_tests
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
_run_tensor_shard_init
(
world_size
)
_run_tensor_shard_init
(
world_size
)
_run_tensor_replicated_init
(
world_size
)
_run_tensor_replicated_init
(
world_size
)
_run_view
(
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
_test_dist_
init
(
world_size
):
def
_test_dist_
cases
(
world_size
):
run_func
=
partial
(
run_
tensor_init
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_
dist_tests
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
# _test_dist_init(4)
# _test_dist_init(4)
test_
new
(
)
_
test_
dist_cases
(
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment