"csrc/vscode:/vscode.git/clone" did not exist on "beb89f68b448a43ac112b48e3834f80a2df626cb"
Unverified Commit bfca2e33 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Update amax pointers when reallocating amax history in fusible ops (#2044)



* Update amax pointers when reallocating amax history in fusible ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update weight tensor quantizer when recipe state is reset
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent de6afe24
...@@ -301,6 +301,7 @@ class BasicLinear(BasicOperation): ...@@ -301,6 +301,7 @@ class BasicLinear(BasicOperation):
rowwise=True, rowwise=True,
columnwise=torch.is_grad_enabled(), columnwise=torch.is_grad_enabled(),
) )
quantizer.internal = False
with torch.no_grad(): with torch.no_grad():
weight = quantizer(weight) weight = quantizer(weight)
...@@ -317,11 +318,32 @@ class BasicLinear(BasicOperation): ...@@ -317,11 +318,32 @@ class BasicLinear(BasicOperation):
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe) super().reset_recipe_state(recipe=recipe)
if recipe is not None and not FP8GlobalStateManager.with_fp8_parameters(): # Input/grad output quantizers use internal tensors
# Make quantizers use internal tensors input_quantizer = self.get_quantizer("forward", 0)
self.get_input_quantizer().internal = True grad_output_quantizer = self.get_quantizer("backward", 0)
self.get_grad_output_quantizer().internal = True if input_quantizer is not None:
self.get_quantizer("forward", 1).internal = True input_quantizer.internal = True
if grad_output_quantizer is not None:
grad_output_quantizer.internal = True
# Handle weight quantizer
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
weight_quantizer = self.get_quantizer("forward", 1)
if weight_quantizer is None:
pass
elif is_quantized_tensor(getattr(self, "weight", None)):
# Make sure weight param has correct quantizer
weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
weight_quantizer.internal = False
self.weight.update_quantizer(weight_quantizer.copy())
else:
# Use internal tensors if quantized weights will not be
# exposed externally
weight_quantizer.internal = (
not FP8GlobalStateManager.with_fp8_parameters()
and not getattr(self, "_with_quantized_weight", False)
)
@staticmethod @staticmethod
def _functional_forward( def _functional_forward(
......
...@@ -110,14 +110,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -110,14 +110,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
basic_op_extra_inputs.append(xs) basic_op_extra_inputs.append(xs)
# Get environment state
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
is_grad_enabled = func_ctx is not None
# Attempt to fuse operations if neccesary
fuser.maybe_fuse_ops(is_grad_enabled, recipe, input_, basic_op_extra_inputs)
# Apply forward ops # Apply forward ops
x = input_ x = input_
extra_outputs = [None] * fuser._num_basic_ops extra_outputs = [None] * fuser._num_basic_ops
...@@ -167,7 +159,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -167,7 +159,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
extra_outputs_flat.extend(ys) extra_outputs_flat.extend(ys)
# Save context for backward pass # Save context for backward pass
if is_grad_enabled: if func_ctx is not None:
# Flatten list of saved tensors # Flatten list of saved tensors
to_save = [] to_save = []
...@@ -180,12 +172,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -180,12 +172,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
ctx._saved_tensors_range = (range_start, range_end) ctx._saved_tensors_range = (range_start, range_end)
# Save tensors for backward # Save tensors for backward
if with_quantized_compute:
tensors_to_save, tensor_objects = prepare_for_saving(*to_save) tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
func_ctx.save_for_backward(*tensors_to_save) func_ctx.save_for_backward(*tensors_to_save)
func_ctx.tensor_objects = tensor_objects func_ctx.tensor_objects = tensor_objects
else:
func_ctx.save_for_backward(*to_save)
# Other context # Other context
func_ctx.backward_ops = fuser._backward_ops func_ctx.backward_ops = fuser._backward_ops
...@@ -195,7 +184,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -195,7 +184,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.num_extra_inputs = fuser.num_extra_inputs func_ctx.num_extra_inputs = fuser.num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.with_quantized_compute = with_quantized_compute
# Mark output tensors as not deletable in backward # Mark output tensors as not deletable in backward
for tensor in [x] + extra_outputs_flat: for tensor in [x] + extra_outputs_flat:
...@@ -223,10 +211,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function): ...@@ -223,10 +211,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_ctxs = func_ctx.basic_op_ctxs basic_op_ctxs = func_ctx.basic_op_ctxs
# Restore saved tensors # Restore saved tensors
if func_ctx.with_quantized_compute:
saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors) saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
else:
saved_tensors = func_ctx.saved_tensors
# Unflatten list of saved tensors # Unflatten list of saved tensors
for ctx in basic_op_ctxs: for ctx in basic_op_ctxs:
...@@ -460,8 +445,24 @@ class OperationFuser: ...@@ -460,8 +445,24 @@ class OperationFuser:
if basic_op_kwargs is None: if basic_op_kwargs is None:
basic_op_kwargs = [{}] * self._num_basic_ops basic_op_kwargs = [{}] * self._num_basic_ops
# Unflatten list of extra tensor inputs
extra_inputs_copy = list(extra_inputs)
basic_op_extra_inputs = []
for op in self._basic_ops:
xs, extra_inputs_copy = _split_tuple(extra_inputs_copy, op.num_extra_inputs)
basic_op_extra_inputs.append(xs)
# Get environment state
recipe = None
if FP8GlobalStateManager.is_fp8_enabled():
recipe = FP8GlobalStateManager.get_fp8_recipe()
is_grad_enabled = torch.is_grad_enabled()
# Attempt to fuse operations if neccesary
self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs)
# Fuser forward pass # Fuser forward pass
if torch.is_grad_enabled(): if is_grad_enabled:
forward_func = _OperationFuserAutogradFunction.apply forward_func = _OperationFuserAutogradFunction.apply
args = [] args = []
else: else:
......
...@@ -294,6 +294,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -294,6 +294,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
forward=(mode == "forward"), forward=(mode == "forward"),
) )
recipe_state = self._fp8_metas[mode][fp8_meta_key] recipe_state = self._fp8_metas[mode][fp8_meta_key]
# Reallocate amax history if needed
current_length = recipe_state.amax_history.size(0) current_length = recipe_state.amax_history.size(0)
target_length = recipe.amax_history_len target_length = recipe.amax_history_len
if target_length < current_length: if target_length < current_length:
...@@ -308,6 +310,25 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -308,6 +310,25 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
pad=(0, 0, 0, target_length - current_length), pad=(0, 0, 0, target_length - current_length),
) )
# Update quantizers with new amax pointers
self._quantizers[mode] = recipe_state.make_quantizers()
# Update the global buffers with new amax pointers
if FP8GlobalStateManager.get_buffer_info() in self._fp8_metas[mode]:
pos, buffer_key = self._fp8_metas[mode][
FP8GlobalStateManager.get_buffer_info()
]
if buffer_key in FP8GlobalStateManager.global_amax_buffer:
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = (
recipe_state.amax_history[0]
)
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][
pos
] = recipe_state.amax_history
# Add meta tensors to global buffer to participate in reduction # Add meta tensors to global buffer to participate in reduction
for mode in ("forward", "backward"): for mode in ("forward", "backward"):
if ( if (
...@@ -686,7 +707,6 @@ class FusedOperation(FusibleOperation): ...@@ -686,7 +707,6 @@ class FusedOperation(FusibleOperation):
return self.basic_ops[-1].get_grad_output_quantizer() return self.basic_ops[-1].get_grad_output_quantizer()
def pre_first_fuser_forward(self) -> None: def pre_first_fuser_forward(self) -> None:
"""Preprocessing before first fuser forward pass"""
for op in self.basic_ops: for op in self.basic_ops:
op.pre_first_fuser_forward() op.pre_first_fuser_forward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment