"test/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "2d8791324db0a764ff7088ba63741ad4f801abb4"
Commit cd4cdf80 authored by dongcl's avatar dongcl
Browse files

support activation offloading

parent 79fa3eba
...@@ -157,6 +157,29 @@ class _BatchLinear(torch.autograd.Function): ...@@ -157,6 +157,29 @@ class _BatchLinear(torch.autograd.Function):
for t in saved_inputmats: for t in saved_inputmats:
if t is not None: if t is not None:
t.activation_offloading = True t.activation_offloading = True
offload_activation = False
if hasattr(inp, "offloading_activation"):
offload_activation = True
for i in range(num_gemms):
saved_inputmats[i].offloading_activation = inp.offloading_activation
ctx.offload_activation = offload_activation
if offload_activation and cpu_offloading:
raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time."
)
if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation:
grad_added_to_main_grad_list = []
for weight in weights:
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"):
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
else:
grad_added_to_main_grad_list.append(None)
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list
ctx.save_for_backward( ctx.save_for_backward(
None, None,
*saved_inputmats, *saved_inputmats,
...@@ -164,7 +187,7 @@ class _BatchLinear(torch.autograd.Function): ...@@ -164,7 +187,7 @@ class _BatchLinear(torch.autograd.Function):
*weights, *weights,
*weights_fp8, *weights_fp8,
*[ *[
w.main_grad if cpu_offloading and fuse_wgrad_accumulation else None w.main_grad if (cpu_offloading or offload_activation) and fuse_wgrad_accumulation else None
for w in weights for w in weights
], ],
) )
...@@ -203,11 +226,14 @@ class _BatchLinear(torch.autograd.Function): ...@@ -203,11 +226,14 @@ class _BatchLinear(torch.autograd.Function):
weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms] weights = saved_tensors[2 * ctx.num_gemms : 3 * ctx.num_gemms]
weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms] weights_fp8 = saved_tensors[3 * ctx.num_gemms : 4 * ctx.num_gemms]
main_grads = saved_tensors[4 * ctx.num_gemms :] main_grads = saved_tensors[4 * ctx.num_gemms :]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], False) w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i] w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
if ctx.offload_activation and weights[i].requires_grad:
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT global _GEMM_INPUT, _GEMM_WEIGHT, _GRAD_OUTPUT
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
grad_output_mats = torch.split( grad_output_mats = torch.split(
......
...@@ -211,6 +211,28 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -211,6 +211,28 @@ class _GroupedLinear(torch.autograd.Function):
if isinstance(weight, QuantizedTensorBase): if isinstance(weight, QuantizedTensorBase):
weight.update_usage(columnwise_usage=True) weight.update_usage(columnwise_usage=True)
offload_activation = False
if hasattr(inp, "offloading_activation"):
offload_activation = True
for i in range(num_gemms):
inputmats[i].offloading_activation = inp.offloading_activation
ctx.offload_activation = offload_activation
if offload_activation and cpu_offloading:
raise ValueError(
f"Do not use offload_activation and cpu_offloading at the same time."
)
if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation:
grad_added_to_main_grad_list = []
for weight in weights:
if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"):
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
else:
grad_added_to_main_grad_list.append(None)
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats, *inputmats,
*weights_fp8, *weights_fp8,
...@@ -273,11 +295,13 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -273,11 +295,13 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N] biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation:
for i in range(ctx.num_gemms): for i in range(ctx.num_gemms):
w = torch.nn.Parameter(weights[i], weights[i].requires_grad) w = torch.nn.Parameter(weights[i], weights[i].requires_grad)
w.main_grad = main_grads[i]
weights[i] = w weights[i] = w
weights[i].main_grad = main_grads[i]
if ctx.offload_activation and weights[0].requires_grad:
weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
# Preprocess grad output # Preprocess grad output
grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment