Unverified Commit f519e6e0 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

FP8 Param support for offloading (#1823)



* Lora spike
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>

* Added FP8 param support
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
parent fab71571
...@@ -314,6 +314,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -314,6 +314,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Data structure to hold the FP8/MXFP8 tensor objects # Data structure to hold the FP8/MXFP8 tensor objects
self.fp8_tensor_object_map = {} self.fp8_tensor_object_map = {}
self.float8_transpose_cache_valid = {} self.float8_transpose_cache_valid = {}
self.dereferencing_list = []
# Tracking the number of layers offloaded # Tracking the number of layers offloaded
self.offloaded_group_count = 0 self.offloaded_group_count = 0
# Core data structure that decides the window for offloading # Core data structure that decides the window for offloading
...@@ -360,6 +361,12 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -360,6 +361,12 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self.tensor_tag_to_state[tensor_tag] = [] self.tensor_tag_to_state[tensor_tag] = []
self.tensor_tag_to_buf[tensor_tag] = [] self.tensor_tag_to_buf[tensor_tag] = []
# Added support for de-duplicating FP8 param tensors
for _, value in self.fp8_tensor_object_map.items():
if tensor is value:
self.dereferencing_list.append(tensor_tag)
break
self.fp8_tensor_object_map[tensor_tag] = tensor self.fp8_tensor_object_map[tensor_tag] = tensor
if isinstance(tensor, Float8Tensor): if isinstance(tensor, Float8Tensor):
self.float8_transpose_cache_valid[tensor_tag] = getattr( self.float8_transpose_cache_valid[tensor_tag] = getattr(
...@@ -398,7 +405,12 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -398,7 +405,12 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Handling the quantized tensor case specially here # Handling the quantized tensor case specially here
if isinstance(tensor, list): if isinstance(tensor, list):
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) # If it's a duplicated tensor, we don't need to locally
# write back a tensor as it would already be written
if tensor_tag in self.dereferencing_list:
self.dereferencing_list.remove(tensor_tag)
else:
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor)
tensor = self.fp8_tensor_object_map.pop(tensor_tag) tensor = self.fp8_tensor_object_map.pop(tensor_tag)
self.tensor_tag_to_buf.pop(tensor_tag, None) self.tensor_tag_to_buf.pop(tensor_tag, None)
...@@ -511,11 +523,21 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -511,11 +523,21 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
) )
else: else:
tensor_list.append(state_tuple) tensor_list.append(state_tuple)
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(tensor_list)
# No need to write back the duplicated tensor againn
# to the same location, this check ensures that
if tensor_label in self.dereferencing_list:
self.dereferencing_list.remove(tensor_label)
else:
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(
tensor_list
)
if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor): if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor):
self.fp8_tensor_object_map[tensor_label]._transpose_invalid = ( self.fp8_tensor_object_map[tensor_label]._transpose_invalid = (
self.float8_transpose_cache_valid.pop(tensor_label) self.float8_transpose_cache_valid.pop(tensor_label)
) )
self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop( self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop(
tensor_label tensor_label
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment