Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
126cf180
Unverified
Commit
126cf180
authored
Nov 28, 2023
by
アマデウス
Committed by
GitHub
Nov 28, 2023
Browse files
[hotfix] fixed memory usage of shardformer module replacement (#5122)
parent
7b789f4d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
6 deletions
+6
-6
colossalai/shardformer/layer/_operation.py
colossalai/shardformer/layer/_operation.py
+5
-5
colossalai/tensor/d_tensor/comm_spec.py
colossalai/tensor/d_tensor/comm_spec.py
+1
-1
No files found.
colossalai/shardformer/layer/_operation.py
View file @
126cf180
...
...
@@ -473,16 +473,17 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
,
ctx
.
dim
,
ctx
.
process_group
),
None
,
None
class
HookParameter
(
torch
.
autograd
.
Function
):
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
):
ctx
.
save_for_backward
(
weight
,
bias
)
output
=
input
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
weight
,
bias
=
ctx
.
saved_tensors
...
...
@@ -491,13 +492,12 @@ class HookParameter(torch.autograd.Function):
if
bias
is
not
None
:
bias
=
bias
.
view
(
bias
.
shape
)
return
grad_output
,
None
,
None
def
hook_paramter_in_backward
(
input
,
weight
=
None
,
bias
=
None
):
return
HookParameter
.
apply
(
input
,
weight
,
bias
)
def
_reduce
(
input_
,
process_group
):
# skip if only one rank involved
if
dist
.
get_world_size
(
process_group
)
==
1
:
...
...
@@ -522,7 +522,7 @@ def _split(input_, dim=-1, process_group=None):
tensor_list
=
torch
.
split
(
input_
,
dim_size
//
world_size
,
dim
=
dim
)
rank
=
dist
.
get_rank
(
process_group
)
output
=
tensor_list
[
rank
].
contiguous
()
output
=
tensor_list
[
rank
].
clone
().
contiguous
()
return
output
...
...
colossalai/tensor/d_tensor/comm_spec.py
View file @
126cf180
...
...
@@ -112,7 +112,7 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec):
dim
=
comm_spec
.
shard_dim
length
=
tensor
.
shape
[
comm_spec
.
shard_dim
]
//
dist
.
get_world_size
(
process_group
)
start
=
length
*
dist
.
get_rank
(
process_group
)
output
=
torch
.
narrow
(
tensor
,
dim
,
start
,
length
).
contiguous
()
output
=
torch
.
narrow
(
tensor
,
dim
,
start
,
length
).
clone
().
contiguous
()
return
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment