Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7f76517a
Unverified
Commit
7f76517a
authored
Apr 26, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 26, 2022
Browse files
[Tensor] make a simple net works with 1D row TP (#879)
parent
c4d903e6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
5 deletions
+36
-5
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+16
-0
tests/test_tensor/test_net_tp.py
tests/test_tensor/test_net_tp.py
+20
-5
No files found.
colossalai/tensor/colo_tensor.py
View file @
7f76517a
...
...
@@ -157,5 +157,21 @@ class ColoTensor(object):
def
backward
(
self
,
gradient
:
Optional
[
torch
.
Tensor
]
=
None
,
retain_graph
:
bool
=
False
):
self
.
_torch_tensor
.
backward
(
gradient
=
gradient
,
retain_graph
=
retain_graph
)
## TODO(fjr) we reduce redundency of the following code
def
__add__
(
self
,
o
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
+
o
.
torch_tensor
())
def
__truediv__
(
self
,
o
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
()
/
o
)
def
view
(
self
,
*
args
:
int
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
().
view
(
*
args
))
def
permute
(
self
,
*
args
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
().
permute
(
*
args
))
def
transpose
(
self
,
*
args
)
->
"ColoTensor"
:
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
().
transpose
(
*
args
))
def
contiguous
(
self
):
return
ColoTensor
.
init_from_torch_tensor
(
self
.
torch_tensor
().
contiguous
())
tests/test_tensor/test_net_tp.py
View file @
7f76517a
...
...
@@ -7,7 +7,8 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
ColoInitContext
from
colossalai.tensor
import
named_params_with_colotensor
from
colossalai.tensor
import
named_params_with_colotensor
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
from
colossalai.context
import
ParallelMode
from
functools
import
partial
...
...
@@ -20,18 +21,32 @@ def run_simple_net():
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
True
)
parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec
=
TensorSpec
(
parallel_action_list
)
# A naive way to set spec for all weights in Linear
for
name
,
p
in
named_params_with_colotensor
(
model
):
if
not
isinstance
(
p
,
ColoTensor
):
continue
if
'weight'
in
name
and
'LayerNorm'
not
in
name
and
'ln'
not
in
name
and
'embed'
not
in
name
:
p
.
set_spec
(
spec
)
model
.
cuda
()
for
param
in
named_params_with_colotensor
(
model
):
print
(
param
)
# we set the Specs for weight of each linear.
# model.proj1.weight.set_spec('1Drow')
# model.proj2.weight.set_spec('1Drow')
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
output
=
model
(
data
)
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
if
criterion
:
output
=
model
(
data
)
loss
=
criterion
(
output
,
label
)
else
:
output
=
model
(
data
,
label
)
loss
=
output
print
(
loss
.
torch_tensor
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment