"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "06eebf668b91336d535723f1f864a117e9af817f"
Commit 4b65dfa3 authored by yuguo's avatar yuguo
Browse files

Merge branch 'develop_v2.8' into 'develop_v2.8'

Update activation offload code to align with the official version

See merge request dcutoolkit/deeplearing/TransformerEngine!52
parents 712d526a 9711d439
...@@ -13,6 +13,7 @@ import transformer_engine_torch as tex ...@@ -13,6 +13,7 @@ import transformer_engine_torch as tex
from .base import ( from .base import (
get_multi_stream_cublas_batchgemm_workspace, get_multi_stream_cublas_batchgemm_workspace,
get_dummy_wgrad,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
...@@ -94,6 +95,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -94,6 +95,7 @@ class _BatchLinear(torch.autograd.Function):
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
is_grad_enabled: bool, is_grad_enabled: bool,
fine_grained_activation_offloading,
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor: ) -> torch.Tensor:
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2")) batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
...@@ -158,19 +160,24 @@ class _BatchLinear(torch.autograd.Function): ...@@ -158,19 +160,24 @@ class _BatchLinear(torch.autograd.Function):
if t is not None: if t is not None:
t.activation_offloading = True t.activation_offloading = True
offload_activation = False
if hasattr(inp, "offloading_activation"):
offload_activation = True
for i in range(num_gemms): for i in range(num_gemms):
saved_inputmats[i].offloading_activation = inp.offloading_activation weights[i].offloading_activation = False
ctx.offload_activation = offload_activation weights[i].main_grad.offloading_activation = False
if weights_fp8[i] is not None:
weights_fp8[i].offloading_activation = False
if offload_activation and cpu_offloading: ctx.fine_grained_activation_offloading = fine_grained_activation_offloading
if fine_grained_activation_offloading and cpu_offloading:
raise ValueError( raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time." f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
) )
if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation: if (
fine_grained_activation_offloading
and weights[0].requires_grad
and fuse_wgrad_accumulation
):
grad_added_to_main_grad_list = [] grad_added_to_main_grad_list = []
for weight in weights: for weight in weights:
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"): if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"):
...@@ -187,7 +194,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -187,7 +194,7 @@ class _BatchLinear(torch.autograd.Function):
*weights, *weights,
*weights_fp8, *weights_fp8,
*[ *[
w.main_grad if (cpu_offloading or offload_activation) and fuse_wgrad_accumulation else None w.main_grad if (cpu_offloading or fine_grained_activation_offloading) and fuse_wgrad_accumulation else None
for w in weights for w in weights
], ],
) )
...@@ -226,12 +233,12 @@ class _BatchLinear(torch.autograd.Function): ...@@ -226,12 +233,12 @@ class _BatchLinear(torch.autograd.Function):
weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms] weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms]
weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms] weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms]
main_grads = saved_tensors[4 * ctx.num_gemms :] main_grads = saved_tensors[4 * ctx.num_gemms :]
if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation: if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
if ctx.offload_activation and weights[i].requires_grad: if ctx.fine_grained_activation_offloading and weights[i].requires_grad:
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT
...@@ -304,18 +311,15 @@ class _BatchLinear(torch.autograd.Function): ...@@ -304,18 +311,15 @@ class _BatchLinear(torch.autograd.Function):
if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"): if ctx.fuse_wgrad_accumulation and hasattr(w, "grad_added_to_main_grad"):
w.grad_added_to_main_grad = True w.grad_added_to_main_grad = True
if getattr(w, "zero_out_wgrad", False): if getattr(w, "zero_out_wgrad", False):
wgrad = torch.zeros( wgrad = get_dummy_wgrad(
w.main_grad.shape, list(w.main_grad.shape),
dtype=w.dtype, w.dtype,
device=torch.cuda.current_device(), zero=True,
requires_grad=False,
) )
else: else:
wgrad = torch.empty( wgrad = get_dummy_wgrad(
w.main_grad.shape, list(w.main_grad.shape),
dtype=w.dtype, w.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
) )
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
wgrad = None wgrad = None
...@@ -367,6 +371,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -367,6 +371,7 @@ class _BatchLinear(torch.autograd.Function):
None, # activation_dtype None, # activation_dtype
None, # parallel_mode None, # parallel_mode
None, # is_grad_enabled None, # is_grad_enabled
None, # fine_grained_activation_offloading
*wgrad_list, *wgrad_list,
*([None] * ctx.num_gemms), # weights_fp8 *([None] * ctx.num_gemms), # weights_fp8
*grad_biases, *grad_biases,
...@@ -457,6 +462,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -457,6 +462,7 @@ class BatchedLinear(TransformerEngineBaseModule):
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_overlap_rs: bool = False, ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
fine_grained_activation_offloading: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
) -> None: ) -> None:
...@@ -480,6 +486,8 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -480,6 +486,8 @@ class BatchedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.fine_grained_activation_offloading = fine_grained_activation_offloading
self.wgrad_store = WeightGradStore(delay_wgrad_compute) self.wgrad_store = WeightGradStore(delay_wgrad_compute)
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
...@@ -657,6 +665,7 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -657,6 +665,7 @@ class BatchedLinear(TransformerEngineBaseModule):
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.fine_grained_activation_offloading,
*weight_tensors, *weight_tensors,
*weight_tensors_fp8, *weight_tensors_fp8,
*bias_tensors, *bias_tensors,
......
...@@ -15,6 +15,7 @@ import transformer_engine_torch as tex ...@@ -15,6 +15,7 @@ import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe from transformer_engine.common.recipe import Recipe
from .base import ( from .base import (
get_multi_stream_cublas_workspace, get_multi_stream_cublas_workspace,
get_dummy_wgrad,
TransformerEngineBaseModule, TransformerEngineBaseModule,
_2X_ACC_FPROP, _2X_ACC_FPROP,
_2X_ACC_DGRAD, _2X_ACC_DGRAD,
...@@ -82,6 +83,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -82,6 +83,7 @@ class _GroupedLinear(torch.autograd.Function):
module, module,
skip_fp8_weight_update, skip_fp8_weight_update,
save_original_input, save_original_input,
fine_grained_activation_offloading,
*weights_and_biases, *weights_and_biases,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
...@@ -211,19 +213,22 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -211,19 +213,22 @@ class _GroupedLinear(torch.autograd.Function):
if isinstance(weight, QuantizedTensorBase): if isinstance(weight, QuantizedTensorBase):
weight.update_usage(columnwise_usage=True) weight.update_usage(columnwise_usage=True)
offload_activation = False
if hasattr(inp, "offloading_activation"):
offload_activation = True
for i in range(num_gemms): for i in range(num_gemms):
inputmats[i].offloading_activation = inp.offloading_activation weights[i].offloading_activation = False
ctx.offload_activation = offload_activation weights_fp8[i].offloading_activation = False
biases[i].offloading_activation = False
ctx.fine_grained_activation_offloading = fine_grained_activation_offloading
if offload_activation and cpu_offloading: if fine_grained_activation_offloading and cpu_offloading:
raise ValueError( raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time." f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
) )
if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation: if (
fine_grained_activation_offloading
and weights[0].requires_grad
and fuse_wgrad_accumulation
):
grad_added_to_main_grad_list = [] grad_added_to_main_grad_list = []
for weight in weights: for weight in weights:
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"): if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"):
...@@ -295,12 +300,12 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -295,12 +300,12 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N] biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation: if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
if ctx.offload_activation and weights[0].requires_grad: if ctx.fine_grained_activation_offloading and weights[0].requires_grad:
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
# Preprocess grad output # Preprocess grad output
...@@ -452,18 +457,15 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -452,18 +457,15 @@ class _GroupedLinear(torch.autograd.Function):
): ):
weight.grad_added_to_main_grad = True weight.grad_added_to_main_grad = True
if getattr(weight, "zero_out_wgrad", False): if getattr(weight, "zero_out_wgrad", False):
wgrad = torch.zeros( wgrad = get_dummy_wgrad(
weight.main_grad.shape, list(weight.main_grad.shape),
dtype=weight.dtype, weight.dtype,
device=torch.cuda.current_device(), zero=True,
requires_grad=False,
) )
else: else:
wgrad = torch.empty( wgrad = get_dummy_wgrad(
weight.main_grad.shape, list(weight.main_grad.shape),
dtype=weight.dtype, weight.dtype,
device=torch.cuda.current_device(),
requires_grad=False,
) )
elif ctx.fuse_wgrad_accumulation: elif ctx.fuse_wgrad_accumulation:
wgrad = None wgrad = None
...@@ -514,6 +516,7 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -514,6 +516,7 @@ class _GroupedLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
*wgrad_list, *wgrad_list,
*grad_biases, *grad_biases,
) )
...@@ -595,6 +598,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -595,6 +598,7 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_rs: bool = False, ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
fine_grained_activation_offloading: bool = False,
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
save_original_input: bool = False, save_original_input: bool = False,
) -> None: ) -> None:
...@@ -617,6 +621,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -617,6 +621,7 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support Userbuffer overlap." ), "GroupedLinear doesn't support Userbuffer overlap."
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.fine_grained_activation_offloading = fine_grained_activation_offloading
self.wgrad_store = WeightGradStore(delay_wgrad_compute) self.wgrad_store = WeightGradStore(delay_wgrad_compute)
...@@ -836,6 +841,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -836,6 +841,7 @@ class GroupedLinear(TransformerEngineBaseModule):
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
self.save_original_input, self.save_original_input,
self.fine_grained_activation_offloading,
*weight_tensors, *weight_tensors,
*bias_tensors, *bias_tensors,
) )
......
...@@ -130,6 +130,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -130,6 +130,7 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_wgrad: bool, ub_bulk_wgrad: bool,
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_name: str, ub_name: str,
fine_grained_activation_offloading: bool,
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module, module: torch.nn.Module,
skip_fp8_weight_update: bool, skip_fp8_weight_update: bool,
...@@ -435,21 +436,25 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -435,21 +436,25 @@ class _LayerNormLinear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
offload_activation = False # Do not offload weights and biases
if get_activation_offloading(): weight.offloading_activation = False
offload_activation = True weightmat.offloading_activation = False
if not inputmat.is_contiguous(): if bias is not None:
inputmat = inputmat.contiguous() bias.offloading_activation = False
inputmat.offloading_activation = True ln_weight.offloading_activation = False
ctx.fine_grained_activation_offloading = fine_grained_activation_offloading
ctx.offload_activation = offload_activation
if offload_activation and cpu_offloading: if fine_grained_activation_offloading and cpu_offloading:
raise ValueError( raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time." f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
) )
if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: if (
fine_grained_activation_offloading
and weight.requires_grad
and fuse_wgrad_accumulation
):
if hasattr(weight, "grad_added_to_main_grad"): if hasattr(weight, "grad_added_to_main_grad"):
ctx.has_grad_added_to_main_grad = True ctx.has_grad_added_to_main_grad = True
ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad
...@@ -594,10 +599,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -594,10 +599,10 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one. # we need to connect them into one.
if ctx.cpu_offloading or ctx.offload_activation or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")): if ctx.cpu_offloading or ctx.fine_grained_activation_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.has_grad_added_to_main_grad: if ctx.has_grad_added_to_main_grad:
origin_weight = ctx.weight_object origin_weight = ctx.weight_object
if ctx.offload_activation: if ctx.fine_grained_activation_offloading:
origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad origin_weight.main_grad = main_grad
...@@ -1074,6 +1079,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -1074,6 +1079,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, # ub_bulk_dgrad None, # ub_bulk_dgrad
None, # ub_bulk_wgrad None, # ub_bulk_wgrad
None, # ub_name None, # ub_name
None, # fine_grained_activation_offloading
None, # fsdp_group None, # fsdp_group
None, # debug None, # debug
None, # module None, # module
...@@ -1209,6 +1215,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1209,6 +1215,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None, symmetric_ar_type: Optional[str] = None,
name: str = None, name: str = None,
fine_grained_activation_offloading: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1227,6 +1234,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1227,6 +1234,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
) )
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type self.symmetric_ar_type = symmetric_ar_type
self.fine_grained_activation_offloading = fine_grained_activation_offloading
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name self.name = name
...@@ -1630,6 +1638,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1630,6 +1638,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_name, self.ub_name,
self.fine_grained_activation_offloading,
self.fsdp_group, self.fsdp_group,
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
......
...@@ -111,6 +111,7 @@ class _Linear(torch.autograd.Function): ...@@ -111,6 +111,7 @@ class _Linear(torch.autograd.Function):
ub_bulk_dgrad: bool, ub_bulk_dgrad: bool,
ub_bulk_wgrad: bool, ub_bulk_wgrad: bool,
ub_name: str, ub_name: str,
fine_grained_activation_offloading: bool,
fp8_output: bool, # pylint: disable=unused-argument fp8_output: bool, # pylint: disable=unused-argument
fsdp_group: Union[dist_group_type, None], fsdp_group: Union[dist_group_type, None],
module: torch.nn.Module, module: torch.nn.Module,
...@@ -404,25 +405,25 @@ class _Linear(torch.autograd.Function): ...@@ -404,25 +405,25 @@ class _Linear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
offload_activation = False ctx.fine_grained_activation_offloading = fine_grained_activation_offloading
if get_activation_offloading():
offload_activation = True
if not saved_inputmat.is_contiguous():
saved_inputmat = saved_inputmat.contiguous()
saved_inputmat.offloading_activation = True
ctx.offload_activation = offload_activation
if offload_activation and cpu_offloading: if fine_grained_activation_offloading and cpu_offloading:
raise ValueError( raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time." f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
) )
if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: if (
ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") fine_grained_activation_offloading
if ctx.has_grad_added_to_main_grad: and weight.requires_grad
and fuse_wgrad_accumulation
):
if hasattr(weight, "grad_added_to_main_grad"):
ctx.has_grad_added_to_main_grad = True
ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad
ctx.weight_object = weight
weight.grad_added_to_main_grad = True weight.grad_added_to_main_grad = True
ctx.weight_object = weight
else:
ctx.has_grad_added_to_main_grad = False
if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")): if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
...@@ -435,6 +436,12 @@ class _Linear(torch.autograd.Function): ...@@ -435,6 +436,12 @@ class _Linear(torch.autograd.Function):
# weights if weights are externally touched outside this module # weights if weights are externally touched outside this module
ctx.weight_object = weight ctx.weight_object = weight
# Do not offload weights and biases
weight.offloading_activation = False
weightmat.offloading_activation = False
if bias is not None:
bias.offloading_activation = False
# TODO(ksivamani): Check memory usage # TODO(ksivamani): Check memory usage
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
saved_inputmat, saved_inputmat,
...@@ -522,10 +529,10 @@ class _Linear(torch.autograd.Function): ...@@ -522,10 +529,10 @@ class _Linear(torch.autograd.Function):
else None else None
) )
if ctx.cpu_offloading or ctx.offload_activation or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")): if ctx.cpu_offloading or ctx.fine_grained_activation_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.has_grad_added_to_main_grad: if ctx.has_grad_added_to_main_grad:
weight = ctx.weight_object weight = ctx.weight_object
if ctx.offload_activation: if ctx.fine_grained_activation_offloading:
weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad weight.main_grad = main_grad
...@@ -1009,6 +1016,7 @@ class _Linear(torch.autograd.Function): ...@@ -1009,6 +1016,7 @@ class _Linear(torch.autograd.Function):
None, # ub_bulk_dgrad None, # ub_bulk_dgrad
None, # ub_bulk_wgrad None, # ub_bulk_wgrad
None, # ub_name None, # ub_name
None, # fine_grained_activation_offloading
None, # fp8_output None, # fp8_output
None, # fsdp_group None, # fsdp_group
None, # module None, # module
...@@ -1131,6 +1139,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1131,6 +1139,7 @@ class Linear(TransformerEngineBaseModule):
symmetric_ar_type: Optional[str] = None, symmetric_ar_type: Optional[str] = None,
save_original_input: bool = False, save_original_input: bool = False,
name: Optional[str] = None, name: Optional[str] = None,
fine_grained_activation_offloading: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1146,6 +1155,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1146,6 +1155,7 @@ class Linear(TransformerEngineBaseModule):
self.symmetric_ar_type = symmetric_ar_type self.symmetric_ar_type = symmetric_ar_type
self.save_original_input = save_original_input self.save_original_input = save_original_input
self.name = name self.name = name
self.fine_grained_activation_offloading = fine_grained_activation_offloading
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
...@@ -1493,6 +1503,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1493,6 +1503,7 @@ class Linear(TransformerEngineBaseModule):
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_name, self.ub_name,
self.fine_grained_activation_offloading,
fp8_output, fp8_output,
self.fsdp_group, self.fsdp_group,
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment