Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
85928d08
Unverified
Commit
85928d08
authored
Aug 02, 2023
by
Kirthi Shankar Sivamani
Committed by
GitHub
Aug 02, 2023
Browse files
Store FP8 checkpointing data in CPU (#351)
Signed-off-by:
Kirthi Shankar Sivamani
<
ksivamani@nvidia.com
>
parent
c8175d9e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
7 deletions
+13
-7
transformer_engine/pytorch/fp8.py
transformer_engine/pytorch/fp8.py
+7
-1
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+6
-6
No files found.
transformer_engine/pytorch/fp8.py
View file @
85928d08
...
@@ -87,7 +87,13 @@ def get_amax_reduce_handle_fwd() -> Union[bool, None]:
...
@@ -87,7 +87,13 @@ def get_amax_reduce_handle_fwd() -> Union[bool, None]:
def
get_global_fp8_buffer
()
->
Dict
[
str
,
List
[
torch
.
Tensor
]]:
def
get_global_fp8_buffer
()
->
Dict
[
str
,
List
[
torch
.
Tensor
]]:
"""Returns global fp8 buffer."""
"""Returns global fp8 buffer."""
return
_global_fp8_buffer
buffer
=
{}
# Map all tensors to CPU.
for
k
,
v
in
_global_fp8_buffer
.
items
():
buffer
[
k
]
=
[
tensor
.
cpu
()
for
tensor
in
v
]
return
buffer
def
set_global_fp8_buffer
(
buffer
:
Dict
[
str
,
List
[
torch
.
Tensor
]])
->
None
:
def
set_global_fp8_buffer
(
buffer
:
Dict
[
str
,
List
[
torch
.
Tensor
]])
->
None
:
...
...
transformer_engine/pytorch/module/base.py
View file @
85928d08
...
@@ -349,12 +349,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
...
@@ -349,12 +349,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if
fp8_checkpoint
:
if
fp8_checkpoint
:
state
=
{}
state
=
{}
state
[
"scale_fwd"
]
=
self
.
fp8_meta
[
"scaling_fwd"
].
scale
state
[
"scale_fwd"
]
=
self
.
fp8_meta
[
"scaling_fwd"
].
scale
.
cpu
()
state
[
"scale_inv_fwd"
]
=
self
.
fp8_meta
[
"scaling_fwd"
].
scale_inv
state
[
"scale_inv_fwd"
]
=
self
.
fp8_meta
[
"scaling_fwd"
].
scale_inv
.
cpu
()
state
[
"amax_history_fwd"
]
=
self
.
fp8_meta
[
"scaling_fwd"
].
amax_history
state
[
"amax_history_fwd"
]
=
self
.
fp8_meta
[
"scaling_fwd"
].
amax_history
.
cpu
()
state
[
"scale_bwd"
]
=
self
.
fp8_meta
[
"scaling_bwd"
].
scale
state
[
"scale_bwd"
]
=
self
.
fp8_meta
[
"scaling_bwd"
].
scale
.
cpu
()
state
[
"scale_inv_bwd"
]
=
self
.
fp8_meta
[
"scaling_bwd"
].
scale_inv
state
[
"scale_inv_bwd"
]
=
self
.
fp8_meta
[
"scaling_bwd"
].
scale_inv
.
cpu
()
state
[
"amax_history_bwd"
]
=
self
.
fp8_meta
[
"scaling_bwd"
].
amax_history
state
[
"amax_history_bwd"
]
=
self
.
fp8_meta
[
"scaling_bwd"
].
amax_history
.
cpu
()
state
[
"global_fp8_buffer"
]
=
get_global_fp8_buffer
()
state
[
"global_fp8_buffer"
]
=
get_global_fp8_buffer
()
# Store other pickelable values.
# Store other pickelable values.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment