Unverified Commit b14a3b62 authored by Kunlun Li's avatar Kunlun Li Committed by GitHub
Browse files

Make FP8 weights compatible with older MCore version (#2342)



* Make cast_master_weights_to_fp8 compatible with older MCore version
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* Rename keep_columnwise to manual_post_all_gather_processing & Optimize unit test
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove redundant _test_mini_optimizer()
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>

---------
Signed-off-by: default avatarkunlunl <kunlunl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent f3b97c26
...@@ -48,7 +48,12 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): ...@@ -48,7 +48,12 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
def cast_master_weights_to_fp8( def cast_master_weights_to_fp8(
model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None model_weights,
master_weights,
start_offsets,
group,
fsdp_shard_model_weights=None,
manual_post_all_gather_processing=False,
): ):
r"""Helper function to cast master weights to FP8 primary weights. r"""Helper function to cast master weights to FP8 primary weights.
...@@ -69,6 +74,11 @@ def cast_master_weights_to_fp8( ...@@ -69,6 +74,11 @@ def cast_master_weights_to_fp8(
fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are
not sharded. Otherwise, it means that the model weights are sharded and we get not sharded. Otherwise, it means that the model weights are sharded and we get
target model weights data storage using the FSDP shard model weights. target model weights data storage using the FSDP shard model weights.
manual_post_all_gather_processing: bool, default = `False`.
If False, post processing will be automatically triggered during next forward.
If True, the timing of calling post_all_gather_processing is left to the user.
Note that users must call `post_all_gather_processing` if it's set to True,
otherwise the weights won't be updated correctly.
""" """
...@@ -129,21 +139,18 @@ def cast_master_weights_to_fp8( ...@@ -129,21 +139,18 @@ def cast_master_weights_to_fp8(
f"cast_master_weights_to_fp8 for {type(quantizer)} is not supported yet" f"cast_master_weights_to_fp8 for {type(quantizer)} is not supported yet"
) )
extra_args = [group, use_fsdp_shard_model_weights, manual_post_all_gather_processing]
if len(delayed_scaling_params) > 0: if len(delayed_scaling_params) > 0:
_cast_master_weights_to_fp8_delayed_scaling( _cast_master_weights_to_fp8_delayed_scaling(delayed_scaling_params, *extra_args)
delayed_scaling_params, group, use_fsdp_shard_model_weights
)
if len(current_scaling_params) > 0: if len(current_scaling_params) > 0:
_cast_master_weights_to_fp8_current_scaling( _cast_master_weights_to_fp8_current_scaling(current_scaling_params, *extra_args)
current_scaling_params, group, use_fsdp_shard_model_weights
)
if len(blockwise_scaling_params) > 0: if len(blockwise_scaling_params) > 0:
_cast_master_weights_to_fp8_blockwise_scaling( _cast_master_weights_to_fp8_blockwise_scaling(blockwise_scaling_params, *extra_args)
blockwise_scaling_params, group, use_fsdp_shard_model_weights
)
def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False): def _cast_master_weights_to_fp8_delayed_scaling(
params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False
):
r"""Helper function to cast master weights to FP8 primary weights for delayed scaling. r"""Helper function to cast master weights to FP8 primary weights for delayed scaling.
Parameters Parameters
...@@ -160,6 +167,13 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo ...@@ -160,6 +167,13 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo
amaxes, scales, scale_invs = [], [], [] amaxes, scales, scale_invs = [], [], []
for model_weight, master_weight, start_offset, shard_model_weight_raw in params: for model_weight, master_weight, start_offset, shard_model_weight_raw in params:
if not manual_post_all_gather_processing:
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to
# overlap the all-gather of model weights and forward process, so the model weight is
# not updated currently.
model_weight._reset_caches()
quantizer = model_weight._get_quantizer() quantizer = model_weight._get_quantizer()
amaxes.append(quantizer.amax.view(1)) amaxes.append(quantizer.amax.view(1))
...@@ -219,7 +233,9 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo ...@@ -219,7 +233,9 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo
) )
def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_model_weights=False): def _cast_master_weights_to_fp8_current_scaling(
params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False
):
r"""Helper function to cast master weights to FP8 primary weights for current scaling. r"""Helper function to cast master weights to FP8 primary weights for current scaling.
Parameters Parameters
...@@ -297,6 +313,13 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo ...@@ -297,6 +313,13 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
params, scales params, scales
): ):
if not manual_post_all_gather_processing:
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to
# overlap the all-gather of model weights and forward process, so the model weight is
# not updated currently.
model_weight._reset_caches()
# If master weight is None, it means that the master weight of the current model weight # If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks. # is in other DP ranks.
if master_weight is None: if master_weight is None:
...@@ -322,7 +345,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo ...@@ -322,7 +345,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
def _cast_master_weights_to_fp8_blockwise_scaling( def _cast_master_weights_to_fp8_blockwise_scaling(
params, group, use_fsdp_shard_model_weights=False params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False
): ):
r"""Helper function to cast master weights to FP8 primary weights for blockwise scaling. r"""Helper function to cast master weights to FP8 primary weights for blockwise scaling.
...@@ -421,6 +444,13 @@ def _cast_master_weights_to_fp8_blockwise_scaling( ...@@ -421,6 +444,13 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
params, scales params, scales
): ):
if not manual_post_all_gather_processing:
# Clear columnwise data for all model weights.
# We cannot create columnwise data here because users (like megatron) may want to
# overlap the all-gather of model weights and forward process, so the model weight is
# not updated at this moment.
model_weight.update_usage(rowwise_usage=True, columnwise_usage=False)
# If master weight is None, it means that the master weight of the current model weight # If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks. # is in other DP ranks.
if master_weight is None: if master_weight is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment