Unverified Commit 9ca89e97 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Avoid initializing recipe state in fusible op base class constructor (#2421)



Do not initialize recipe state in base op class

Op attrs may not be set. Move recipe state initialization to linear op constructor.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 9f61f8a5
...@@ -901,15 +901,15 @@ class TestBasicOps: ...@@ -901,15 +901,15 @@ class TestBasicOps:
dtype=dtype, dtype=dtype,
accumulate_into_main_grad=accumulate_into_main_grad, accumulate_into_main_grad=accumulate_into_main_grad,
) )
forward = te_ops.Sequential(
te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
)
with torch.no_grad(): with torch.no_grad():
op.weight.copy_(w_test) op.weight.copy_(w_test)
del w_test del w_test
op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32) op.weight.main_grad = torch.full_like(op.weight, 0.5, dtype=torch.float32)
forward = te_ops.Sequential(
te_ops.Quantize(forward=quantized_input, backward=quantized_grad_input),
op,
te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output),
)
with te.autocast(enabled=quantized_compute, recipe=recipe): with te.autocast(enabled=quantized_compute, recipe=recipe):
y_test = forward(x_test) y_test = forward(x_test)
y_test.backward(dy_test) y_test.backward(dy_test)
......
...@@ -137,8 +137,10 @@ class BasicLinear(BasicOperation): ...@@ -137,8 +137,10 @@ class BasicLinear(BasicOperation):
out_features=out_features, out_features=out_features,
) )
# Whether weight tensor is natively quantized # Initialize recipe state if needed for natively quantized weight
self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters() self._with_quantized_weight: bool = FP8GlobalStateManager.with_fp8_parameters()
if self._with_quantized_weight:
self.reset_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe())
# Initialize parameters if needed # Initialize parameters if needed
weight = torch.empty( weight = torch.empty(
......
...@@ -188,9 +188,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta): ...@@ -188,9 +188,6 @@ class BasicOperation(FusibleOperation, metaclass=abc.ABCMeta):
# Objects for quantization # Objects for quantization
self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None
self._quantizers: Optional[dict[str, list[Quantizer]]] = None self._quantizers: Optional[dict[str, list[Quantizer]]] = None
with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters()
recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None
self.reset_recipe_state(recipe=recipe)
@property @property
def is_fused_op(self) -> bool: def is_fused_op(self) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment