Unverified Commit fdb87afc authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Reset recipe state in fusible operations when FP8 amax history length changes (#1985)



* Fix bug where TE ops were not updating fp8_meta dicts
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Rename reset_recipe_state function
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update error message when initializing meta device quantized weight without recipe
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent d1967d55
...@@ -712,6 +712,12 @@ def _make_graphed_callables( ...@@ -712,6 +712,12 @@ def _make_graphed_callables(
elif isinstance(m, BasicOperation): elif isinstance(m, BasicOperation):
for mode in ("forward", "backward"): for mode in ("forward", "backward"):
if m.num_quantizers(mode): if m.num_quantizers(mode):
m._fp8_metas[mode][
"fp8_group"
] = FP8GlobalStateManager.get_fp8_group()
m._fp8_metas[mode][
"recipe"
] = FP8GlobalStateManager.get_fp8_recipe()
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(
m._fp8_metas[mode], m._fp8_metas[mode],
) )
...@@ -756,7 +762,7 @@ def save_fp8_tensors( ...@@ -756,7 +762,7 @@ def save_fp8_tensors(
m.adjust_amax_history_length(fp8_recipe.amax_history_len) m.adjust_amax_history_length(fp8_recipe.amax_history_len)
module_tensors = m.get_fp8_meta_tensors() module_tensors = m.get_fp8_meta_tensors()
elif isinstance(m, BasicOperation): elif isinstance(m, BasicOperation):
m.reset_recipe_type(recipe=fp8_recipe) m.reset_recipe_state(recipe=fp8_recipe)
module_tensors = m._save_fp8_metas() module_tensors = m._save_fp8_metas()
fp8_tensors.append(module_tensors) fp8_tensors.append(module_tensors)
return fp8_tensors return fp8_tensors
......
...@@ -294,9 +294,9 @@ class BasicLinear(BasicOperation): ...@@ -294,9 +294,9 @@ class BasicLinear(BasicOperation):
raise RuntimeError( raise RuntimeError(
"Tried to quantize weight with deferred initialization " "Tried to quantize weight with deferred initialization "
"due to meta device, but no quantizer was available. " "due to meta device, but no quantizer was available. "
"This is most likely because fp8_model_init was called " "This is most likely because the weight was initialized "
"with enabled=True and recipe=None, instead of providing " "within fp8_model_init, but the forward pass was not "
"a recipe to use for quantization." "performed within fp8_autocast."
) )
quantizer.set_usage( quantizer.set_usage(
rowwise=True, rowwise=True,
...@@ -315,8 +315,8 @@ class BasicLinear(BasicOperation): ...@@ -315,8 +315,8 @@ class BasicLinear(BasicOperation):
if self.weight.device.type == "meta": if self.weight.device.type == "meta":
self.reset_parameters() self.reset_parameters()
def reset_recipe_type(self, *, recipe: Optional[Recipe]) -> None: def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_type(recipe=recipe) super().reset_recipe_state(recipe=recipe)
if recipe is not None and not FP8GlobalStateManager.with_fp8_parameters(): if recipe is not None and not FP8GlobalStateManager.with_fp8_parameters():
# Make quantizers use internal tensors # Make quantizers use internal tensors
......
...@@ -398,27 +398,30 @@ class OperationFuser: ...@@ -398,27 +398,30 @@ class OperationFuser:
break break
# Early exit if fusion parameters haven't changed # Early exit if fusion parameters haven't changed
need_reset = False
recipe_type = type(recipe) recipe_type = type(recipe)
fusion_params = (recipe_type, first_op_requiring_backward) fusion_params = (recipe_type, first_op_requiring_backward)
if fusion_params == (self.recipe_type, self.first_op_requiring_backward): if fusion_params != (self.recipe_type, self.first_op_requiring_backward):
# Recipe type or grad requirmenets have changed
need_reset = True
elif (
recipe is not None
and recipe.delayed()
and self._last_amax_history_len != recipe.amax_history_len
):
# FP8 delayed scaling has changed amax history length
need_reset = True
if not need_reset:
return return
# Initialize ops if recipe type has changed # Reset recipe state
if self.recipe_type != recipe_type: for op in self._basic_ops:
# Check if this is the first iteration op.reset_recipe_state(recipe=recipe)
if self.recipe_type is None:
for op in self._basic_ops: # Check if this is the first iteration
op.pre_first_fuser_forward() if self.recipe_type is None:
# Inform ops that the recipe type has changed
for op in self._basic_ops: for op in self._basic_ops:
op.reset_recipe_type(recipe=recipe) op.pre_first_fuser_forward()
# Check if amax history was invalidated
elif isinstance(recipe, DelayedScaling):
if recipe.amax_history_len != self._last_amax_history_len:
raise RuntimeError(
"Detected change of amax history length. "
"Changing the length of amax history is currently not supported."
)
# Prepare basic op lists for fusions # Prepare basic op lists for fusions
forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)] forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)]
......
...@@ -183,7 +183,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -183,7 +183,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._quantizers: Optional[dict[str, list[Quantizer]]] = None self._quantizers: Optional[dict[str, list[Quantizer]]] = None
with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None
self.reset_recipe_type(recipe=recipe) self.reset_recipe_state(recipe=recipe)
@property @property
def is_fused_op(self) -> bool: def is_fused_op(self) -> bool:
...@@ -215,7 +215,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -215,7 +215,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
return self.get_quantizer("backward", 0) return self.get_quantizer("backward", 0)
return None return None
def reset_recipe_type( def reset_recipe_state(
self, self,
*, *,
recipe: Optional[Recipe], recipe: Optional[Recipe],
...@@ -228,6 +228,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -228,6 +228,9 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._quantizers = None self._quantizers = None
return return
# Communication group for FP8 amax reductions
fp8_group = FP8GlobalStateManager.get_fp8_group()
# Skip resetting recipe type if it did not actually change. # Skip resetting recipe type if it did not actually change.
# This could happen for example if calling BasicOperation.forward directly, as in that # This could happen for example if calling BasicOperation.forward directly, as in that
# case, the OperationFuser is not persistent, or when loading from a checkpoint # case, the OperationFuser is not persistent, or when loading from a checkpoint
...@@ -247,7 +250,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -247,7 +250,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
break break
if need_to_reset_recipe_state: if need_to_reset_recipe_state:
# Quantization recipe state for forward and backward pass # Construct quantization recipe states
self._fp8_metas = {"forward": None, "backward": None} self._fp8_metas = {"forward": None, "backward": None}
self._quantizers = {"forward": [], "backward": []} self._quantizers = {"forward": [], "backward": []}
for mode in ("forward", "backward"): for mode in ("forward", "backward"):
...@@ -272,11 +275,38 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -272,11 +275,38 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
self._fp8_metas[mode] = { self._fp8_metas[mode] = {
fp8_meta_key: recipe_state, fp8_meta_key: recipe_state,
"recipe": recipe, "recipe": recipe,
"fp8_group": FP8GlobalStateManager.get_fp8_group(), "fp8_group": fp8_group,
} }
# Construct builder class for quantized tensors # Construct builder class for quantized tensors
self._quantizers[mode] = recipe_state.make_quantizers() self._quantizers[mode] = recipe_state.make_quantizers()
else:
# Update quantization recipe states
for mode in ("forward", "backward"):
if self._fp8_metas[mode] is None:
continue
self._fp8_metas[mode]["recipe"] = recipe
self._fp8_metas[mode]["fp8_group"] = fp8_group
# Update amax history for FP8 delayed scaling
if recipe.delayed():
fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key(
forward=(mode == "forward"),
)
recipe_state = self._fp8_metas[mode][fp8_meta_key]
current_length = recipe_state.amax_history.size(0)
target_length = recipe.amax_history_len
if target_length < current_length:
with torch.no_grad():
recipe_state.amax_history = recipe_state.amax_history[
:target_length
].clone()
elif target_length > current_length:
with torch.no_grad():
recipe_state.amax_history = torch.nn.functional.pad(
recipe_state.amax_history,
pad=(0, 0, 0, target_length - current_length),
)
# Add meta tensors to global buffer to participate in reduction # Add meta tensors to global buffer to participate in reduction
for mode in ("forward", "backward"): for mode in ("forward", "backward"):
...@@ -574,7 +604,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -574,7 +604,7 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Get op's quantizer state, initializing if needed # Get op's quantizer state, initializing if needed
if self._fp8_metas is None or self._fp8_metas[mode] is None: if self._fp8_metas is None or self._fp8_metas[mode] is None:
with fp8_autocast(fp8_recipe=state[mode]["recipe"]): with fp8_autocast(fp8_recipe=state[mode]["recipe"]):
self.reset_recipe_type(recipe=state[mode]["recipe"]) self.reset_recipe_state(recipe=state[mode]["recipe"])
fp8_meta = self._fp8_metas[mode] fp8_meta = self._fp8_metas[mode]
# Load extra items # Load extra items
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment