Unverified Commit f519e6e0 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

FP8 Param support for offloading (#1823)



* Lora spike
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>

* Added FP8 param support
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* lint fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos02.a51.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
parent fab71571
......@@ -314,6 +314,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Data structure to hold the FP8/MXFP8 tensor objects
self.fp8_tensor_object_map = {}
self.float8_transpose_cache_valid = {}
self.dereferencing_list = []
# Tracking the number of layers offloaded
self.offloaded_group_count = 0
# Core data structure that decides the window for offloading
......@@ -360,6 +361,12 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self.tensor_tag_to_state[tensor_tag] = []
self.tensor_tag_to_buf[tensor_tag] = []
# Added support for de-duplicating FP8 param tensors
for _, value in self.fp8_tensor_object_map.items():
if tensor is value:
self.dereferencing_list.append(tensor_tag)
break
self.fp8_tensor_object_map[tensor_tag] = tensor
if isinstance(tensor, Float8Tensor):
self.float8_transpose_cache_valid[tensor_tag] = getattr(
......@@ -398,7 +405,12 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Handling the quantized tensor case specially here
if isinstance(tensor, list):
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor)
# If it's a duplicated tensor, we don't need to locally
# write back a tensor as it would already be written
if tensor_tag in self.dereferencing_list:
self.dereferencing_list.remove(tensor_tag)
else:
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor)
tensor = self.fp8_tensor_object_map.pop(tensor_tag)
self.tensor_tag_to_buf.pop(tensor_tag, None)
......@@ -511,11 +523,21 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
)
else:
tensor_list.append(state_tuple)
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(tensor_list)
# No need to write back the duplicated tensor againn
# to the same location, this check ensures that
if tensor_label in self.dereferencing_list:
self.dereferencing_list.remove(tensor_label)
else:
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(
tensor_list
)
if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor):
self.fp8_tensor_object_map[tensor_label]._transpose_invalid = (
self.float8_transpose_cache_valid.pop(tensor_label)
)
self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop(
tensor_label
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment