Commit 177291ac authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.9' into 'develop_v2.9'

grouped_linear supports activation offloading

See merge request dcutoolkit/deeplearing/TransformerEngine!62
parents 99e60246 db0ad945
......@@ -213,6 +213,16 @@ class _GroupedLinear(torch.autograd.Function):
if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True)
for i in range(num_gemms):
weights[i].offloading_activation = False
weights_fp8[i].offloading_activation = False
biases[i].offloading_activation = False
ctx.fine_grained_activation_offloading = fine_grained_activation_offloading
if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
)
if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")
......@@ -225,6 +235,21 @@ class _GroupedLinear(torch.autograd.Function):
ctx.weight_objects = []
for weight in weights:
ctx.weight_objects.append(weight)
if (
fine_grained_activation_offloading
and weights[0].requires_grad
and fuse_wgrad_accumulation
):
grad_added_to_main_grad_list = []
ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")
for weight in weights:
if ctx.grad_added_to_main_grad:
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
ctx.weight_objects.append(weight)
else:
grad_added_to_main_grad_list.append(None)
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list
tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats,
......@@ -288,12 +313,15 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
if ctx.cpu_offloading:
if ctx.cpu_offloading or ctx.fine_grained_activation_offloading:
if ctx.grad_added_to_main_grad:
for i, weight in enumerate(ctx.weight_objects):
origin_weights[i] = ctx.weight_objects[i]
ctx.weight_objects[i] = None
if ctx.fine_grained_activation_offloading:
origin_weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
if ctx.fuse_wgrad_accumulation:
for i in range(N):
origin_weights[i].main_grad = main_grads[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment