Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
bcc86550
Unverified
Commit
bcc86550
authored
Apr 24, 2022
by
Ziyue Jiang
Committed by
GitHub
Apr 24, 2022
Browse files
[Tensor ] Add 1Drow weight reshard by spec (#854)
parent
d7e0303d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
41 additions
and
11 deletions
+41
-11
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+6
-2
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+26
-1
tests/components_to_test/simple_net.py
tests/components_to_test/simple_net.py
+4
-3
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+2
-2
tests/test_tensor/test_net_tp.py
tests/test_tensor/test_net_tp.py
+3
-3
No files found.
colossalai/tensor/_ops/linear.py
View file @
bcc86550
...
@@ -6,6 +6,7 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
...
@@ -6,6 +6,7 @@ from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward
from
colossalai.nn.layer.utils
import
divide
from
colossalai.nn.layer.utils
import
divide
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
packaging
import
version
from
packaging
import
version
from
colossalai.utils.cuda
import
get_current_device
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
def
colo_linear
(
types
,
args
,
kwargs
,
pg
):
def
colo_linear
(
types
,
args
,
kwargs
,
pg
):
...
@@ -39,12 +40,15 @@ def colo_linear(types, args, kwargs, pg):
...
@@ -39,12 +40,15 @@ def colo_linear(types, args, kwargs, pg):
# Input:S[1]
# Input:S[1]
input_per_partition
=
split_forward_gather_backward
(
input_tensor
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
input_per_partition
=
split_forward_gather_backward
(
input_tensor
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
# Output:P
# Output:P
partial_output
=
torch
.
nn
.
functional
.
linear
(
input_per_partition
,
weight
.
torch_tensor
())
device
=
get_current_device
()
# TODO where to put to(deivce)?
weight_
=
weight
.
torch_tensor
().
to
(
device
)
partial_output
=
torch
.
nn
.
functional
.
linear
(
input_per_partition
,
weight_
)
# Reduce(Output)
# Reduce(Output)
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
# Bias
# Bias
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
bias_
=
bias
.
to
(
device
)
output
=
output
+
bias_
return
output
return
output
else
:
else
:
...
...
colossalai/tensor/colo_tensor.py
View file @
bcc86550
...
@@ -3,7 +3,10 @@ from .op_wrapper import _COLOSSAL_OPS
...
@@ -3,7 +3,10 @@ from .op_wrapper import _COLOSSAL_OPS
import
torch
import
torch
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
from
numpy
import
product
from
numpy
import
product
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
colossalai.nn.layer.utils
import
divide
from
colossalai.utils.cuda
import
get_current_device
class
ColoTensor
(
object
):
class
ColoTensor
(
object
):
""" Data Structure for Tensor in Colossal-AI
""" Data Structure for Tensor in Colossal-AI
...
@@ -85,6 +88,28 @@ class ColoTensor(object):
...
@@ -85,6 +88,28 @@ class ColoTensor(object):
device
=
self
.
_device
)
device
=
self
.
_device
)
return
self
.
_torch_tensor
return
self
.
_torch_tensor
def
set_spec
(
self
,
spec
:
str
,
lazy_shard
:
bool
=
False
)
->
None
:
self
.
_shard_spec
=
spec
if
lazy_shard
==
False
:
self
.
_shard
()
def
_shard
(
self
):
assert
self
.
_shard_spec
is
not
None
,
'You should call set_spec() before _shard() ColoTensor.'
if
self
.
_shard_spec
==
"1Drow"
:
# TODO It actually represents the sharding layout for Linear-1Drow-weight, but we make it simpler now.
num_partition
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
TENSOR
)
dim
=
-
1
chunk_size
=
divide
(
self
.
_size
[
dim
],
num_partition
)
device
=
get_current_device
()
# Reshape to get shard for this rank and we don't want autograd
# recording here for the narrow op and 'local_shard' should be a
# leaf variable in the autograd graph.
self
.
_torch_tensor
=
self
.
_torch_tensor
.
narrow
(
dim
,
local_rank
*
chunk_size
,
chunk_size
).
detach
().
contiguous
()
# TODO Shall we clone() here since detach() will point to the old tensor?
self
.
_torch_tensor
.
requires_grad
=
self
.
_requires_grad
self
.
_size
=
self
.
_torch_tensor
.
size
()
self
.
_device
=
device
# TODO A `fake` device now because torch_tensor.device always = cpu
@
classmethod
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
global
_COLOSSAL_OPS
global
_COLOSSAL_OPS
...
...
tests/components_to_test/simple_net.py
View file @
bcc86550
from
zmq
import
device
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
colossalai.nn
import
CheckpointModule
from
colossalai.nn
import
CheckpointModule
from
.utils.dummy_data_generator
import
DummyDataGenerator
from
.utils.dummy_data_generator
import
DummyDataGenerator
from
.registry
import
non_distributed_component_funcs
from
.registry
import
non_distributed_component_funcs
from
colossalai.utils.cuda
import
get_current_device
class
SimpleNet
(
CheckpointModule
):
class
SimpleNet
(
CheckpointModule
):
"""
"""
...
@@ -25,8 +26,8 @@ class SimpleNet(CheckpointModule):
...
@@ -25,8 +26,8 @@ class SimpleNet(CheckpointModule):
class
DummyDataLoader
(
DummyDataGenerator
):
class
DummyDataLoader
(
DummyDataGenerator
):
def
generate
(
self
):
def
generate
(
self
):
data
=
torch
.
rand
(
16
,
4
)
data
=
torch
.
rand
(
16
,
4
,
device
=
get_current_device
()
)
label
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
16
,))
label
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
16
,)
,
device
=
get_current_device
()
)
return
data
,
label
return
data
,
label
...
...
tests/test_tensor/test_linear_tp.py
View file @
bcc86550
...
@@ -35,7 +35,7 @@ def run_linear_tp1d_row_test():
...
@@ -35,7 +35,7 @@ def run_linear_tp1d_row_test():
W_shape
=
(
out_features
,
in_features
)
W_shape
=
(
out_features
,
in_features
)
W_master
=
torch
.
randn
(
W_shape
,
dtype
=
dtype
,
device
=
device
)
W_master
=
torch
.
randn
(
W_shape
,
dtype
=
dtype
,
device
=
device
)
W
=
broadcast_tensor_chunk
(
W_master
,
chunk_size
=
DEPTH
,
local_rank
=
local_rank
)
W
=
broadcast_tensor_chunk
(
W_master
,
chunk_size
=
1
)
W
.
requires_grad
=
True
W
.
requires_grad
=
True
B_shape
=
(
out_features
)
B_shape
=
(
out_features
)
...
@@ -45,7 +45,7 @@ def run_linear_tp1d_row_test():
...
@@ -45,7 +45,7 @@ def run_linear_tp1d_row_test():
# replace the torch nn.Parameters with ColoTensor
# replace the torch nn.Parameters with ColoTensor
sharded_weight
=
ColoTensor
.
init_from_torch_tensor
(
W
)
sharded_weight
=
ColoTensor
.
init_from_torch_tensor
(
W
)
sharded_weight
.
_shard_
spec
=
"1Drow"
sharded_weight
.
set_spec
(
spec
=
"1Drow"
)
# reshard
sharded_bias
=
ColoTensor
.
init_from_torch_tensor
(
B
)
sharded_bias
=
ColoTensor
.
init_from_torch_tensor
(
B
)
replace_parameter_add_grad
(
layer
,
sharded_weight
,
sharded_bias
)
replace_parameter_add_grad
(
layer
,
sharded_weight
,
sharded_bias
)
out
=
layer
(
A
)
out
=
layer
(
A
)
...
...
tests/test_tensor/test_net_tp.py
View file @
bcc86550
...
@@ -23,9 +23,9 @@ def run_simple_net():
...
@@ -23,9 +23,9 @@ def run_simple_net():
with
ColoInitContext
():
with
ColoInitContext
():
model
=
model_builder
(
checkpoint
=
True
)
model
=
model_builder
(
checkpoint
=
True
)
#
TODO(jzy)
we set the Specs for weight of each linear.
# we set the Specs for weight of each linear.
#
model.proj1.weight.set_spec('1Drow')
model
.
proj1
.
weight
.
set_spec
(
'1Drow'
)
#
model.proj2.weight.set_spec('1Drow')
model
.
proj2
.
weight
.
set_spec
(
'1Drow'
)
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
output
=
model
(
data
)
output
=
model
(
data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment