Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
29159d9b
Unverified
Commit
29159d9b
authored
Apr 25, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 25, 2022
Browse files
hotfix tensor unittest bugs (#862)
parent
1258af71
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
3 deletions
+4
-3
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+4
-3
No files found.
colossalai/tensor/_ops/linear.py
View file @
29159d9b
...
@@ -8,6 +8,7 @@ from colossalai.core import global_context as gpc
...
@@ -8,6 +8,7 @@ from colossalai.core import global_context as gpc
from
packaging
import
version
from
packaging
import
version
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
@
colo_op_impl
(
torch
.
nn
.
functional
.
linear
)
def
colo_linear
(
types
,
args
,
kwargs
,
pg
):
def
colo_linear
(
types
,
args
,
kwargs
,
pg
):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
...
@@ -34,13 +35,13 @@ def colo_linear(types, args, kwargs, pg):
...
@@ -34,13 +35,13 @@ def colo_linear(types, args, kwargs, pg):
elif
weight
.
shard_spec
==
'1Drow'
:
elif
weight
.
shard_spec
==
'1Drow'
:
# Input:S[1] x Weight:S[0] = Output:P
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# All-Reduce(Output) + bias = res
assert
divide
(
input_tensor
.
shape
[
-
1
],
gpc
.
tensor_parallel_size
)
==
weight
.
size
[
-
1
]
,
\
assert
divide
(
input_tensor
.
shape
[
-
1
],
gpc
.
tensor_parallel_size
)
==
weight
.
size
(
-
1
)
,
\
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
'Invalid shapes in 1Drow forward: input={}, weight={}. Expected last dim of input {}.'
.
format
(
input_tensor
.
shape
,
weight
.
size
,
weight
.
size
[
-
1
]
*
gpc
.
tensor_parallel_size
)
input_tensor
.
shape
,
weight
.
size
,
weight
.
size
[
-
1
]
*
gpc
.
tensor_parallel_size
)
# Input:S[1]
# Input:S[1]
input_per_partition
=
split_forward_gather_backward
(
input_tensor
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
input_per_partition
=
split_forward_gather_backward
(
input_tensor
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
# Output:P
# Output:P
device
=
get_current_device
()
# TODO where to put to(deivce)?
device
=
get_current_device
()
# TODO where to put to(deivce)?
weight_
=
weight
.
torch_tensor
().
to
(
device
)
weight_
=
weight
.
torch_tensor
().
to
(
device
)
partial_output
=
torch
.
nn
.
functional
.
linear
(
input_per_partition
,
weight_
)
partial_output
=
torch
.
nn
.
functional
.
linear
(
input_per_partition
,
weight_
)
# Reduce(Output)
# Reduce(Output)
...
@@ -50,7 +51,7 @@ def colo_linear(types, args, kwargs, pg):
...
@@ -50,7 +51,7 @@ def colo_linear(types, args, kwargs, pg):
bias_
=
bias
.
to
(
device
)
bias_
=
bias
.
to
(
device
)
output
=
output
+
bias_
output
=
output
+
bias_
return
output
return
output
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment