Commit 96a104d5 authored by wenjh's avatar wenjh
Browse files

Merge develop_v2.10


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents abec28e8 0fce42f7
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
......
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
......
...@@ -10,7 +10,9 @@ ...@@ -10,7 +10,9 @@
#include <memory> #include <memory>
#include <random> #include <random>
#include <limits> #include <limits>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
#endif #endif
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) #define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
......
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
************************************************************************/ ************************************************************************/
#include <assert.h> #include <assert.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h> #include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
......
...@@ -7,6 +7,9 @@ ...@@ -7,6 +7,9 @@
#include <limits> #include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype. // Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <assert.h> #include <assert.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h> #include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
......
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
************************************************************************/ ************************************************************************/
#include <assert.h> #include <assert.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_fp8.h> #include <cuda_fp8.h>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype. // Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <transformer_engine/multi_tensor.h> #include <transformer_engine/multi_tensor.h>
......
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
************************************************************************/ ************************************************************************/
#include <assert.h> #include <assert.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_runtime.h>
#endif
#include <cuda_fp8.h> #include <cuda_fp8.h>
#include <transformer_engine/multi_tensor.h> #include <transformer_engine/multi_tensor.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
......
...@@ -95,7 +95,6 @@ class _BatchLinear(torch.autograd.Function): ...@@ -95,7 +95,6 @@ class _BatchLinear(torch.autograd.Function):
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
is_grad_enabled: bool, is_grad_enabled: bool,
fine_grained_activation_offloading,
*weights_and_biases: Union[Float8Tensor, torch.Tensor, None], *weights_and_biases: Union[Float8Tensor, torch.Tensor, None],
) -> torch.Tensor: ) -> torch.Tensor:
batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2")) batch_num = int(os.getenv("NVTE_MOE_BATCHCOUNT", "2"))
...@@ -160,33 +159,6 @@ class _BatchLinear(torch.autograd.Function): ...@@ -160,33 +159,6 @@ class _BatchLinear(torch.autograd.Function):
if t is not None: if t is not None:
t.activation_offloading = True t.activation_offloading = True
for i in range(num_gemms):
weights[i].offloading_activation = False
weights[i].main_grad.offloading_activation = False
if weights_fp8[i] is not None:
weights_fp8[i].offloading_activation = False
ctx.fine_grained_activation_offloading = fine_grained_activation_offloading
if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
)
if (
fine_grained_activation_offloading
and weights[0].requires_grad
and fuse_wgrad_accumulation
):
grad_added_to_main_grad_list = []
for weight in weights:
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"):
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
else:
grad_added_to_main_grad_list.append(None)
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list
ctx.save_for_backward( ctx.save_for_backward(
None, None,
*saved_inputmats, *saved_inputmats,
...@@ -194,7 +166,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -194,7 +166,7 @@ class _BatchLinear(torch.autograd.Function):
*weights, *weights,
*weights_fp8, *weights_fp8,
*[ *[
w.main_grad if (cpu_offloading or fine_grained_activation_offloading) and fuse_wgrad_accumulation else None w.main_grad if cpu_offloading and fuse_wgrad_accumulation else None
for w in weights for w in weights
], ],
) )
...@@ -233,13 +205,11 @@ class _BatchLinear(torch.autograd.Function): ...@@ -233,13 +205,11 @@ class _BatchLinear(torch.autograd.Function):
weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms] weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms]
weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms] weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms]
main_grads = saved_tensors[4 * ctx.num_gemms :] main_grads = saved_tensors[4 * ctx.num_gemms :]
if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation: if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
if ctx.fine_grained_activation_offloading and weights[i].requires_grad:
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
...@@ -371,7 +341,6 @@ class _BatchLinear(torch.autograd.Function): ...@@ -371,7 +341,6 @@ class _BatchLinear(torch.autograd.Function):
None, # activation_dtype None, # activation_dtype
None, # parallel_mode None, # parallel_mode
None, # is_grad_enabled None, # is_grad_enabled
None, # fine_grained_activation_offloading
*wgrad_list, *wgrad_list,
*([None] * ctx.num_gemms), # weights_fp8 *([None] * ctx.num_gemms), # weights_fp8
*grad_biases, *grad_biases,
...@@ -462,7 +431,6 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -462,7 +431,6 @@ class BatchedLinear(TransformerEngineBaseModule):
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
ub_overlap_rs: bool = False, ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
fine_grained_activation_offloading: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
) -> None: ) -> None:
...@@ -486,8 +454,6 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -486,8 +454,6 @@ class BatchedLinear(TransformerEngineBaseModule):
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.fine_grained_activation_offloading = fine_grained_activation_offloading
self.wgrad_store = WeightGradStore(delay_wgrad_compute) self.wgrad_store = WeightGradStore(delay_wgrad_compute)
global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT global _GEMM_INPUT, _GEMM_WEIGHT, _GEMM_OUTPUT
...@@ -665,7 +631,6 @@ class BatchedLinear(TransformerEngineBaseModule): ...@@ -665,7 +631,6 @@ class BatchedLinear(TransformerEngineBaseModule):
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
torch.is_grad_enabled(), torch.is_grad_enabled(),
self.fine_grained_activation_offloading,
*weight_tensors, *weight_tensors,
*weight_tensors_fp8, *weight_tensors_fp8,
*bias_tensors, *bias_tensors,
......
...@@ -90,7 +90,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -90,7 +90,6 @@ class _GroupedLinear(torch.autograd.Function):
module, module,
skip_fp8_weight_update, skip_fp8_weight_update,
save_original_input, save_original_input,
fine_grained_activation_offloading,
) = non_tensor_args ) = non_tensor_args
num_gemms = len(m_splits) num_gemms = len(m_splits)
...@@ -225,16 +224,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -225,16 +224,6 @@ class _GroupedLinear(torch.autograd.Function):
else: else:
inputmats = [None] * num_gemms inputmats = [None] * num_gemms
for i in range(num_gemms):
weights[i].offloading_activation = False
weights_fp8[i].offloading_activation = False
biases[i].offloading_activation = False
ctx.fine_grained_activation_offloading = fine_grained_activation_offloading
if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
)
if cpu_offloading: if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad") ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")
...@@ -247,21 +236,6 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -247,21 +236,6 @@ class _GroupedLinear(torch.autograd.Function):
ctx.weight_objects = [] ctx.weight_objects = []
for weight in weights: for weight in weights:
ctx.weight_objects.append(weight) ctx.weight_objects.append(weight)
if (
fine_grained_activation_offloading
and weights[0].requires_grad
and fuse_wgrad_accumulation
):
grad_added_to_main_grad_list = []
ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")
for weight in weights:
if ctx.grad_added_to_main_grad:
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
ctx.weight_objects.append(weight)
else:
grad_added_to_main_grad_list.append(None)
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats, *inputmats,
...@@ -325,15 +299,12 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -325,15 +299,12 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N] biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: if ctx.cpu_offloading:
if ctx.grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
for i, weight in enumerate(ctx.weight_objects): for i, weight in enumerate(ctx.weight_objects):
origin_weights[i] = ctx.weight_objects[i] origin_weights[i] = ctx.weight_objects[i]
ctx.weight_objects[i] = None ctx.weight_objects[i] = None
if ctx.fine_grained_activation_offloading:
origin_weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
if ctx.fuse_wgrad_accumulation: if ctx.fuse_wgrad_accumulation:
for i in range(N): for i in range(N):
origin_weights[i].main_grad = main_grads[i] origin_weights[i].main_grad = main_grads[i]
...@@ -614,7 +585,6 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -614,7 +585,6 @@ class GroupedLinear(TransformerEngineBaseModule):
ub_overlap_rs: bool = False, ub_overlap_rs: bool = False,
ub_overlap_ag: bool = False, ub_overlap_ag: bool = False,
ub_name: Optional[str] = None, ub_name: Optional[str] = None,
fine_grained_activation_offloading: bool = False,
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
save_original_input: bool = False, save_original_input: bool = False,
) -> None: ) -> None:
...@@ -637,7 +607,6 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -637,7 +607,6 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support Userbuffer overlap." ), "GroupedLinear doesn't support Userbuffer overlap."
self.get_rng_state_tracker = get_rng_state_tracker self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name self.rng_tracker_name = rng_tracker_name
self.fine_grained_activation_offloading = fine_grained_activation_offloading
self.wgrad_store = WeightGradStore(delay_wgrad_compute) self.wgrad_store = WeightGradStore(delay_wgrad_compute)
...@@ -850,7 +819,6 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -850,7 +819,6 @@ class GroupedLinear(TransformerEngineBaseModule):
self, self,
None, # skip_fp8_weight_update None, # skip_fp8_weight_update
self.save_original_input, self.save_original_input,
self.fine_grained_activation_offloading,
) )
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors) out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
......
...@@ -40,7 +40,6 @@ from ..utils import ( ...@@ -40,7 +40,6 @@ from ..utils import (
requires_grad, requires_grad,
needs_quantized_gemm, needs_quantized_gemm,
get_nvtx_range_context, get_nvtx_range_context,
get_activation_offloading,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -144,7 +143,6 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -144,7 +143,6 @@ class _LayerNormLinear(torch.autograd.Function):
ub_bulk_wgrad, ub_bulk_wgrad,
ub_bulk_dgrad, ub_bulk_dgrad,
ub_name, ub_name,
fine_grained_activation_offloading,
fsdp_group, fsdp_group,
module, module,
skip_fp8_weight_update, skip_fp8_weight_update,
...@@ -598,11 +596,10 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -598,11 +596,10 @@ class _LayerNormLinear(torch.autograd.Function):
# For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one. # we need to connect them into one.
if ctx.cpu_offloading or ctx.fine_grained_activation_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")): if ctx.cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.has_grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
origin_weight = ctx.weight_object origin_weight = ctx.weight_object
if ctx.fine_grained_activation_offloading:
origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
origin_weight.main_grad = main_grad origin_weight.main_grad = main_grad
...@@ -1180,7 +1177,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1180,7 +1177,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
delay_wgrad_compute: bool = False, delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None, symmetric_ar_type: Optional[str] = None,
name: str = None, name: str = None,
fine_grained_activation_offloading: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1199,7 +1195,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1199,7 +1195,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
) )
self.zero_centered_gamma = zero_centered_gamma self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type self.symmetric_ar_type = symmetric_ar_type
self.fine_grained_activation_offloading = fine_grained_activation_offloading
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name self.name = name
...@@ -1600,7 +1595,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1600,7 +1595,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_name, self.ub_name,
self.fine_grained_activation_offloading,
self.fsdp_group, self.fsdp_group,
self, self,
skip_fp8_weight_update, skip_fp8_weight_update,
......
...@@ -39,7 +39,6 @@ from ..utils import ( ...@@ -39,7 +39,6 @@ from ..utils import (
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
get_nvtx_range_context, get_nvtx_range_context,
get_activation_offloading,
) )
from ..distributed import ( from ..distributed import (
set_tensor_model_parallel_attributes, set_tensor_model_parallel_attributes,
...@@ -123,7 +122,6 @@ class _Linear(torch.autograd.Function): ...@@ -123,7 +122,6 @@ class _Linear(torch.autograd.Function):
ub_bulk_dgrad, ub_bulk_dgrad,
ub_bulk_wgrad, ub_bulk_wgrad,
ub_name, ub_name,
fine_grained_activation_offloading,
fp8_output, # pylint: disable=unused-variable fp8_output, # pylint: disable=unused-variable
fsdp_group, fsdp_group,
module, module,
...@@ -420,30 +418,10 @@ class _Linear(torch.autograd.Function): ...@@ -420,30 +418,10 @@ class _Linear(torch.autograd.Function):
) )
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")
ctx.fine_grained_activation_offloading = fine_grained_activation_offloading
if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
)
if (
fine_grained_activation_offloading
and weight.requires_grad
and fuse_wgrad_accumulation
):
if hasattr(weight, "grad_added_to_main_grad"):
ctx.has_grad_added_to_main_grad = True
ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad
weight.grad_added_to_main_grad = True
ctx.weight_object = weight
else:
ctx.has_grad_added_to_main_grad = False
if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")): if cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad")
if ctx.has_grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
# If you are passing torch.nn.Parameter through the Torch hooks, you will # If you are passing torch.nn.Parameter through the Torch hooks, you will
# get back torch.Tensor. Torch rips off the Parameter wrapper. # get back torch.Tensor. Torch rips off the Parameter wrapper.
# You need to preserve the weight object to have all the attributes user # You need to preserve the weight object to have all the attributes user
...@@ -540,11 +518,10 @@ class _Linear(torch.autograd.Function): ...@@ -540,11 +518,10 @@ class _Linear(torch.autograd.Function):
else None else None
) )
if ctx.cpu_offloading or ctx.fine_grained_activation_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")): if ctx.cpu_offloading or int(os.getenv("NVTE_SWAP_OVERLAP_GRAD", "0")):
if ctx.has_grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
weight = ctx.weight_object weight = ctx.weight_object
if ctx.fine_grained_activation_offloading:
weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad
if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation:
weight.main_grad = main_grad weight.main_grad = main_grad
...@@ -1124,7 +1101,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1124,7 +1101,6 @@ class Linear(TransformerEngineBaseModule):
symmetric_ar_type: Optional[str] = None, symmetric_ar_type: Optional[str] = None,
save_original_input: bool = False, save_original_input: bool = False,
name: Optional[str] = None, name: Optional[str] = None,
fine_grained_activation_offloading: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1140,7 +1116,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1140,7 +1116,6 @@ class Linear(TransformerEngineBaseModule):
self.symmetric_ar_type = symmetric_ar_type self.symmetric_ar_type = symmetric_ar_type
self.save_original_input = save_original_input self.save_original_input = save_original_input
self.name = name self.name = name
self.fine_grained_activation_offloading = fine_grained_activation_offloading
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
...@@ -1487,7 +1462,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1487,7 +1462,6 @@ class Linear(TransformerEngineBaseModule):
self.ub_bulk_dgrad, self.ub_bulk_dgrad,
self.ub_bulk_wgrad, self.ub_bulk_wgrad,
self.ub_name, self.ub_name,
self.fine_grained_activation_offloading,
fp8_output, fp8_output,
self.fsdp_group, self.fsdp_group,
self, self,
......
...@@ -824,30 +824,3 @@ def make_weak_ref(x): ...@@ -824,30 +824,3 @@ def make_weak_ref(x):
if x is None: if x is None:
return None return None
raise TypeError(f"Invalid type {type(x)} to make weak ref") raise TypeError(f"Invalid type {type(x)} to make weak ref")
ActivationOffloadEnabled = False
def get_activation_offloading():
global ActivationOffloadEnabled
return ActivationOffloadEnabled
def set_activation_offloading(activation_offloading):
global ActivationOffloadEnabled
ActivationOffloadEnabled = activation_offloading
class ActivationOffloadContextManager:
"""A reusable context manager for switch ActivationOffloadEnabled"""
def __init__(self, activation_offloading):
self.activation_offloading = activation_offloading
def __enter__(self):
self.origin_cpu_offloading = get_activation_offloading()
set_activation_offloading(self.activation_offloading)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
set_activation_offloading(self.origin_cpu_offloading)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment