Unverified Commit d52ed471 authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

FSDP2 Allgather Perf improvement and support for FusedAdam with FSDP2 (#2370)



* fix ci issue
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert back testing changes
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* remove quantizer copy + fused adam working
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix test
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix mxfp8 bug, god knows who created it
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/optimizers/fused_adam.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update comment
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 89cc2a7e
...@@ -11,6 +11,7 @@ from typing import Optional ...@@ -11,6 +11,7 @@ from typing import Optional
import warnings import warnings
import torch import torch
from torch.distributed._tensor import DTensor
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
...@@ -567,8 +568,10 @@ class FusedAdam(torch.optim.Optimizer): ...@@ -567,8 +568,10 @@ class FusedAdam(torch.optim.Optimizer):
unscaled_lists[name].append(unscaled) unscaled_lists[name].append(unscaled)
scaled_lists[name].append(state[name]) scaled_lists[name].append(state[name])
state_scales[name].append(self._scales[p][name]) state_scales[name].append(self._scales[p][name])
if isinstance(p, Float8Tensor) or (
if isinstance(p, Float8Tensor): isinstance(p, DTensor) and isinstance(p._local_tensor, Float8Tensor)
):
p = p._local_tensor if isinstance(p, DTensor) else p
out_dtype = p._fp8_dtype out_dtype = p._fp8_dtype
p_fp8_model.append(p._data.data) p_fp8_model.append(p._data.data)
scale, amax, scale_inv = get_fp8_meta(p) scale, amax, scale_inv = get_fp8_meta(p)
......
...@@ -713,9 +713,8 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -713,9 +713,8 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
[transpose, t_shape] + list(args[2:]), [transpose, t_shape] + list(args[2:]),
kwargs, kwargs,
) )
# deep copy the scale inverse tensor and quantizer as well.
scale_inv = tensor._scale_inv.detach().clone() scale_inv = tensor._scale_inv.detach().clone()
quantizer = tensor._quantizer.copy() quantizer = tensor._quantizer # Deep-copied in constructor
out_tensor = Float8Tensor( out_tensor = Float8Tensor(
data=func_out, data=func_out,
shape=func_out.shape, shape=func_out.shape,
...@@ -820,7 +819,7 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -820,7 +819,7 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
# sure that updated Quantized weight tensor have same scale inverse across all shards. # sure that updated Quantized weight tensor have same scale inverse across all shards.
self._quantizer.amax_reduction_group = mesh.get_group() self._quantizer.amax_reduction_group = mesh.get_group()
self._quantizer.with_amax_reduction = True self._quantizer.with_amax_reduction = True
quantizer = self._quantizer.copy() # quantizer to be used for allgathered weights
fsdp_state = _get_module_fsdp_state(module) fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
# If weights are resharded after forward pass, then its enough to set the quantizer usages # If weights are resharded after forward pass, then its enough to set the quantizer usages
...@@ -833,9 +832,13 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -833,9 +832,13 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
is_backward_pass = training_state == TrainingState.PRE_BACKWARD is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# In case of hopper/L40, only one of data/transpose is needed # In case of hopper/L40, only one of data/transpose is needed
# based on forward or backward pass. So setting the quantizer usages appropriately. # based on forward or backward pass. So setting the quantizer usages appropriately.
quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass) rowwise_usage = not is_backward_pass
columnwise_usage = is_backward_pass
else:
rowwise_usage = True
columnwise_usage = self._quantizer.columnwise_usage
sharded_tensors = (self._data,) sharded_tensors = (self._data,)
metadata = (self._scale_inv, self._fp8_dtype, quantizer) metadata = (self._scale_inv, rowwise_usage, columnwise_usage, self._fp8_dtype)
return sharded_tensors, metadata return sharded_tensors, metadata
def fsdp_post_all_gather( def fsdp_post_all_gather(
...@@ -861,7 +864,7 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -861,7 +864,7 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
""" """
(data,) = all_gather_outputs (data,) = all_gather_outputs
(fp8_scale_inv, fp8_dtype, quantizer) = metadata (fp8_scale_inv, rowwise_usage, columnwise_usage, fp8_dtype) = metadata
orig_shape = data.size() orig_shape = data.size()
# Quantizer has only columnwise usage set for backward pass # Quantizer has only columnwise usage set for backward pass
# In Blackwell+ architectures, transpose is not needed at all, # In Blackwell+ architectures, transpose is not needed at all,
...@@ -870,20 +873,27 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor): ...@@ -870,20 +873,27 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
if out is not None: if out is not None:
out._data = data out._data = data
else: else:
# We ll be here when post all gather is called the first time.
# Float8Tensor constructor makes a copy of the quantizer to
# save as its own quantizer. For the consequent iterations,
# the same quantizer is used. Copy is needed in the first iteration,
# since we need different quantizers for sharded and allgathered tensors.
# and self._quantizer belongs to the sharded parameter.
fp8_args = { fp8_args = {
"shape": orig_shape, "shape": orig_shape,
"dtype": param_dtype, "dtype": param_dtype,
"fp8_scale_inv": fp8_scale_inv, "fp8_scale_inv": fp8_scale_inv,
"fp8_dtype": fp8_dtype, "fp8_dtype": fp8_dtype,
"quantizer": quantizer, "quantizer": self._quantizer,
"requires_grad": False, "requires_grad": False,
"data": data, "data": data,
} }
out = Float8Tensor(**fp8_args) out = Float8Tensor(**fp8_args)
out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage)
out.update_usage( out.update_usage(
rowwise_usage=quantizer.rowwise_usage, rowwise_usage=rowwise_usage,
columnwise_usage=quantizer.columnwise_usage, columnwise_usage=columnwise_usage,
) )
return out, all_gather_outputs return out, all_gather_outputs
......
...@@ -552,7 +552,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -552,7 +552,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
rowwise_scale_inv=rowwise_scale_inv, rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data, columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv, columnwise_scale_inv=columnwise_scale_inv,
quantizer=tensor._quantizer.copy(), quantizer=tensor._quantizer,
requires_grad=False, requires_grad=False,
fp8_dtype=tensor._fp8_dtype, fp8_dtype=tensor._fp8_dtype,
) )
...@@ -583,7 +583,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -583,7 +583,6 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
fsdp_state = _get_module_fsdp_state(module) fsdp_state = _get_module_fsdp_state(module)
reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward
quantizer = self._quantizer.copy()
# Remove padding from scale inverses before allgather # Remove padding from scale inverses before allgather
# Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128] # Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128]
rowwise_scale_inv = self._rowwise_scale_inv rowwise_scale_inv = self._rowwise_scale_inv
...@@ -601,9 +600,8 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -601,9 +600,8 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
if columnwise_scale_inv.size(0) != flattened_in_shape0: if columnwise_scale_inv.size(0) != flattened_in_shape0:
columnwise_scale_inv = columnwise_scale_inv[:flattened_in_shape0] columnwise_scale_inv = columnwise_scale_inv[:flattened_in_shape0]
sharded_tensors = (self._rowwise_data, rowwise_scale_inv) # If weights are resharded after forward pass, then its enough to send one row/col
# If weights are resharded after forward pass, then its enough to set the quantizer usages # usage based on whether its forward or backward pass for the allgathered weights.
# based on whether its forward or backward pass for the allgathered weights.
# If not resharded after forward pass, the same weights allgathered in forward # If not resharded after forward pass, the same weights allgathered in forward
# are used again in backward. And hence if we need the columnwise data/scale_inv, # are used again in backward. And hence if we need the columnwise data/scale_inv,
# we need to send them as well for allgather in forward pass itself. # we need to send them as well for allgather in forward pass itself.
...@@ -611,18 +609,24 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -611,18 +609,24 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
training_state = fsdp_state._fsdp_param_group._training_state training_state = fsdp_state._fsdp_param_group._training_state
is_backward_pass = training_state == TrainingState.PRE_BACKWARD is_backward_pass = training_state == TrainingState.PRE_BACKWARD
# Allgather only the necessary tensors based on forward/backward pass # Allgather only the necessary tensors based on forward/backward pass
quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass) rowwise_usage = not is_backward_pass
columnwise_usage = is_backward_pass
sharded_tensors = ( sharded_tensors = (
(self._columnwise_data, columnwise_scale_inv) (self._columnwise_data, columnwise_scale_inv)
if is_backward_pass if is_backward_pass
else sharded_tensors else (self._rowwise_data, rowwise_scale_inv)
) )
else: else:
if quantizer.columnwise_usage: # rowwise usage is always needed for forward pass.
rowwise_usage = True
sharded_tensors = (self._rowwise_data, rowwise_scale_inv)
columnwise_usage = self._quantizer.columnwise_usage
if columnwise_usage:
# If weights are not resharded after forward, then both # If weights are not resharded after forward, then both
# rowwise and columnwise data/scale_inv need to be allgathered. # rowwise and columnwise data/scale_inv need to be allgathered.
sharded_tensors += (self._columnwise_data, columnwise_scale_inv) sharded_tensors += (self._columnwise_data, columnwise_scale_inv)
metadata = (self._fp8_dtype, quantizer)
metadata = (self._fp8_dtype, rowwise_usage, columnwise_usage)
return sharded_tensors, metadata return sharded_tensors, metadata
def fsdp_post_all_gather( def fsdp_post_all_gather(
...@@ -645,12 +649,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -645,12 +649,10 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
Tuple[MXFP8Tensor, Tuple[torch.Tensor, ...]]: Allgathered MXFP8Tensor and tuple of internal tensors Tuple[MXFP8Tensor, Tuple[torch.Tensor, ...]]: Allgathered MXFP8Tensor and tuple of internal tensors
used by the MXFP8Tensor that was being computed after allgather. used by the MXFP8Tensor that was being computed after allgather.
""" """
fp8_dtype, quantizer = metadata fp8_dtype, rowwise_usage, columnwise_usage = metadata
rowwise_data, rowwise_scale_inv = ( rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] if rowwise_usage else (None, None)
all_gather_outputs[:2] if quantizer.rowwise_usage else (None, None)
)
columnwise_data, columnwise_scale_inv = ( columnwise_data, columnwise_scale_inv = (
all_gather_outputs[-2:] if quantizer.columnwise_usage else (None, None) all_gather_outputs[-2:] if columnwise_usage else (None, None)
) )
# Add padding to scale_inv tensors to be multiples of [128, 4]for rowwise and [4, 128] for columnwise # Add padding to scale_inv tensors to be multiples of [128, 4]for rowwise and [4, 128] for columnwise
...@@ -675,8 +677,13 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -675,8 +677,13 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
out._rowwise_scale_inv = rowwise_scale_inv out._rowwise_scale_inv = rowwise_scale_inv
out._columnwise_data = columnwise_data out._columnwise_data = columnwise_data
out._columnwise_scale_inv = columnwise_scale_inv out._columnwise_scale_inv = columnwise_scale_inv
out._quantizer = quantizer
else: else:
# We ll be here when post all gather is called the first time.
# MXFP8Tensor constructor makes a copy of the quantizer to
# save as its own quantizer. For the consequent iterations,
# the same quantizer is used. Copy is needed in the first iteration,
# since we need different quantizers for sharded and allgathered tensors.
# and self._quantizer belongs to the sharded parameter.
out = MXFP8Tensor( out = MXFP8Tensor(
rowwise_data=rowwise_data, rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv, rowwise_scale_inv=rowwise_scale_inv,
...@@ -685,9 +692,9 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): ...@@ -685,9 +692,9 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
dtype=param_dtype, dtype=param_dtype,
shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape,
quantizer=quantizer, quantizer=self._quantizer,
) )
out._quantizer.set_usage(rowwise=rowwise_usage, columnwise=columnwise_usage)
return out, all_gather_outputs return out, all_gather_outputs
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment