"docs/debug/api.rst" did not exist on "c9ea6be92948e1ec553037f1a04900617b9f7f6b"
Unverified Commit b6b3abce authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch debug] Improve precision debug tools performance (#1909)



* turn on userbuffers for layers without debug
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* working change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tests and fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* update nvinspect version
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix ci
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9f9b4816
...@@ -47,6 +47,7 @@ from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTens ...@@ -47,6 +47,7 @@ from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTens
from ...common.recipe import DelayedScaling, Recipe from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor
from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled
__all__ = ["initialize_ub", "destroy_ub"] __all__ = ["initialize_ub", "destroy_ub"]
...@@ -564,6 +565,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -564,6 +565,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
super().__init__() super().__init__()
assert torch.cuda.is_available(), "TransformerEngine needs CUDA." assert torch.cuda.is_available(), "TransformerEngine needs CUDA."
self.name = None self.name = None
self.next_iter_when_debug_should_be_run = 0
self.fp8_initialized = False self.fp8_initialized = False
self.fp8 = False self.fp8 = False
self.fp8_calibration = False self.fp8_calibration = False
...@@ -1416,12 +1418,55 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1416,12 +1418,55 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook() wgrad_accumulation_and_reduce_hook()
def is_debug_iter(self) -> bool:
"""
This function checks if the debug should be enabled for this layer.
"""
debug = TEDebugState.debug_enabled
if not debug:
return False
self._validate_name()
# If layer is run first time in new iteration,
# we need to check if the debug should be enabled for this layer -
# maybe in previous iterations debug features returned information
# that no feature will be active for this layer for multiple next iterations.
started_new_iteration = TEDebugState.get_iteration() != getattr(
self, "debug_last_iteration", None
)
if started_new_iteration:
if self.next_iter_when_debug_should_be_run is None:
debug = False
else:
debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run
self.debug_last_iteration = TEDebugState.get_iteration()
return debug
def no_debug_features_active(self, quantizers):
"""
Checks if any debug feature is active for this layer.
"""
run_current = any_feature_enabled(quantizers)
# Sometimes features inform that they will not be enabled for particular layer
# for multiple next iterations.
self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers)
if not run_current:
return True
if self.primary_weights_in_fp8:
raise RuntimeError("FP8 weights are not supported in debug mode.")
return False
def _validate_name(self): def _validate_name(self):
""" """
Validate name passed to the module. Validate name passed to the module.
This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM.
If no name is assigned, it creates a default name with layer count as the variable. If no name is assigned, it creates a default name with layer count as the variable.
""" """
if self.name is not None:
return
assert TEDebugState.debug_enabled assert TEDebugState.debug_enabled
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
...@@ -1470,29 +1515,3 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1470,29 +1515,3 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
" Please check the recipes assigned during fp8_model_init() and" " Please check the recipes assigned during fp8_model_init() and"
" fp8_autocast() calls." " fp8_autocast() calls."
) )
def _turn_off_unsupported_features_in_debug(self):
if (
getattr(self, "ub_bulk_wgrad", False)
or getattr(self, "ub_bulk_dgrad", False)
or getattr(self, "ub_overlap_ag", False)
or getattr(self, "ub_overlap_rs_dgrad", False)
or getattr(self, "ub_overlap_rs", False)
):
import nvdlfw_inspect.api as debug_api
debug_api.log_message(
"UserBuffers are not supported in debug module. "
"Using UB optimization will not affect the debug module. ",
level=logging.WARNING,
)
if hasattr(self, "ub_bulk_wgrad"):
self.ub_bulk_wgrad = None
if hasattr(self, "ub_bulk_dgrad"):
self.ub_bulk_dgrad = None
if hasattr(self, "ub_overlap_ag"):
self.ub_overlap_ag = None
if hasattr(self, "ub_overlap_rs_dgrad"):
self.ub_overlap_rs_dgrad = None
if hasattr(self, "ub_overlap_rs"):
self.ub_overlap_rs = None
...@@ -62,9 +62,7 @@ from ..tensor.quantized_tensor import ( ...@@ -62,9 +62,7 @@ from ..tensor.quantized_tensor import (
restore_from_saved, restore_from_saved,
) )
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
...@@ -162,6 +160,13 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -162,6 +160,13 @@ class _LayerNormLinear(torch.autograd.Function):
with_input_all_gather = parallel_mode == "column" and sequence_parallel with_input_all_gather = parallel_mode == "column" and sequence_parallel
# Configure Userbuffers communication (comm+GEMM overlap) # Configure Userbuffers communication (comm+GEMM overlap)
if debug: # turn off userbuffers in debug mode
ub_overlap_ag_fprop = False
ub_overlap_rs_fprop = False
ub_overlap_ag_dgrad = False
ub_overlap_rs_dgrad = False
ub_bulk_wgrad = False
ub_bulk_dgrad = False
ub_obj = None ub_obj = None
ub_type = None ub_type = None
ub_overlap_ag_fprop = ( ub_overlap_ag_fprop = (
...@@ -179,9 +184,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -179,9 +184,7 @@ class _LayerNormLinear(torch.autograd.Function):
if input_quantizer is None: if input_quantizer is None:
raise ValueError("Missing quantizer for input tensor") raise ValueError("Missing quantizer for input tensor")
input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input)
if with_input_all_gather and isinstance( if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather():
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data # All-gather is not supported with FP8 column-wise data
input_quantizer.set_usage(columnwise=False) input_quantizer.set_usage(columnwise=False)
...@@ -638,7 +641,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -638,7 +641,7 @@ class _LayerNormLinear(torch.autograd.Function):
quantizer = None quantizer = None
if ctx.input_quantizer is not None: if ctx.input_quantizer is not None:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if quantizer.supports_only_rowwise_all_gather():
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
else: else:
...@@ -1163,8 +1166,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1163,8 +1166,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.name = name self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
if tp_group is None: if tp_group is None:
self.tp_size = tp_size self.tp_size = tp_size
...@@ -1471,9 +1472,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1471,9 +1472,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
""" """
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output) return self.onnx_forward(inp, fp8_output)
debug = TEDebugState.debug_enabled
if debug: debug = self.is_debug_iter()
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing(): if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
...@@ -1504,13 +1504,9 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1504,13 +1504,9 @@ class LayerNormLinear(TransformerEngineBaseModule):
else self._get_debug_quantizers(fp8_output, fp8_grad) else self._get_debug_quantizers(fp8_output, fp8_grad)
) )
if debug: if debug:
if not any_feature_enabled(quantizers): if self.no_debug_features_active(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
debug = False debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad)
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
( (
input_quantizer, input_quantizer,
......
...@@ -68,7 +68,6 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer ...@@ -68,7 +68,6 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ._common import apply_normalization, WeightGradStore from ._common import apply_normalization, WeightGradStore
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ..tensor.quantized_tensor import ( from ..tensor.quantized_tensor import (
QuantizedTensor,
QuantizedTensorBase, QuantizedTensorBase,
Quantizer, Quantizer,
prepare_for_saving, prepare_for_saving,
...@@ -78,7 +77,6 @@ from ..cpp_extensions import ( ...@@ -78,7 +77,6 @@ from ..cpp_extensions import (
general_gemm, general_gemm,
) )
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ...debug.pytorch.utils import any_feature_enabled
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
__all__ = ["LayerNormMLP"] __all__ = ["LayerNormMLP"]
...@@ -223,6 +221,12 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -223,6 +221,12 @@ class _LayerNormMLP(torch.autograd.Function):
device = inp.device device = inp.device
# Configure Userbuffers communication (comm+GEMM overlap) # Configure Userbuffers communication (comm+GEMM overlap)
if debug: # turn off userbuffers in debug mode
ub_overlap_ag = False
ub_overlap_rs = False
ub_overlap_rs_dgrad = False
ub_bulk_wgrad = False
ub_bulk_dgrad = False
ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered
ub_overlap_rs = ub_overlap_rs and is_grad_enabled ub_overlap_rs = ub_overlap_rs and is_grad_enabled
...@@ -238,9 +242,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -238,9 +242,7 @@ class _LayerNormMLP(torch.autograd.Function):
if fc1_input_quantizer is None: if fc1_input_quantizer is None:
raise ValueError("Missing quantizer for FC1 input tensor") raise ValueError("Missing quantizer for FC1 input tensor")
fc1_input_quantizer.set_usage(rowwise=True, columnwise=backwards_needs_fc1_input) fc1_input_quantizer.set_usage(rowwise=True, columnwise=backwards_needs_fc1_input)
if sequence_parallel and isinstance( if sequence_parallel and fc1_input_quantizer.supports_only_rowwise_all_gather():
fc1_input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
):
# All-gather is not supported with FP8 column-wise data # All-gather is not supported with FP8 column-wise data
fc1_input_quantizer.set_usage(columnwise=False) fc1_input_quantizer.set_usage(columnwise=False)
...@@ -1523,9 +1525,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1523,9 +1525,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
) )
self.name = name self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if tp_group is None: if tp_group is None:
...@@ -1728,9 +1727,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1728,9 +1727,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
""" """
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp) return self.onnx_forward(inp)
debug = TEDebugState.debug_enabled
if debug: debug = self.is_debug_iter()
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing(): if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
...@@ -1754,12 +1752,9 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1754,12 +1752,9 @@ class LayerNormMLP(TransformerEngineBaseModule):
else self._get_debug_quantizers(fp8_output) else self._get_debug_quantizers(fp8_output)
) )
if debug: if debug:
if not any_feature_enabled(quantizers): if self.no_debug_features_active(quantizers):
quantizers = self._get_quantizers(fp8_output)
debug = False debug = False
quantizers = self._get_quantizers(fp8_output)
if isinstance(self.fc1_weight, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
# Get quantizers # Get quantizers
( (
......
...@@ -68,7 +68,6 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer ...@@ -68,7 +68,6 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..export import is_in_onnx_export_mode, assert_warmed_up from ..export import is_in_onnx_export_mode, assert_warmed_up
from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
from ...debug.pytorch.utils import any_feature_enabled
__all__ = ["Linear"] __all__ = ["Linear"]
...@@ -137,6 +136,12 @@ class _Linear(torch.autograd.Function): ...@@ -137,6 +136,12 @@ class _Linear(torch.autograd.Function):
) )
# Configure Userbuffers communication (comm+GEMM overlap) # Configure Userbuffers communication (comm+GEMM overlap)
if debug: # turn off userbuffers in debug mode
ub_overlap_rs_fprop = False
ub_overlap_ag_fprop = False
ub_overlap_rs_dgrad = False
ub_bulk_wgrad = False
ub_bulk_dgrad = False
ub_obj = None ub_obj = None
ub_type = None ub_type = None
if ub_overlap_rs_fprop: if ub_overlap_rs_fprop:
...@@ -356,8 +361,9 @@ class _Linear(torch.autograd.Function): ...@@ -356,8 +361,9 @@ class _Linear(torch.autograd.Function):
and own_quantized_input and own_quantized_input
and isinstance(inputmat, QuantizedTensorBase) and isinstance(inputmat, QuantizedTensorBase)
): ):
if ctx.backward_input_needs_gather and isinstance( if (
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ctx.backward_input_needs_gather
and weight_quantizer.supports_only_rowwise_all_gather()
): ):
# All-gather is not supported with FP8 column-wise data # All-gather is not supported with FP8 column-wise data
inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) inputmat.update_usage(rowwise_usage=True, columnwise_usage=False)
...@@ -589,7 +595,7 @@ class _Linear(torch.autograd.Function): ...@@ -589,7 +595,7 @@ class _Linear(torch.autograd.Function):
else: else:
# Quantize input tensor # Quantize input tensor
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if quantizer.supports_only_rowwise_all_gather():
# All-gather is not supported with FP8 column-wise data # All-gather is not supported with FP8 column-wise data
quantizer.set_usage( quantizer.set_usage(
rowwise=True, rowwise=True,
...@@ -607,7 +613,7 @@ class _Linear(torch.autograd.Function): ...@@ -607,7 +613,7 @@ class _Linear(torch.autograd.Function):
quantizer = None quantizer = None
if ctx.fp8 or ctx.debug: if ctx.fp8 or ctx.debug:
quantizer = ctx.input_quantizer quantizer = ctx.input_quantizer
if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): if quantizer.supports_only_rowwise_all_gather():
# If data is in FP8, we compute FP8 transposes manually # If data is in FP8, we compute FP8 transposes manually
quantizer.set_usage(rowwise=True, columnwise=False) quantizer.set_usage(rowwise=True, columnwise=False)
else: else:
...@@ -1077,9 +1083,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1077,9 +1083,6 @@ class Linear(TransformerEngineBaseModule):
self.save_original_input = save_original_input self.save_original_input = save_original_input
self.name = name self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
if device == "meta": if device == "meta":
...@@ -1341,9 +1344,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1341,9 +1344,7 @@ class Linear(TransformerEngineBaseModule):
if is_in_onnx_export_mode(): if is_in_onnx_export_mode():
return self.onnx_forward(inp, fp8_output) return self.onnx_forward(inp, fp8_output)
debug = TEDebugState.debug_enabled debug = self.is_debug_iter()
if debug:
self._validate_name()
if FP8GlobalStateManager.fp8_graph_capturing(): if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
...@@ -1373,14 +1374,11 @@ class Linear(TransformerEngineBaseModule): ...@@ -1373,14 +1374,11 @@ class Linear(TransformerEngineBaseModule):
if not debug if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad) else self._get_debug_quantizers(fp8_output, fp8_grad)
) )
if debug: if debug:
if not any_feature_enabled(quantizers): if self.no_debug_features_active(quantizers):
# If no feature is used, then run faster implementation with debug = False.
quantizers = self._get_quantizers(fp8_output, fp8_grad)
debug = False debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad)
if isinstance(weight_tensor, QuantizedTensor):
raise RuntimeError("FP8 weights are not supported in debug mode.")
( (
input_quantizer, input_quantizer,
......
...@@ -184,6 +184,12 @@ class Float8Quantizer(Quantizer): ...@@ -184,6 +184,12 @@ class Float8Quantizer(Quantizer):
def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return DelayedScaling return DelayedScaling
def supports_only_rowwise_all_gather(self) -> bool:
"""
Float8Quantizer supports only rowwise all-gather
"""
return True
class Float8CurrentScalingQuantizer(Quantizer): class Float8CurrentScalingQuantizer(Quantizer):
"""Builder class for FP8 tensors with per-tensor current scaling """Builder class for FP8 tensors with per-tensor current scaling
...@@ -361,6 +367,12 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -361,6 +367,12 @@ class Float8CurrentScalingQuantizer(Quantizer):
def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
return Float8CurrentScaling return Float8CurrentScaling
def supports_only_rowwise_all_gather(self) -> bool:
"""
Float8CurrentScalingQuantizer supports only rowwise all-gather
"""
return True
class Float8Tensor(Float8TensorBase, QuantizedTensor): class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
......
...@@ -260,6 +260,10 @@ class Quantizer(abc.ABC): ...@@ -260,6 +260,10 @@ class Quantizer(abc.ABC):
def _get_compatible_recipe(self) -> Union[type[Recipe], None]: def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer""" """Returns recipe class that is compatible with this quantizer"""
def supports_only_rowwise_all_gather(self) -> bool:
"""Returns True if the quantizer supports only rowwise all-gather"""
return False
class _QuantizeFunc(torch.autograd.Function): class _QuantizeFunc(torch.autograd.Function):
"""Cast to FP8 from other dtype""" """Cast to FP8 from other dtype"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment