Commit 3273bc20 authored by yuguo's avatar yuguo
Browse files
parents 75e9ef24 521f8d3b
...@@ -304,7 +304,14 @@ class _BatchLinear(torch.autograd.Function): ...@@ -304,7 +304,14 @@ class _BatchLinear(torch.autograd.Function):
wgrad_list = [None] * ctx.num_gemms wgrad_list = [None] * ctx.num_gemms
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms # overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad_list = [
torch.empty(w.size(), dtype=ctx.activation_dtype, device=w.device)
for w in weights
]
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or ( if not ctx.use_bias or (
ctx.wgrad_store is not None ctx.wgrad_store is not None
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
"""GroupedLinear API""" """GroupedLinear API"""
from typing import Union, Optional, Callable, Tuple, List from typing import Union, Optional, Callable, Tuple, List
import warnings import warnings
import os
import functools import functools
import torch import torch
...@@ -394,7 +395,11 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -394,7 +395,11 @@ class _GroupedLinear(torch.autograd.Function):
wgrad_list = [None] * ctx.num_gemms wgrad_list = [None] * ctx.num_gemms
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
wgrad_list = [None] * ctx.num_gemms # overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad_list = [torch.empty(w.size(), dtype=ctx.activation_dtype, device=ctx.device) for w in weights]
else:
wgrad_list = [None] * ctx.num_gemms
if not ctx.use_bias or ( if not ctx.use_bias or (
ctx.wgrad_store is not None ctx.wgrad_store is not None
......
...@@ -861,6 +861,9 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -861,6 +861,9 @@ class _LayerNormLinear(torch.autograd.Function):
"with Userbuffers (tensor-parallel communication overlapping)" "with Userbuffers (tensor-parallel communication overlapping)"
) )
ctx.wgrad_store.put([ln_out_total, grad_output], wgrad_gemm) ctx.wgrad_store.put([ln_out_total, grad_output], wgrad_gemm)
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad = torch.empty(weight.size(), dtype=ctx.activation_dtype, device=weight.device)
else: else:
# Call wgrad GEMM now # Call wgrad GEMM now
......
...@@ -913,6 +913,14 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -913,6 +913,14 @@ class _LayerNormMLP(torch.autograd.Function):
# Choose whether to call wgrad GEMM now or delay # Choose whether to call wgrad GEMM now or delay
if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute():
ctx.wgrad_store.put([act_out, grad_output], fc2_wgrad_gemm) ctx.wgrad_store.put([act_out, grad_output], fc2_wgrad_gemm)
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
fc2_wgrad = torch.empty(
origin_fc2_weight.shape,
dtype=origin_fc2_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else: else:
# Call wgrad GEMM now # Call wgrad GEMM now
...@@ -1166,7 +1174,16 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1166,7 +1174,16 @@ class _LayerNormMLP(torch.autograd.Function):
"with Userbuffers (tensor-parallel communication overlapping)" "with Userbuffers (tensor-parallel communication overlapping)"
) )
ctx.wgrad_store.put([ln_out_total, dact], fc1_wgrad_gemm) ctx.wgrad_store.put([ln_out_total, dact], fc1_wgrad_gemm)
fc1_wgrad = None # overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
fc1_wgrad = torch.empty(
origin_fc1_weight.shape,
dtype=origin_fc1_weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
)
else:
fc1_wgrad = None
if fuse_gemm_and_bias_fc1_wgrad: if fuse_gemm_and_bias_fc1_wgrad:
fc1_bias_grad = None fc1_bias_grad = None
else: else:
......
...@@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Tuple, Union ...@@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Tuple, Union
from functools import reduce from functools import reduce
from operator import mul as multiply_op from operator import mul as multiply_op
import warnings import warnings
import os
import torch import torch
...@@ -782,6 +783,9 @@ class _Linear(torch.autograd.Function): ...@@ -782,6 +783,9 @@ class _Linear(torch.autograd.Function):
"with Userbuffers (tensor-parallel communication overlapping)" "with Userbuffers (tensor-parallel communication overlapping)"
) )
ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm) ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm)
# overlap_grad_reduce, dongcl
if int(os.getenv("NVTE_OVERLAP_GRAD_REDUCE", "0")):
wgrad = torch.empty(weight.size(), dtype=ctx.activation_dtype, device=weight.device)
else: else:
# Call wgrad GEMM now # Call wgrad GEMM now
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment