Unverified Commit d52ed471 authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

FSDP2 Allgather Perf improvement and support for FusedAdam with FSDP2 (#2370)



* fix ci issue
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert back testing changes
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* remove quantizer copy + fused adam working
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix test
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix mxfp8 bug, god knows who created it
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/optimizers/fused_adam.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update comment
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 89cc2a7e
......@@ -11,6 +11,7 @@ from typing import Optional
import warnings
import torch
from torch.distributed._tensor import DTensor
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from .multi_tensor_apply import multi_tensor_applier
......@@ -567,8 +568,10 @@ class FusedAdam(torch.optim.Optimizer):
unscaled_lists[name].append(unscaled)
scaled_lists[name].append(state[name])
state_scales[name].append(self._scales[p][name])
if isinstance(p, Float8Tensor):
if isinstance(p, Float8Tensor) or (
isinstance(p, DTensor) and isinstance(p._local_tensor, Float8Tensor)
):
p = p._local_tensor if isinstance(p, DTensor) else p
out_dtype = p._fp8_dtype
p_fp8_model.append(p._data.data)
scale, amax, scale_inv = get_fp8_meta(p)
......
......@@ -713,9 +713,8 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
[transpose, t_shape] + list(args[2:]),
kwargs,
)
# deep copy the scale inverse tensor and quantizer as well.
scale_inv = tensor._scale_inv.detach().clone()
quantizer = tensor._quantizer.copy()
quantizer = tensor._quantizer # Deep-copied in constructor
out_tensor = Float8Tensor(
data=func_out,
shape=func_out.shape,
......@@ -820,7 +819,7 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
# sure that updated Quantized weight tensor have same scale inverse across all shards.
self._quantizer.amax_reduction_group = mesh.get_group()
self._quantizer.with_amax_reduction = True
quantizer = self._quantizer.copy() # quantizer to be used for allgathered weights
fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
# If weights are resharded after forward pass, then its enough to set the quantizer usages
......@@ -833,9 +832,13 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# In case of hopper/L40, only one of data/transpose is needed
# based on forward or backward pass. So setting the quantizer usages appropriately.
quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass)
rowwise_usage = not is_backward_pass
columnwise_usage = is_backward_pass
else:
rowwise_usage = True
columnwise_usage = self._quantizer.columnwise_usage
sharded_tensors = (self._data,)
metadata = (self._scale_inv, self._fp8_dtype, quantizer)
metadata = (self._scale_inv, rowwise_usage, columnwise_usage, self._fp8_dtype)
return sharded_tensors, metadata
def fsdp_post_all_gather(
......@@ -861,7 +864,7 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
"""
(data,) = all_gather_outputs
(fp8_scale_inv, fp8_dtype, quantizer) = metadata
(fp8_scale_inv, rowwise_usage, columnwise_usage, fp8_dtype) = metadata
orig_shape = data.size()
# Quantizer has only columnwise usage set for backward pass
# In Blackwell+ architectures, transpose is not needed at all,
......@@ -870,20 +873,27 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
if out is not None:
out._data = data
else:
# We ll be here when post all gather is called the first time.
# Float8Tensor constructor makes a copy of the quantizer to
# save as its own quantizer. For the consequent iterations,
# the same quantizer is used. Copy is needed in the first iteration,
# since we need different quantizers for sharded and allgathered tensors.
# and self._quantizer belongs to the sharded parameter.
fp8_args = {
"shape": orig_shape,
"dtype": param_dtype,
"fp8_scale_inv": fp8_scale_inv,
"fp8_dtype": fp8_dtype,
"quantizer": quantizer,
"quantizer": self._quantizer,
"requires_grad": False,
"data": data,
}
out = Float8Tensor(**fp8_args)
out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage)
out.update_usage(
rowwise_usage=quantizer.rowwise_usage,
columnwise_usage=quantizer.columnwise_usage,
rowwise_usage=rowwise_usage,
columnwise_usage=columnwise_usage,
)
return out, all_gather_outputs
......
......@@ -552,7 +552,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
quantizer=tensor._quantizer.copy(),
quantizer=tensor._quantizer,
requires_grad=False,
fp8_dtype=tensor._fp8_dtype,
)
......@@ -583,7 +583,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
quantizer = self._quantizer.copy()
# Remove padding from scale inverses before allgather
# Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128]
rowwise_scale_inv = self._rowwise_scale_inv
......@@ -601,9 +600,8 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
if columnwise_scale_inv.size(0) != flattened_in_shape0:
columnwise_scale_inv = columnwise_scale_inv[:flattened_in_shape0]
sharded_tensors = (self._rowwise_data, rowwise_scale_inv)
# If weights are resharded after forward pass, then its enough to set the quantizer usages
# based on whether its forward or backward pass for the allgathered weights.
# If weights are resharded after forward pass, then its enough to send one row/col
# usage based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward. And hence if we need the columnwise data/scale_inv,
# we need to send them as well for allgather in forward pass itself.
......@@ -611,18 +609,24 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
training_state = fsdp_state._fsdp_param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# Allgather only the necessary tensors based on forward/backward pass
quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass)
rowwise_usage = not is_backward_pass
columnwise_usage = is_backward_pass
sharded_tensors = (
(self._columnwise_data, columnwise_scale_inv)
if is_backward_pass
else sharded_tensors
else (self._rowwise_data, rowwise_scale_inv)
)
else:
if quantizer.columnwise_usage:
# rowwise usage is always needed for forward pass.
rowwise_usage = True
sharded_tensors = (self._rowwise_data, rowwise_scale_inv)
columnwise_usage = self._quantizer.columnwise_usage
if columnwise_usage:
# If weights are not resharded after forward, then both
# rowwise and columnwise data/scale_inv need to be allgathered.
sharded_tensors += (self._columnwise_data, columnwise_scale_inv)
metadata = (self._fp8_dtype, quantizer)
metadata = (self._fp8_dtype, rowwise_usage, columnwise_usage)
return sharded_tensors, metadata
def fsdp_post_all_gather(
......@@ -645,12 +649,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
Tuple[MXFP8Tensor, Tuple[torch.Tensor, ...]]: Allgathered MXFP8Tensor and tuple of internal tensors
used by the MXFP8Tensor that was being computed after allgather.
"""
fp8_dtype, quantizer = metadata
rowwise_data, rowwise_scale_inv = (
all_gather_outputs[:2] if quantizer.rowwise_usage else (None, None)
)
fp8_dtype, rowwise_usage, columnwise_usage = metadata
rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] if rowwise_usage else (None, None)
columnwise_data, columnwise_scale_inv = (
all_gather_outputs[-2:] if quantizer.columnwise_usage else (None, None)
all_gather_outputs[-2:] if columnwise_usage else (None, None)
)
# Add padding to scale_inv tensors to be multiples of [128, 4]for rowwise and [4, 128] for columnwise
......@@ -675,8 +677,13 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
out._rowwise_scale_inv = rowwise_scale_inv
out._columnwise_data = columnwise_data
out._columnwise_scale_inv = columnwise_scale_inv
out._quantizer = quantizer
else:
# We ll be here when post all gather is called the first time.
# MXFP8Tensor constructor makes a copy of the quantizer to
# save as its own quantizer. For the consequent iterations,
# the same quantizer is used. Copy is needed in the first iteration,
# since we need different quantizers for sharded and allgathered tensors.
# and self._quantizer belongs to the sharded parameter.
out = MXFP8Tensor(
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
......@@ -685,9 +692,9 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
fp8_dtype=fp8_dtype,
dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
quantizer=quantizer,
quantizer=self._quantizer,
)
out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage)
return out, all_gather_outputs
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment