Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
db1bef90
Unverified
Commit
db1bef90
authored
Jul 07, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 07, 2022
Browse files
[hotfix] fx shard 1d pass bug fixing (#1220)
parent
11973d89
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
13 deletions
+8
-13
colossalai/fx/passes/shard_1d_pass.py
colossalai/fx/passes/shard_1d_pass.py
+8
-13
No files found.
colossalai/fx/passes/shard_1d_pass.py
View file @
db1bef90
import
torch
from
torch.fx.node
import
map_arg
from
torch.fx.node
import
Node
from
torch.fx.passes.split_module
import
split_module
from
colossalai.tensor
import
ColoTensorSpec
,
distspec
,
ProcessGroup
,
ComputeSpec
,
ComputePattern
import
colossalai
from
colossalai.tensor
import
ColoTensor
,
TensorSpec
,
distspec
,
ProcessGroup
,
ComputeSpec
,
ComputePattern
def
weight_split
(
weight
:
torch
.
nn
.
parameter
.
Parameter
,
dim
:
int
)
->
torch
.
nn
.
parameter
.
Parameter
:
def
weight_split
(
weight
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
nn
.
parameter
.
Parameter
:
"""weight_split
split a nn.Parameter
...
...
@@ -18,22 +14,20 @@ def weight_split(weight: torch.nn.parameter.Parameter, dim: int) -> torch.nn.par
Returns:
_type_: _description_
"""
#TODO: This func temporarily works with no materialization
# Append a Tensor spec to target_module.weight.shard
# Convert to ColoTensor: colo_tensor = ColoTensor.from_torch_tensor(tensor, spec)
#
assert isinstance(weight, torch.
nn.parameter.Paramete
r), \
#
f'The type of the input tensor should be torch.nn.parameter' \
#
f'Your Input tensor is {type(weight)}'
assert
isinstance
(
weight
,
torch
.
Tenso
r
),
\
f
'The type of the input tensor should be torch.nn.parameter'
\
f
'Your Input tensor is
{
type
(
weight
)
}
'
# FIXME() I initialized a PG for this tensor. Only has TP comm group.
# we only consider the TP-only caes.
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
dim
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
Colo
TensorSpec
(
pg
,
distspec
.
shard
([
dim
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
# setattr(weight, "fx_attr", spec)
weight
.
data
=
ColoTensor
(
data
=
weight
.
data
,
spec
=
spec
)
setattr
(
weight
,
"fx_attr"
,
spec
)
return
weight
...
...
@@ -58,6 +52,7 @@ def row_shard_linear_pass(gm: torch.fx.GraphModule):
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
if
isinstance
(
target_module
,
torch
.
nn
.
Linear
):
target_module
.
weight
=
weight_split
(
target_module
.
weight
,
dim
=-
1
)
gm
.
recompile
()
return
gm
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment