Unverified Commit 6f4310d7 authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Added MCore FSDP support for TE (#1890)



* Added MCore fsdp support for TE
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Refactored based on new MCore FSDP
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* Code cleanup and extended across modules
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Added FSDP suport
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@cw-dfw-cs-001-login-01.cm.cluster>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 21b780cc
...@@ -208,9 +208,18 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -208,9 +208,18 @@ class _GroupedLinear(torch.autograd.Function):
ctx.weights_requires_grad = weights[0].requires_grad ctx.weights_requires_grad = weights[0].requires_grad
if fuse_wgrad_accumulation and ctx.weights_requires_grad: if fuse_wgrad_accumulation and ctx.weights_requires_grad:
ctx.main_grads = [weights[i].main_grad for i in range(num_gemms)] # This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if hasattr(weights[0], "__fsdp_param__"):
# MCore FSDP creates main_grad lazily before backward
ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)]
else:
ctx.main_grad_funcs = [
lambda j=i: weights[j].main_grad for i in range(num_gemms)
]
else: else:
ctx.main_grads = [None] * num_gemms ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)]
ctx.device = device ctx.device = device
ctx.grad_output_quantizers = grad_output_quantizers ctx.grad_output_quantizers = grad_output_quantizers
ctx.m_splits = m_splits ctx.m_splits = m_splits
...@@ -246,7 +255,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -246,7 +255,7 @@ class _GroupedLinear(torch.autograd.Function):
weights = saved_tensors[N : 2 * N] weights = saved_tensors[N : 2 * N]
origin_weights = saved_tensors[2 * N : 3 * N] origin_weights = saved_tensors[2 * N : 3 * N]
biases = saved_tensors[3 * N : 4 * N] biases = saved_tensors[3 * N : 4 * N]
main_grads = ctx.main_grads main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
......
...@@ -447,7 +447,14 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -447,7 +447,14 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.requires_wgrad = weight.requires_grad ctx.requires_wgrad = weight.requires_grad
ctx.quantized_weight = quantized_weight ctx.quantized_weight = quantized_weight
if fuse_wgrad_accumulation and weight.requires_grad: if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad # This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if hasattr(weight, "__fsdp_param__"):
# MCore FSDP creates main_grad lazily before backward
ctx.main_grad_func = weight.get_main_grad
else:
ctx.main_grad_func = lambda: weight.main_grad
ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_input_quantizer = grad_input_quantizer
ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer
ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_output_quantizer = grad_output_quantizer
...@@ -528,7 +535,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -528,7 +535,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Since main_grad can be modified inplace, it should not be a part of saved_tensors # Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = ( main_grad = (
ctx.main_grad ctx.main_grad_func()
if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
else None else None
) )
......
...@@ -552,8 +552,20 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -552,8 +552,20 @@ class _LayerNormMLP(torch.autograd.Function):
) )
if fuse_wgrad_accumulation: if fuse_wgrad_accumulation:
ctx.fc1_main_grad = fc1_weight.main_grad if fc1_weight.requires_grad else None # This check is needed to ensure that main_grad is not created
ctx.fc2_main_grad = fc2_weight.main_grad if fc2_weight.requires_grad else None # during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if hasattr(fc1_weight, "__fsdp_param__") and hasattr(fc2_weight, "__fsdp_param__"):
# MCore FSDP creates main_grad lazily before backward
ctx.fc1_main_grad_func = (
fc1_weight.get_main_grad if fc1_weight.requires_grad else lambda: None
)
ctx.fc2_main_grad_func = (
fc2_weight.get_main_grad if fc2_weight.requires_grad else lambda: None
)
else:
ctx.fc1_main_grad_func = lambda: fc1_weight.main_grad
ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad
ctx.save_for_backward(*tensors_to_save) ctx.save_for_backward(*tensors_to_save)
ctx.tensor_objects = tensor_objects ctx.tensor_objects = tensor_objects
...@@ -653,14 +665,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -653,14 +665,14 @@ class _LayerNormMLP(torch.autograd.Function):
# Since main_grad can be modified inplace, it should not be a part of saved_tensors # Since main_grad can be modified inplace, it should not be a part of saved_tensors
fc1_weight_main_grad = ( fc1_weight_main_grad = (
ctx.fc1_main_grad ctx.fc1_main_grad_func()
if fc1_weight is not None if fc1_weight is not None
and ctx.fuse_wgrad_accumulation and ctx.fuse_wgrad_accumulation
and ctx.fc1_weight_requires_grad and ctx.fc1_weight_requires_grad
else None else None
) )
fc2_weight_main_grad = ( fc2_weight_main_grad = (
ctx.fc2_main_grad ctx.fc2_main_grad_func()
if origin_fc2_weight is not None if origin_fc2_weight is not None
and ctx.fuse_wgrad_accumulation and ctx.fuse_wgrad_accumulation
and ctx.fc2_weight_requires_grad and ctx.fc2_weight_requires_grad
......
...@@ -397,7 +397,14 @@ class _Linear(torch.autograd.Function): ...@@ -397,7 +397,14 @@ class _Linear(torch.autograd.Function):
ctx.grad_output_quantizer = grad_output_quantizer ctx.grad_output_quantizer = grad_output_quantizer
ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation
if fuse_wgrad_accumulation and weight.requires_grad: if fuse_wgrad_accumulation and weight.requires_grad:
ctx.main_grad = weight.main_grad # This check is needed to ensure that main_grad is not created
# during the forward pass when using MCore FSDP as it creates
# the main_grad buffer lazily before backprop
if hasattr(weight, "__fsdp_param__"):
# MCore FSDP creates main_grad lazily before backward
ctx.main_grad_func = weight.get_main_grad
else:
ctx.main_grad_func = lambda: weight.main_grad
ctx.debug = debug ctx.debug = debug
ctx.cpu_offloading = cpu_offloading ctx.cpu_offloading = cpu_offloading
...@@ -453,7 +460,7 @@ class _Linear(torch.autograd.Function): ...@@ -453,7 +460,7 @@ class _Linear(torch.autograd.Function):
# Since main_grad can be modified inplace, it should not be a part of saved_tensors # Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = ( main_grad = (
ctx.main_grad ctx.main_grad_func()
if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad if weight is not None and ctx.fuse_wgrad_accumulation and ctx.requires_wgrad
else None else None
) )
......
...@@ -964,6 +964,9 @@ class BasicLinear(BasicOperation): ...@@ -964,6 +964,9 @@ class BasicLinear(BasicOperation):
accumulate_into_main_grad = self._accumulate_into_main_grad accumulate_into_main_grad = self._accumulate_into_main_grad
grad_weight = None grad_weight = None
if ctx.weight_requires_grad and accumulate_into_main_grad: if ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(self.weight, "__fsdp_param__"):
self.weight.main_grad = self.weight.get_main_grad()
if not hasattr(self.weight, "main_grad"): if not hasattr(self.weight, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
......
...@@ -57,6 +57,9 @@ class BackwardLinearAdd(FusedOperation): ...@@ -57,6 +57,9 @@ class BackwardLinearAdd(FusedOperation):
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"): if not hasattr(linear_op.weight, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
......
...@@ -508,6 +508,9 @@ class UserbuffersBackwardLinear(FusedOperation): ...@@ -508,6 +508,9 @@ class UserbuffersBackwardLinear(FusedOperation):
accumulate_into_main_grad = linear_op._accumulate_into_main_grad accumulate_into_main_grad = linear_op._accumulate_into_main_grad
grad_weight = None grad_weight = None
if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad:
if hasattr(linear_op.weight, "__fsdp_param__"):
linear_op.weight.main_grad = linear_op.weight.get_main_grad()
if not hasattr(linear_op.weight, "main_grad"): if not hasattr(linear_op.weight, "main_grad"):
raise RuntimeError( raise RuntimeError(
"BasicLinear op is configured with " "BasicLinear op is configured with "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment