Commit 521f8d3b authored by yuguo's avatar yuguo
Browse files

[DCU] combine 1f1b needs NVTE_OVERLAP_GRAD_REDUCE

parent 291fcf52
......@@ -304,7 +304,14 @@ class _BatchLinear(torch.autograd.Function):
wgrad_list = [None] * ctx.num_gemms
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad_list = [
torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device)
for w in weights
]
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or (
ctx.wgrad_store is not None
......
......@@ -4,7 +4,7 @@
"""GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List
import os
import functools
import torch
......@@ -393,7 +393,11 @@ class _GroupedLinear(torch.autograd.Function):
wgrad_list = [None] * ctx.num_gemms
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad_list = [torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights]
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or (
ctx.wgrad_store is not None
......
......@@ -783,6 +783,9 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad)
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad = torch.empty(weight.size(), dtype=ctx.activation_dtype, device=weight.device)
else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output)
......
......@@ -843,7 +843,16 @@ class _LayerNormMLP(torch.autograd.Function):
)
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad)
fc2_wgrad = None
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
fc2_wgrad = torch.empty(
origin_fc2_weight.shape,
dtype=origin_fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
fc2_wgrad = None
else:
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad(
act_out,
......@@ -1057,7 +1066,16 @@ class _LayerNormMLP(torch.autograd.Function):
)
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad)
fc1_wgrad = None
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
fc1_wgrad = torch.empty(
origin_fc1_weight.shape,
dtype=origin_fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
fc1_wgrad = None
if fuse_gemm_and_bias_fc1_wgrad:
fc1_bias_grad = None
else:
......
......@@ -6,7 +6,7 @@
from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import os
import functools
import torch
......@@ -699,6 +699,9 @@ class _Linear(torch.autograd.Function):
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad)
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad = torch.empty(weight.size(), dtype=ctx.activation_dtype, device=weight.device)
else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment