Commit bd5a6e86 authored by tabuchixiangcai3's avatar tabuchixiangcai3
Browse files

[DCU]fix main_grad no exit


Signed-off-by: tabuchixiangcai3's avatarTangao <2205747538@qq.com>
parent 29271c40
...@@ -153,13 +153,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -153,13 +153,7 @@ class _BatchLinear(torch.autograd.Function):
if cpu_offloading: if cpu_offloading:
if fuse_wgrad_accumulation: if fuse_wgrad_accumulation:
for w in weights: for w in weights:
if getattr(w, "main_grad", None) is not None:
w.main_grad.weight_offloading = True w.main_grad.weight_offloading = True
else:
# Optional: log a warning if fuse requested but buffer missing
# logger = logging.getLogger("BatchLinear")
# logger.debug("fuse_wgrad_accumulation=True but weight.main_grad is missing; skipping weight_offloading for this weight.")
pass
for w in weights: for w in weights:
w.weight_offloading = True w.weight_offloading = True
for t in saved_inputmats: for t in saved_inputmats:
...@@ -168,7 +162,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -168,7 +162,7 @@ class _BatchLinear(torch.autograd.Function):
for i in range(num_gemms): for i in range(num_gemms):
weights[i].offloading_activation = False weights[i].offloading_activation = False
if getattr(weights[i], "main_grad", None) is not None: if fuse_wgrad_accumulation and hasattr(weights[i], 'main_grad'):
weights[i].main_grad.offloading_activation = False weights[i].main_grad.offloading_activation = False
if weights_fp8[i] is not None: if weights_fp8[i] is not None:
weights_fp8[i].offloading_activation = False weights_fp8[i].offloading_activation = False
...@@ -561,16 +555,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -561,16 +555,7 @@ class BatchedLinear(TransformerEngineBaseModule):
if self.primary_weights_in_fp8: if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms) self.init_fp8_metadata(num_gemms=self.num_gemms)
# Ensure main_grad buffers exist when fuse_wgrad_accumulation is enabled.
# Skip allocation under meta device (deferred init).
self.reset_parameters(defer_init=(device == "meta")) self.reset_parameters(defer_init=(device == "meta"))
if self.fuse_wgrad_accumulation and device != "meta":
for i in range(int(self.num_gemms)):
w = getattr(self, f"weight{i}")
if getattr(w, "main_grad", None) is None:
# use float32 buffer for main_grad (tests use float32)
w.main_grad = torch.empty_like(w, dtype=torch.float32, device=w.device)
w.main_grad.zero_()
# For RPL, bias has to be added after TP collectives # For RPL, bias has to be added after TP collectives
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment