"mmgen/vscode:/vscode.git/clone" did not exist on "1401de15d079af4d9d9f995f2d57ddb6d930d7f0"
Commit 521f8d3b authored by yuguo's avatar yuguo
Browse files

[DCU] combine 1f1b needs NVTE_OVERLAP_GRAD_REDUCE

parent 291fcf52
...@@ -304,7 +304,14 @@ class _BatchLinear(torch.autograd.Function): ...@@ -304,7 +304,14 @@ class _BatchLinear(torch.autograd.Function):
wgrad_list = [None] * ctx.num_gemms wgrad_list = [None] * ctx.num_gemms
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms # overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad_list = [
torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device)
for w in weights
]
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or ( if not ctx.use_bias or (
ctx.wgrad_store is not None ctx.wgrad_store is not None
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""GroupedLinear API""" """GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List from typing import Union, Optional, Callable, Tuple, List
import os
import functools import functools
import torch import torch
...@@ -393,7 +393,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -393,7 +393,11 @@ class _GroupedLinear(torch.autograd.Function):
wgrad_list = [None] * ctx.num_gemms wgrad_list = [None] * ctx.num_gemms
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms # overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad_list = [torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights]
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or ( if not ctx.use_bias or (
ctx.wgrad_store is not None ctx.wgrad_store is not None
......
...@@ -783,6 +783,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -783,6 +783,9 @@ class _LayerNormLinear(torch.autograd.Function):
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad) ctx.wgrad_store.put([ln_out_total, grad_output], general_gemm_wgrad)
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad = torch.empty(weight.size(), dtype=ctx.activation_dtype, device=weight.device)
else: else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output) wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(ln_out_total, grad_output)
......
...@@ -843,7 +843,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -843,7 +843,16 @@ class _LayerNormMLP(torch.autograd.Function):
) )
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad) ctx.wgrad_store.put([act_out, grad_output], general_gemm_fc2_wgrad)
fc2_wgrad = None # overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
fc2_wgrad = torch.empty(
origin_fc2_weight.shape,
dtype=origin_fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
fc2_wgrad = None
else: else:
fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad( fc2_wgrad, fc2_bias_grad_, *_ = general_gemm_fc2_wgrad(
act_out, act_out,
...@@ -1057,7 +1066,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1057,7 +1066,16 @@ class _LayerNormMLP(torch.autograd.Function):
) )
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad) ctx.wgrad_store.put([ln_out_total, dact], general_gemm_fc1_wgrad)
fc1_wgrad = None # overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
fc1_wgrad = torch.empty(
origin_fc1_weight.shape,
dtype=origin_fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
fc1_wgrad = None
if fuse_gemm_and_bias_fc1_wgrad: if fuse_gemm_and_bias_fc1_wgrad:
fc1_bias_grad = None fc1_bias_grad = None
else: else:
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from typing import Callable, Dict, Optional, Tuple, Union from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import os
import functools import functools
import torch import torch
...@@ -699,6 +699,9 @@ class _Linear(torch.autograd.Function): ...@@ -699,6 +699,9 @@ class _Linear(torch.autograd.Function):
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad) ctx.wgrad_store.put([inputmat_total, grad_output], general_gemm_wgrad)
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad = torch.empty(weight.size(), dtype=ctx.activation_dtype, device=weight.device)
else: else:
wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output) wgrad, grad_bias_, _, rs_out = general_gemm_wgrad(inputmat_total, grad_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment