Unverified Commit bfca2e33 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Update amax pointers when reallocating amax history in fusible ops (#2044)



* Update amax pointers when reallocating amax history in fusible ops
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update weight tensor quantizer when recipe state is reset
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent de6afe24
......@@ -301,6 +301,7 @@ class BasicLinear(BasicOperation):
rowwise=True,
columnwise=torch.is_grad_enabled(),
)
quantizer.internal = False
with torch.no_grad():
weight = quantizer(weight)
......@@ -317,11 +318,32 @@ class BasicLinear(BasicOperation):
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
super().reset_recipe_state(recipe=recipe)
if recipe is not None and not FP8GlobalStateManager.with_fp8_parameters():
# Make quantizers use internal tensors
self.get_input_quantizer().internal = True
self.get_grad_output_quantizer().internal = True
self.get_quantizer("forward", 1).internal = True
# Input/grad output quantizers use internal tensors
input_quantizer = self.get_quantizer("forward", 0)
grad_output_quantizer = self.get_quantizer("backward", 0)
if input_quantizer is not None:
input_quantizer.internal = True
if grad_output_quantizer is not None:
grad_output_quantizer.internal = True
# Handle weight quantizer
# Note: This function may be called in base class constructor,
# before any basic linear attrs have been set.
weight_quantizer = self.get_quantizer("forward", 1)
if weight_quantizer is None:
pass
elif is_quantized_tensor(getattr(self, "weight", None)):
# Make sure weight param has correct quantizer
weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
weight_quantizer.internal = False
self.weight.update_quantizer(weight_quantizer.copy())
else:
# Use internal tensors if quantized weights will not be
# exposed externally
weight_quantizer.internal = (
not FP8GlobalStateManager.with_fp8_parameters()
and not getattr(self, "_with_quantized_weight", False)
)
@staticmethod
def _functional_forward(
......
......@@ -110,14 +110,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs)
basic_op_extra_inputs.append(xs)
# Get environment state
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None
is_grad_enabled = func_ctx is not None
# Attempt to fuse operations if neccesary
fuser.maybe_fuse_ops(is_grad_enabled, recipe, input_, basic_op_extra_inputs)
# Apply forward ops
x = input_
extra_outputs = [None] * fuser._num_basic_ops
......@@ -167,7 +159,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
extra_outputs_flat.extend(ys)
# Save context for backward pass
if is_grad_enabled:
if func_ctx is not None:
# Flatten list of saved tensors
to_save = []
......@@ -180,12 +172,9 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
ctx._saved_tensors_range = (range_start, range_end)
# Save tensors for backward
if with_quantized_compute:
tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
func_ctx.save_for_backward(*tensors_to_save)
func_ctx.tensor_objects = tensor_objects
else:
func_ctx.save_for_backward(*to_save)
tensors_to_save, tensor_objects = prepare_for_saving(*to_save)
func_ctx.save_for_backward(*tensors_to_save)
func_ctx.tensor_objects = tensor_objects
# Other context
func_ctx.backward_ops = fuser._backward_ops
......@@ -195,7 +184,6 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
func_ctx.num_extra_inputs = fuser.num_extra_inputs
func_ctx.num_extra_outputs = len(extra_outputs_flat)
func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
func_ctx.with_quantized_compute = with_quantized_compute
# Mark output tensors as not deletable in backward
for tensor in [x] + extra_outputs_flat:
......@@ -223,10 +211,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
basic_op_ctxs = func_ctx.basic_op_ctxs
# Restore saved tensors
if func_ctx.with_quantized_compute:
saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
else:
saved_tensors = func_ctx.saved_tensors
saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors)
# Unflatten list of saved tensors
for ctx in basic_op_ctxs:
......@@ -460,8 +445,24 @@ class OperationFuser:
if basic_op_kwargs is None:
basic_op_kwargs = [{}] * self._num_basic_ops
# Unflatten list of extra tensor inputs
extra_inputs_copy = list(extra_inputs)
basic_op_extra_inputs = []
for op in self._basic_ops:
xs, extra_inputs_copy = _split_tuple(extra_inputs_copy, op.num_extra_inputs)
basic_op_extra_inputs.append(xs)
# Get environment state
recipe = None
if FP8GlobalStateManager.is_fp8_enabled():
recipe = FP8GlobalStateManager.get_fp8_recipe()
is_grad_enabled = torch.is_grad_enabled()
# Attempt to fuse operations if neccesary
self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs)
# Fuser forward pass
if torch.is_grad_enabled():
if is_grad_enabled:
forward_func = _OperationFuserAutogradFunction.apply
args = []
else:
......
......@@ -294,6 +294,8 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
forward=(mode == "forward"),
)
recipe_state = self._fp8_metas[mode][fp8_meta_key]
# Reallocate amax history if needed
current_length = recipe_state.amax_history.size(0)
target_length = recipe.amax_history_len
if target_length < current_length:
......@@ -308,6 +310,25 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
pad=(0, 0, 0, target_length - current_length),
)
# Update quantizers with new amax pointers
self._quantizers[mode] = recipe_state.make_quantizers()
# Update the global buffers with new amax pointers
if FP8GlobalStateManager.get_buffer_info() in self._fp8_metas[mode]:
pos, buffer_key = self._fp8_metas[mode][
FP8GlobalStateManager.get_buffer_info()
]
if buffer_key in FP8GlobalStateManager.global_amax_buffer:
assert (
buffer_key in FP8GlobalStateManager.global_amax_history_buffer
), "TE internal error during amax history change."
FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = (
recipe_state.amax_history[0]
)
FP8GlobalStateManager.global_amax_history_buffer[buffer_key][
pos
] = recipe_state.amax_history
# Add meta tensors to global buffer to participate in reduction
for mode in ("forward", "backward"):
if (
......@@ -686,7 +707,6 @@ class FusedOperation(FusibleOperation):
return self.basic_ops[-1].get_grad_output_quantizer()
def pre_first_fuser_forward(self) -> None:
"""Preprocessing before first fuser forward pass"""
for op in self.basic_ops:
op.pre_first_fuser_forward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment