Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
45d93843
Commit
45d93843
authored
Jun 16, 2023
by
Frank Lee
Browse files
[shardformer] removed inplace tensor sharding (#4018)
parent
3893fa1a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
4 deletions
+40
-4
colossalai/shardformer/layer/layers.py
colossalai/shardformer/layer/layers.py
+4
-0
colossalai/tensor/d_tensor/api.py
colossalai/tensor/d_tensor/api.py
+36
-4
No files found.
colossalai/shardformer/layer/layers.py
View file @
45d93843
...
...
@@ -329,7 +329,11 @@ class Linear1D_Row(ParallelModule):
src_rank
=
0
else
:
src_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
self
.
process_group
,
0
)
origin_device
=
self
.
bias
.
device
self
.
bias
=
self
.
bias
.
cuda
()
dist
.
broadcast
(
self
.
bias
,
src
=
src_rank
,
group
=
self
.
process_group
)
self
.
bias
=
self
.
bias
.
to
(
origin_device
)
def
forward
(
self
,
input_
:
Tensor
)
->
Tensor
:
# Set up backprop all-reduce.
...
...
colossalai/tensor/d_tensor/api.py
View file @
45d93843
...
...
@@ -10,9 +10,21 @@ from .d_tensor import DTensor
from
.sharding_spec
import
ShardingSpec
def
shard_rowwise
(
tensor
:
torch
.
Tensor
,
group_or_device_mesh
:
Union
[
ProcessGroup
,
DeviceMesh
]
=
None
)
->
DTensor
:
def
shard_rowwise
(
tensor
:
torch
.
Tensor
,
group_or_device_mesh
:
Union
[
ProcessGroup
,
DeviceMesh
]
=
None
,
inplace
:
bool
=
False
)
->
DTensor
:
"""
Shard the first dim of the given tensor
Shard the first dim of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be sharded.
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
If None, the tensor will be sharded with respect to the global process group.
Defaults to None.
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns:
DTensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if
group_or_device_mesh
is
None
:
...
...
@@ -24,12 +36,28 @@ def shard_rowwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
assert
len
(
group_or_device_mesh
.
shape
)
==
1
,
'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh
=
group_or_device_mesh
sharding_spec
=
ShardingSpec
(
dim_size
=
tensor
.
dim
(),
dim_partition_dict
=
{
0
:
[
0
]})
if
not
inplace
:
tensor
=
tensor
.
detach
().
clone
()
return
DTensor
(
tensor
,
device_mesh
,
sharding_spec
)
def
shard_colwise
(
tensor
:
torch
.
Tensor
,
group_or_device_mesh
:
Union
[
ProcessGroup
,
DeviceMesh
]
=
None
)
->
DTensor
:
def
shard_colwise
(
tensor
:
torch
.
Tensor
,
group_or_device_mesh
:
Union
[
ProcessGroup
,
DeviceMesh
]
=
None
,
inplace
:
bool
=
False
)
->
DTensor
:
"""
Shard the first dim of the given tensor
Shard the first dim of the given tensor.
Args:
tensor (torch.Tensor): The tensor to be sharded.
group_or_device_mesh (Union[ProcessGroup, DeviceMesh], optional): The group or device mesh to shard the tensor.
If None, the tensor will be sharded with respect to the global process group.
Defaults to None.
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns:
DTensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if
group_or_device_mesh
is
None
:
...
...
@@ -41,4 +69,8 @@ def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup
assert
len
(
group_or_device_mesh
.
shape
)
==
1
,
'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh
=
group_or_device_mesh
sharding_spec
=
ShardingSpec
(
dim_size
=
tensor
.
dim
(),
dim_partition_dict
=
{
-
1
:
[
0
]})
if
not
inplace
:
tensor
=
tensor
.
detach
().
clone
()
return
DTensor
(
tensor
,
device_mesh
,
sharding_spec
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment