Commit 3273bc20 authored by yuguo's avatar yuguo
Browse files
parents 75e9ef24 521f8d3b
......@@ -304,7 +304,14 @@ class _BatchLinear(torch.autograd.Function):
wgrad_list = [None] * ctx.num_gemms
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad_list = [
torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device)
for w in weights
]
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or (
ctx.wgrad_store is not None
......
......@@ -5,6 +5,7 @@
"""GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List
import warnings
import os
import functools
import torch
......@@ -394,7 +395,11 @@ class _GroupedLinear(torch.autograd.Function):
wgrad_list = [None] * ctx.num_gemms
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad_list = [torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights]
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or (
ctx.wgrad_store is not None
......
......@@ -861,6 +861,9 @@ class _LayerNormLinear(torch.autograd.Function):
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([ln_out_total, grad_output], wgrad_gemm)
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad = torch.empty(weight.size(), dtype=ctx.activation_dtype, device=weight.device)
else:
# Call wgrad GEMM now
......
......@@ -913,6 +913,14 @@ class _LayerNormMLP(torch.autograd.Function):
# Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([act_out, grad_output], fc2_wgrad_gemm)
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
fc2_wgrad = torch.empty(
origin_fc2_weight.shape,
dtype=origin_fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
# Call wgrad GEMM now
......@@ -1166,7 +1174,16 @@ class _LayerNormMLP(torch.autograd.Function):
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([ln_out_total, dact], fc1_wgrad_gemm)
fc1_wgrad = None
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
fc1_wgrad = torch.empty(
origin_fc1_weight.shape,
dtype=origin_fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
fc1_wgrad = None
if fuse_gemm_and_bias_fc1_wgrad:
fc1_bias_grad = None
else:
......
......@@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce
from operator import mul as multiply_op
import warnings
import os
import torch
......@@ -782,6 +783,9 @@ class _Linear(torch.autograd.Function):
"with Userbuffers (tensor-parallel communication overlapping)"
)
ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm)
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad = torch.empty(weight.size(), dtype=ctx.activation_dtype, device=weight.device)
else:
# Call wgrad GEMM now
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment