Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
83328329
Unverified
Commit
83328329
authored
Jul 28, 2022
by
ver217
Committed by
GitHub
Jul 28, 2022
Browse files
[hotfix] fix zero ddp buffer cast (#1376)
* fix zero ddp buffer cast * fix zero ddp ignore params
parent
5d5031e9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
2 deletions
+10
-2
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+10
-2
No files found.
colossalai/nn/parallel/data_parallel.py
View file @
83328329
...
@@ -192,7 +192,7 @@ class ZeroDDP(ColoDDP):
...
@@ -192,7 +192,7 @@ class ZeroDDP(ColoDDP):
"""
"""
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
gemini_manager
:
GeminiManager
)
->
None
:
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
gemini_manager
:
GeminiManager
)
->
None
:
super
().
__init__
(
module
.
half
()
,
process_group
=
gemini_manager
.
chunk_manager
.
process_group
)
super
().
__init__
(
module
,
process_group
=
gemini_manager
.
chunk_manager
.
process_group
)
self
.
gemini_manager
=
gemini_manager
self
.
gemini_manager
=
gemini_manager
self
.
chunk_manager
=
gemini_manager
.
chunk_manager
self
.
chunk_manager
=
gemini_manager
.
chunk_manager
self
.
param_op_hook
=
ZeROHookV2
(
gemini_manager
)
self
.
param_op_hook
=
ZeROHookV2
(
gemini_manager
)
...
@@ -204,13 +204,15 @@ class ZeroDDP(ColoDDP):
...
@@ -204,13 +204,15 @@ class ZeroDDP(ColoDDP):
# TODO: get param order and filter unused params
# TODO: get param order and filter unused params
for
p
in
module
.
parameters
():
for
p
in
module
.
parameters
():
if
getattr
(
p
,
'_ddp_to_ignore'
,
False
):
if
getattr
(
p
,
'_ddp_to_ignore'
,
False
):
p
.
data
=
p
.
half
()
continue
continue
assert
p
.
dtype
==
torch
.
half
fp32_p
=
p
.
float
().
detach
()
fp32_p
=
p
.
float
().
detach
()
p
.
data
=
p
.
half
()
self
.
chunk_manager
.
append_tensor
(
p
,
'fp16_param'
)
self
.
chunk_manager
.
append_tensor
(
p
,
'fp16_param'
)
self
.
chunk_manager
.
append_tensor
(
fp32_p
,
'fp32_param'
)
self
.
chunk_manager
.
append_tensor
(
fp32_p
,
'fp32_param'
)
self
.
fp32_params
.
append
(
fp32_p
)
self
.
fp32_params
.
append
(
fp32_p
)
self
.
grads_device
[
p
]
=
self
.
gemini_manager
.
default_device
self
.
grads_device
[
p
]
=
self
.
gemini_manager
.
default_device
self
.
_cast_buffers
()
self
.
_logger
=
get_dist_logger
()
self
.
_logger
=
get_dist_logger
()
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
...
@@ -481,3 +483,9 @@ class ZeroDDP(ColoDDP):
...
@@ -481,3 +483,9 @@ class ZeroDDP(ColoDDP):
input_name
=
key
[
len
(
prefix
):]
input_name
=
key
[
len
(
prefix
):]
if
input_name
not
in
local_state
:
if
input_name
not
in
local_state
:
unexpected_keys
.
append
(
key
)
unexpected_keys
.
append
(
key
)
def
_cast_buffers
(
self
):
for
buffer
in
self
.
module
.
buffers
():
buffer
.
data
=
buffer
.
cuda
()
if
torch
.
is_floating_point
(
buffer
):
buffer
.
data
=
buffer
.
half
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment