Unverified Commit 71124c31 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Remove unwanted Memory Copies/Fix weight parameters (#1034)



* removed unwanted memcpyDtoD/fixed weight parametrisation
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 931b44fe
......@@ -284,11 +284,8 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
debug=debug,
)
self.num_prefetch_group = num_prefetch_group
# prepare for tensor buffer
self.tensor_id_to_tensor_buf_double_bufs = []
for _ in range(2):
self.tensor_id_to_tensor_buf_double_bufs.append({})
# Data Structure to maintain reference to activation tensors
self.tensor_tag_to_buf = {}
# allocate streams and events for synchronization
self.d2h_stream = torch.cuda.Stream()
......@@ -300,37 +297,6 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self.compute_stream_bwd_start_events.append(torch.cuda.Event())
self.d2h_final_event = torch.cuda.Event()
def get_tensor_buf_for_offloaded_tensor(self, tensor, tensor_tag):
"""Get tensor buffer for offloaded tensor."""
group_id, tensor_id = tensor_tag
# obtain ping-pong buffer
id_buf_map = self.tensor_id_to_tensor_buf_double_bufs[(group_id % 2)]
if not tensor_id in id_buf_map:
allocate_new_buf = True
else:
tensor_buf = id_buf_map[tensor_id]
allocate_new_buf = (
tensor_buf.size() != tensor.size() or tensor_buf.dtype != tensor.dtype
)
if allocate_new_buf:
# supposed to only execute once
fp8_offload = isinstance(tensor, Float8Tensor)
buffer = torch.empty(
tensor.size(),
dtype=torch.uint8 if fp8_offload else tensor.dtype,
layout=tensor.layout,
device=tensor.device,
)
if isinstance(tensor, Float8Tensor):
id_buf_map[tensor_id] = Float8Tensor.make_like(tensor, data=buffer)
else:
id_buf_map[tensor_id] = buffer
return id_buf_map[tensor_id]
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
torch_stray_tensor = isinstance(
......@@ -347,21 +313,12 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
self.tensor_tag_to_state[tensor_tag] = tensor
if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker(
tensor
):
# first copy the tensor to tensorbuf,
# so that the original tensor will not be deleted
tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag)
tensor_buf.copy_(tensor)
if hasattr(tensor, "weight_offloading"):
tensor_buf.weight_offloading = True
if hasattr(tensor, "activation_offloading"):
tensor_buf.activation_offloading = True
# Here we just save it, and at commit, bulk_offload_group will handle it
self.tensor_tag_to_state[tensor_tag] = tensor_buf
else:
self.tensor_tag_to_state[tensor_tag] = tensor
self.tensor_tag_to_buf[tensor_tag] = tensor
else:
tensor_tag = (-1, self.torch_tensor_count)
self.torch_tensor_count += 1
......@@ -373,6 +330,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
"""Tensor pop."""
assert tensor_tag in self.tensor_tag_to_state
tensor = self.tensor_tag_to_state.pop(tensor_tag)
self.tensor_tag_to_buf.pop(tensor_tag, None)
# the tensor should have been copied back in on_group_commit_backward()
# which invokes bulk_reload_group.
assert not isinstance(tensor, tuple)
......@@ -389,10 +347,6 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device):
if hasattr(tensor_on_device, "weight_offloading"):
delattr(tensor_on_device, "weight_offloading")
if hasattr(tensor_on_device, "activation_offloading"):
delattr(tensor_on_device, "activation_offloading")
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
self.tensor_tag_to_state[tensor_tag] = state
......@@ -403,12 +357,12 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
previous_group = current_group - 1
if previous_group < self.num_offload_group:
torch.cuda.synchronize()
# TODO (guyueh): this part is originally designed to reduce the peak memory usage. # pylint: disable=fixme
# however, uncommenting this part will cause illegal access, have not figured out why.
if previous_group + 2 >= self.num_offload_group:
# this buffer is no longer required
self.tensor_id_to_tensor_buf_double_bufs[(previous_group % 2)] = {}
# Have to release the memory held by activations of the previous layer
if previous_group >= 0:
for tensor_tag, _ in self.tensor_tag_to_buf.items():
if tensor_tag[0] == previous_group:
self.tensor_tag_to_buf[tensor_tag] = None
# the copying of this group should wait for the computation stream event
if current_group < self.num_offload_group:
......
......@@ -237,9 +237,6 @@ class _GroupedLinear(torch.autograd.Function):
saved_inputmats = inputmats_no_fp8
if cpu_offloading:
if fuse_wgrad_accumulation:
for w in weights:
w.main_grad.weight_offloading = True
if fp8:
for w in weights_fp8:
if w is not None:
......@@ -303,7 +300,7 @@ class _GroupedLinear(torch.autograd.Function):
main_grads = saved_tensors[4 * ctx.num_gemms :]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
for i in ctx.num_gemms:
w = torch.nn.Parameter(weights[i], False)
w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i]
weights[i] = w
......
......@@ -409,7 +409,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight.requires_grad)
weight = torch.nn.Parameter(weight, weight.requires_grad)
weight.main_grad = main_grad
if ctx.ub_overlap_rs_dgrad:
......
......@@ -567,8 +567,8 @@ class _LayerNormMLP(torch.autograd.Function):
)
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
fc1_weight = Parameter(fc1_weight.requires_grad)
fc2_weight = Parameter(fc2_weight.requires_grad)
fc1_weight = Parameter(fc1_weight, fc1_weight.requires_grad)
fc2_weight = Parameter(fc2_weight, fc2_weight.requires_grad)
fc1_weight.main_grad = fc1_weight_main_grad
fc2_weight.main_grad = fc2_weight_main_grad
......
......@@ -401,7 +401,7 @@ class _Linear(torch.autograd.Function):
)
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
weight = torch.nn.Parameter(weight.requires_grad)
weight = torch.nn.Parameter(weight, weight.requires_grad)
weight.main_grad = main_grad
tp_world_size = get_distributed_world_size(ctx.tp_group)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment