Unverified Commit fdb87afc authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Reset recipe state in fusible operations when FP8 amax history length changes (#1985)



* Fix bug where TE ops were not updating fp8_meta dicts
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename reset_recipe_state function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update error message when initializing meta device quantized weight without recipe
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent d1967d55
......@@ -712,6 +712,12 @@ def _make_graphed_callables(
elif isinstance(m, BasicOperation):
for mode in ("forward", "backward"):
if m.num_quantizers(mode):
m._fp8_metas[mode][
"fp8_group"
] = FP8GlobalStateManager.get_fp8_group()
m._fp8_metas[mode][
"recipe"
] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m._fp8_metas[mode],
)
......@@ -756,7 +762,7 @@ def save_fp8_tensors(
m.adjust_amax_history_length(fp8_recipe.amax_history_len)
module_tensors = m.get_fp8_meta_tensors()
elif isinstance(m, BasicOperation):
m.reset_recipe_type(recipe=fp8_recipe)
m.reset_recipe_state(recipe=fp8_recipe)
module_tensors = m._save_fp8_metas()
fp8_tensors.append(module_tensors)
return fp8_tensors
......
......@@ -294,9 +294,9 @@ class BasicLinear(BasicOperation):
raise RuntimeError(
"Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. "
"This is most likely because fp8_model_init was called "
"with enabled=True and recipe=None, instead of providing "
"a recipe to use for quantization."
"This is most likely because the weight was initialized "
"within fp8_model_init, but the forward pass was not "
"performed within fp8_autocast."
)
quantizer.set_usage(
rowwise=True,
......@@ -315,8 +315,8 @@ class BasicLinear(BasicOperation):
if self.weight.device.type == "meta":
self.reset_parameters()
def reset_recipe_type(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_type(recipe=recipe)
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe)
if recipe is not None and not FP8GlobalStateManager.with_fp8_parameters():
# Make quantizers use internal tensors
......
......@@ -398,27 +398,30 @@ class OperationFuser:
break
# Early exit if fusion parameters haven't changed
need_reset = False
recipe_type = type(recipe)
fusion_params = (recipe_type, first_op_requiring_backward)
if fusion_params == (self.recipe_type, self.first_op_requiring_backward):
if fusion_params != (self.recipe_type, self.first_op_requiring_backward):
# Recipe type or grad requirmenets have changed
need_reset = True
elif (
recipe is not None
and recipe.delayed()
and self._last_amax_history_len != recipe.amax_history_len
):
# FP8 delayed scaling has changed amax history length
need_reset = True
if not need_reset:
return
# Initialize ops if recipe type has changed
if self.recipe_type != recipe_type:
# Check if this is the first iteration
if self.recipe_type is None:
for op in self._basic_ops:
op.pre_first_fuser_forward()
# Inform ops that the recipe type has changed
# Reset recipe state
for op in self._basic_ops:
op.reset_recipe_state(recipe=recipe)
# Check if this is the first iteration
if self.recipe_type is None:
for op in self._basic_ops:
op.reset_recipe_type(recipe=recipe)
# Check if amax history was invalidated
elif isinstance(recipe, DelayedScaling):
if recipe.amax_history_len != self._last_amax_history_len:
raise RuntimeError(
"Detected change of amax history length. "
"Changing the length of amax history is currently not supported."
)
op.pre_first_fuser_forward()
# Prepare basic op lists for fusions
forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)]
......
......@@ -183,7 +183,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._quantizers: Optional[dict[str, list[Quantizer]]] = None
with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None
self.reset_recipe_type(recipe=recipe)
self.reset_recipe_state(recipe=recipe)
@property
def is_fused_op(self) -> bool:
......@@ -215,7 +215,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
return self.get_quantizer("backward", 0)
return None
def reset_recipe_type(
def reset_recipe_state(
self,
*,
recipe: Optional[Recipe],
......@@ -228,6 +228,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._quantizers = None
return
# Communication group for FP8 amax reductions
fp8_group = FP8GlobalStateManager.get_fp8_group()
# Skip resetting recipe type if it did not actually change.
# This could happen for example if calling BasicOperation.forward directly, as in that
# case, the OperationFuser is not persistent, or when loading from a checkpoint
......@@ -247,7 +250,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
break
if need_to_reset_recipe_state:
# Quantization recipe state for forward and backward pass
# Construct quantization recipe states
self._fp8_metas = {"forward": None, "backward": None}
self._quantizers = {"forward": [], "backward": []}
for mode in ("forward", "backward"):
......@@ -272,11 +275,38 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas[mode] = {
fp8_meta_key: recipe_state,
"recipe": recipe,
"fp8_group": FP8GlobalStateManager.get_fp8_group(),
"fp8_group": fp8_group,
}
# Construct builder class for quantized tensors
self._quantizers[mode] = recipe_state.make_quantizers()
else:
# Update quantization recipe states
for mode in ("forward", "backward"):
if self._fp8_metas[mode] is None:
continue
self._fp8_metas[mode]["recipe"] = recipe
self._fp8_metas[mode]["fp8_group"] = fp8_group
# Update amax history for FP8 delayed scaling
if recipe.delayed():
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
recipe_state = self._fp8_metas[mode][fp8_meta_key]
current_length = recipe_state.amax_history.size(0)
target_length = recipe.amax_history_len
if target_length < current_length:
with torch.no_grad():
recipe_state.amax_history = recipe_state.amax_history[
:target_length
].clone()
elif target_length > current_length:
with torch.no_grad():
recipe_state.amax_history = torch.nn.functional.pad(
recipe_state.amax_history,
pad=(0, 0, 0, target_length - current_length),
)
# Add meta tensors to global buffer to participate in reduction
for mode in ("forward", "backward"):
......@@ -574,7 +604,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
self.reset_recipe_type(recipe=state[mode]["recipe"])
self.reset_recipe_state(recipe=state[mode]["recipe"])
fp8_meta = self._fp8_metas[mode]
# Load extra items
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment