Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2a0a427e
Unverified
Commit
2a0a427e
authored
Apr 24, 2022
by
Ziyue Jiang
Committed by
GitHub
Apr 24, 2022
Browse files
[tensor]add assert for colo_tensor 1Drow (#846)
parent
05023ecf
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
9 additions
and
7 deletions
+9
-7
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+7
-4
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+1
-2
tests/test_tensor/_utils/__init__.py
tests/test_tensor/_utils/__init__.py
+0
-0
tests/test_tensor/_utils/_util.py
tests/test_tensor/_utils/_util.py
+0
-0
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+1
-1
No files found.
colossalai/tensor/_ops/linear.py
View file @
2a0a427e
...
@@ -3,6 +3,8 @@ from colossalai.tensor.op_wrapper import colo_op_impl
...
@@ -3,6 +3,8 @@ from colossalai.tensor.op_wrapper import colo_op_impl
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
colossalai.context
import
ParallelMode
from
colossalai.context
import
ParallelMode
from
colossalai.nn.layer.parallel_1d._utils
import
split_forward_gather_backward
,
reduce_input
from
colossalai.nn.layer.parallel_1d._utils
import
split_forward_gather_backward
,
reduce_input
from
colossalai.nn.layer.utils
import
divide
from
colossalai.core
import
global_context
as
gpc
from
packaging
import
version
from
packaging
import
version
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
...
@@ -29,10 +31,11 @@ def colo_linear(types, args, kwargs, pg):
...
@@ -29,10 +31,11 @@ def colo_linear(types, args, kwargs, pg):
if
weight
.
shard_spec
==
None
:
if
weight
.
shard_spec
==
None
:
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
.
torch_tensor
(),
bias
)
return
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
.
torch_tensor
(),
bias
)
elif
weight
.
shard_spec
==
'1Drow'
:
elif
weight
.
shard_spec
==
'1Drow'
:
"""
# Input:S[1] x Weight:S[0] = Output:P
Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
All-Reduce(Output) + bias = res
assert
divide
(
input_tensor
.
shape
[
-
1
],
gpc
.
tensor_parallel_size
)
==
weight
.
size
[
-
1
],
\
"""
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_tensor
.
shape
,
weight
.
size
,
weight
.
size
[
-
1
]
*
gpc
.
tensor_parallel_size
)
# Input:S[1]
# Input:S[1]
input_per_partition
=
split_forward_gather_backward
(
input_tensor
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
input_per_partition
=
split_forward_gather_backward
(
input_tensor
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
# Output:P
# Output:P
...
...
colossalai/tensor/colo_tensor.py
View file @
2a0a427e
from
numpy
import
product
from
numpy
import
product
import
torch
import
torch
from
typing
import
Tuple
from
typing
import
Tuple
,
Optional
import
numpy
from
.op_wrapper
import
_COLOSSAL_OPS
from
.op_wrapper
import
_COLOSSAL_OPS
class
ColoTensor
(
object
):
class
ColoTensor
(
object
):
...
...
tests/test_tensor/
test_tensor
_utils/__init__.py
→
tests/test_tensor/_utils/__init__.py
View file @
2a0a427e
File moved
tests/test_tensor/
test_tensor
_utils/_util.py
→
tests/test_tensor/_utils/_util.py
View file @
2a0a427e
File moved
tests/test_tensor/test_linear_tp.py
View file @
2a0a427e
...
@@ -14,7 +14,7 @@ from colossalai.utils import free_port
...
@@ -14,7 +14,7 @@ from colossalai.utils import free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
test_tensor
_utils
import
check_equal
,
replace_parameter_add_grad
,
broadcast_tensor_chunk
from
_utils
import
check_equal
,
replace_parameter_add_grad
,
broadcast_tensor_chunk
def
run_linear_tp1d_row_test
():
def
run_linear_tp1d_row_test
():
device
=
get_current_device
()
device
=
get_current_device
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment