Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cb5a4778
Unverified
Commit
cb5a4778
authored
Apr 22, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 22, 2022
Browse files
Revert "[WIP] Applying ColoTensor on TP-1D-row Linear. (#831)" (#835)
This reverts commit
ac88de6d
.
parent
5e00e6cf
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
101 deletions
+8
-101
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+2
-8
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+6
-19
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+0
-74
No files found.
colossalai/tensor/_ops/linear.py
View file @
cb5a4778
...
...
@@ -19,18 +19,12 @@ def colo_linear(types, args, kwargs, pg):
bias
=
None
else
:
bias
=
kwargs
.
get
(
'bias'
,
None
)
if
isinstance
(
bias
,
ColoTensor
):
bias
=
bias
.
torch_tensor
()
# Add communication logic before and after linear call.
if
isinstance
(
weight
,
ColoTensor
):
if
weight
.
shard_spec
==
None
:
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
.
torch_tensor
(),
bias
)
elif
weight
.
shard_spec
==
'1Drow'
:
# TODO(jzy): implement 1Drow TP linear here.
raise
NotImplementedError
else
:
raise
NotImplementedError
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
.
torch_tensor
(),
bias
)
else
:
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
)
colossalai/tensor/colo_tensor.py
View file @
cb5a4778
import
torch
from
.op_wrapper
import
_COLOSSAL_OPS
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
class
ColoTensor
(
object
):
...
...
@@ -21,35 +21,20 @@ class ColoTensor(object):
requires_grad
=
False
,
pin_memory
=
False
,
torch_tensor
=
torch
.
empty
(
0
),
shard_spec
:
str
=
None
,
):
self
.
_size
=
size
self
.
_dtype
=
dtype
self
.
_requires_grad
=
requires_grad
self
.
_pin_memory
=
pin_memory
self
.
_torch_tensor
=
torch_tensor
self
.
_shard_spec
=
shard_spec
@
property
def
shard_spec
(
self
)
->
Optional
[
str
]:
return
self
.
_shard_spec
@
property
def
data
(
self
):
return
self
.
_torch_tensor
.
data
@
property
def
grad
(
self
):
return
self
.
_torch_tensor
.
grad
@
staticmethod
def
init_from_torch_tensor
(
tensor
:
torch
.
Tensor
,
shard_spec
:
str
=
None
)
->
'ColoTensor'
:
def
init_from_torch_tensor
(
tensor
:
torch
.
Tensor
)
:
colo_t
=
ColoTensor
(
*
tensor
.
size
(),
dtype
=
tensor
.
dtype
,
requires_grad
=
tensor
.
requires_grad
,
pin_memory
=
tensor
.
pin_memory
,
torch_tensor
=
tensor
,
shard_spec
=
shard_spec
)
torch_tensor
=
tensor
)
return
colo_t
def
del_torch_tensor
(
self
)
->
None
:
...
...
@@ -82,5 +67,7 @@ class ColoTensor(object):
if
kwargs
is
None
:
kwargs
=
{}
kwargs
=
{
k
:
v
.
torch_tensor
()
if
isinstance
(
v
,
ColoTensor
)
else
v
for
k
,
v
in
kwargs
.
items
()}
kwargs
=
{
k
:
v
.
torch_tensor
()
if
isinstance
(
v
,
ColoTensor
)
else
v
for
k
,
v
in
kwargs
.
items
()
}
return
func
(
*
args
,
**
kwargs
)
tests/test_tensor/test_linear_tp.py
deleted
100644 → 0
View file @
5e00e6cf
from
joblib
import
Parallel
from
numpy
import
allclose
,
require
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.tensor
import
ColoTensor
from
copy
import
deepcopy
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.logging
import
get_dist_logger
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
def
run_linear_tp1d_row_test
():
in_dim
=
4
out_dim
=
5
fc
=
torch
.
nn
.
Linear
(
in_dim
,
out_dim
,
bias
=
True
)
fc_ref
=
deepcopy
(
fc
)
input_ref
=
torch
.
randn
(
1
,
in_dim
)
input_tensor
=
input_ref
.
clone
()
# sharded_weight = ColoTensor.init_from_torch_tensor(fc_ref.weight, "1Drow")
# shard weight at begiin
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
sharded_weight
=
ColoTensor
(
in_dim
/
world_size
,
out_dim
,
shard_spec
=
"1Drow"
)
sharded_bias
=
ColoTensor
.
init_from_torch_tensor
(
fc_ref
.
bias
)
# replace the torch nn.Parameters with ShardedTensor
delattr
(
fc
,
'weight'
)
setattr
(
fc
,
'weight'
,
sharded_weight
)
delattr
(
fc
,
'bias'
)
setattr
(
fc
,
'bias'
,
sharded_bias
)
fc
.
weight
.
requires_grad
=
True
fc
.
bias
.
requires_grad
=
True
# torch.nn.functional.linear(torch.randn(1, in_dim), sharded_weight, sharded_bias)
out
=
fc
(
input_tensor
)
loss
=
out
.
sum
()
loss
.
backward
()
out_ref
=
fc_ref
(
input_ref
)
loss_ref
=
out_ref
.
sum
()
loss_ref
.
backward
()
assert
(
loss_ref
==
loss
)
assert
allclose
(
fc_ref
.
weight
.
grad
,
fc
.
weight
.
torch_tensor
().
grad
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_linear_tp1d_row_test
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
4
])
@
rerun_if_address_is_in_use
()
def
test_linear_1d
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_linear_1d
(
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment