Commit 177291ac authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.9' into 'develop_v2.9'

grouped_linear supports activation offloading

See merge request dcutoolkit/deeplearing/TransformerEngine!62
parents 99e60246 db0ad945
...@@ -213,6 +213,16 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -213,6 +213,16 @@ class _GroupedLinear(torch.autograd.Function):
if isinstance(weight, QuantizedTensorStorage): if isinstance(weight, QuantizedTensorStorage):
weight.update_usage(columnwise_usage=True) weight.update_usage(columnwise_usage=True)
for i in range(num_gemms):
weights[i].offloading_activation = False
weights_fp8[i].offloading_activation = False
biases[i].offloading_activation = False
ctx.fine_grained_activation_offloading = fine_grained_activation_offloading
if fine_grained_activation_offloading and cpu_offloading:
raise ValueError(
f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time."
)
if cpu_offloading: if cpu_offloading:
ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad") ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")
...@@ -225,6 +235,21 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -225,6 +235,21 @@ class _GroupedLinear(torch.autograd.Function):
ctx.weight_objects = [] ctx.weight_objects = []
for weight in weights: for weight in weights:
ctx.weight_objects.append(weight) ctx.weight_objects.append(weight)
if (
fine_grained_activation_offloading
and weights[0].requires_grad
and fuse_wgrad_accumulation
):
grad_added_to_main_grad_list = []
ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad")
for weight in weights:
if ctx.grad_added_to_main_grad:
grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad)
weight.grad_added_to_main_grad = True
ctx.weight_objects.append(weight)
else:
grad_added_to_main_grad_list.append(None)
ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list
tensors_to_save, tensor_objects = prepare_for_saving( tensors_to_save, tensor_objects = prepare_for_saving(
*inputmats, *inputmats,
...@@ -288,12 +313,15 @@ class _GroupedLinear(torch.autograd.Function): ...@@ -288,12 +313,15 @@ class _GroupedLinear(torch.autograd.Function):
biases = saved_tensors[3 * N : 4 * N] biases = saved_tensors[3 * N : 4 * N]
main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs]
if ctx.cpu_offloading: if ctx.cpu_offloading or ctx.fine_grained_activation_offloading:
if ctx.grad_added_to_main_grad: if ctx.grad_added_to_main_grad:
for i, weight in enumerate(ctx.weight_objects): for i, weight in enumerate(ctx.weight_objects):
origin_weights[i] = ctx.weight_objects[i] origin_weights[i] = ctx.weight_objects[i]
ctx.weight_objects[i] = None ctx.weight_objects[i] = None
if ctx.fine_grained_activation_offloading:
origin_weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i]
if ctx.fuse_wgrad_accumulation: if ctx.fuse_wgrad_accumulation:
for i in range(N): for i in range(N):
origin_weights[i].main_grad = main_grads[i] origin_weights[i].main_grad = main_grads[i]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment