Unverified Commit 2645eaec authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[Pytorch] NVIDIA-DL-Framework-Inspect support – part 3 – tests (#1612)



* tests drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* move dir
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* tests fox
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1d903f5e
...@@ -49,7 +49,7 @@ def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None): ...@@ -49,7 +49,7 @@ def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None):
fp8_dtype = tex.DType.kFloat8E5M2 fp8_dtype = tex.DType.kFloat8E5M2
amax = tensor.abs().max().float() amax = tensor.abs().max().float()
one = torch.ones(1, device=tensor.device) one = torch.ones(1, device=tensor.device)
scale = _default_sf_compute(amax, one, fp8_max) scale = _default_sf_compute(amax, one, fp8_max, 0)
quantizer = Float8Quantizer(scale, amax, fp8_dtype) quantizer = Float8Quantizer(scale, amax, fp8_dtype)
else: else:
......
...@@ -120,7 +120,6 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -120,7 +120,6 @@ class LogFp8TensorStats(BaseLogTensorStats):
if not rowwise: if not rowwise:
return # tensor was already seen rowwise in the other gemm return # tensor was already seen rowwise in the other gemm
tensor = tensor._data
options = ( options = (
config.get("start_step", None), config.get("start_step", None),
config.get("end_step", None), config.get("end_step", None),
......
...@@ -15,6 +15,7 @@ import transformer_engine_torch as tex ...@@ -15,6 +15,7 @@ import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import ( from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor, Float8Tensor,
Float8Quantizer,
Float8CurrentScalingQuantizer, Float8CurrentScalingQuantizer,
) )
from transformer_engine.debug.features.api import TEConfigAPIMapper from transformer_engine.debug.features.api import TEConfigAPIMapper
...@@ -39,7 +40,7 @@ def per_tensor_cast( ...@@ -39,7 +40,7 @@ def per_tensor_cast(
}, "[NVTORCH INSPECT ERROR] Only 2 FP8 types: E4M3 and E5M2 are supported in TE." }, "[NVTORCH INSPECT ERROR] Only 2 FP8 types: E4M3 and E5M2 are supported in TE."
tensor = tensor.contiguous() tensor = tensor.contiguous()
quantizer = Float8CurrentScalingQuantizer(fp8_dtype) quantizer = Float8CurrentScalingQuantizer(fp8_dtype, device=tensor.device)
if out is not None: if out is not None:
quantizer.update_quantized(tensor, out) quantizer.update_quantized(tensor, out)
...@@ -118,7 +119,7 @@ class PerTensorScaling(TEConfigAPIMapper): ...@@ -118,7 +119,7 @@ class PerTensorScaling(TEConfigAPIMapper):
if key not in ["gemm", "tensor"]: if key not in ["gemm", "tensor"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
assert isinstance(default_quantizer, Float8CurrentScalingQuantizer), ( assert isinstance(default_quantizer, Float8Quantizer), (
f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor: " f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor: "
"Per-tensor current scaling can be used only within `DelayedScaling` recipe autocast." "Per-tensor current scaling can be used only within `DelayedScaling` recipe autocast."
f" {layer_name}" f" {layer_name}"
......
...@@ -96,7 +96,10 @@ STATS = { ...@@ -96,7 +96,10 @@ STATS = {
"max": (torch.max, lambda buffers: max(_get(buffers, "max"))), "max": (torch.max, lambda buffers: max(_get(buffers, "max"))),
"sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))), "sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))),
"mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))), "mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))),
"numel": (lambda x: x.numel(), lambda buffers: sum(_get(buffers, "numel"))), "numel": (
lambda x: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(),
lambda buffers: sum(_get(buffers, "numel")),
),
"l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))), "l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))),
"l2_norm_square": ( "l2_norm_square": (
lambda x: torch.sum(x**2), lambda x: torch.sum(x**2),
...@@ -137,7 +140,7 @@ STATS = { ...@@ -137,7 +140,7 @@ STATS = {
- min(_get(buffers, "dynamic_range_bottom")), - min(_get(buffers, "dynamic_range_bottom")),
), ),
"underflows%": ( "underflows%": (
lambda x: (x == 0).sum() / x.numel() * 100, lambda x: (x.get_data_tensors()[0] == 0).sum() / x.get_data_tensors()[0].numel() * 100,
lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")), lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")),
), ),
} }
...@@ -18,6 +18,7 @@ import transformer_engine_torch as tex ...@@ -18,6 +18,7 @@ import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor.quantized_tensor import ( from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
Quantizer, Quantizer,
QuantizedTensorBase,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
...@@ -299,6 +300,7 @@ class DebugQuantizer(Quantizer): ...@@ -299,6 +300,7 @@ class DebugQuantizer(Quantizer):
iteration=self.iteration, iteration=self.iteration,
dtype=dtype, dtype=dtype,
) )
if dtype is not None:
if columnwise_gemm_tensor.dtype != dtype: if columnwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call") raise ValueError("Dtype does not match the output of the modify_tensor call")
if self.rowwise_tensor_plan == API_CALL_MODIFY: if self.rowwise_tensor_plan == API_CALL_MODIFY:
...@@ -311,6 +313,7 @@ class DebugQuantizer(Quantizer): ...@@ -311,6 +313,7 @@ class DebugQuantizer(Quantizer):
iteration=self.iteration, iteration=self.iteration,
dtype=dtype, dtype=dtype,
) )
if dtype is not None:
if rowwise_gemm_tensor.dtype != dtype: if rowwise_gemm_tensor.dtype != dtype:
raise ValueError("Dtype does not match the output of the modify_tensor call") raise ValueError("Dtype does not match the output of the modify_tensor call")
...@@ -332,6 +335,7 @@ class DebugQuantizer(Quantizer): ...@@ -332,6 +335,7 @@ class DebugQuantizer(Quantizer):
quantizer=self, quantizer=self,
layer_name=self.layer_name, layer_name=self.layer_name,
tensor_name=self.tensor_name, tensor_name=self.tensor_name,
original_tensor=tensor,
) )
def process_gemm_output(self, tensor: torch.Tensor): def process_gemm_output(self, tensor: torch.Tensor):
...@@ -456,7 +460,7 @@ class DebugQuantizer(Quantizer): ...@@ -456,7 +460,7 @@ class DebugQuantizer(Quantizer):
return False return False
class DebugQuantizedTensor: class DebugQuantizedTensor(QuantizedTensorBase):
""" """
Class containing quantized tensors after debug. Depending on configuration Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method it can contain one or two different objects. These objects can be accessed by the method
...@@ -470,6 +474,7 @@ class DebugQuantizedTensor: ...@@ -470,6 +474,7 @@ class DebugQuantizedTensor:
quantizer, quantizer,
layer_name=None, layer_name=None,
tensor_name=None, tensor_name=None,
original_tensor=None,
): ):
self.rowwise_gemm_tensor = rowwise_gemm_tensor self.rowwise_gemm_tensor = rowwise_gemm_tensor
...@@ -477,6 +482,7 @@ class DebugQuantizedTensor: ...@@ -477,6 +482,7 @@ class DebugQuantizedTensor:
self.quantizer = quantizer self.quantizer = quantizer
self._layer_name = layer_name self._layer_name = layer_name
self._tensor_name = tensor_name self._tensor_name = tensor_name
self._original_tensor = original_tensor
def prepare_for_saving(self): def prepare_for_saving(self):
""" " Prepare for saving method override""" """ " Prepare for saving method override"""
...@@ -524,5 +530,5 @@ class DebugQuantizedTensor: ...@@ -524,5 +530,5 @@ class DebugQuantizedTensor:
"""Size of the tensor.""" """Size of the tensor."""
return self.rowwise_gemm_tensor.size() return self.rowwise_gemm_tensor.size()
def update_usage(self, rowwise_usage: bool, columnwise_usage: bool): def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None):
"""Update usage of the tensor.""" """Update usage of the tensor."""
...@@ -1239,12 +1239,18 @@ def gather_along_first_dim( ...@@ -1239,12 +1239,18 @@ def gather_along_first_dim(
final_quantizer = ( final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
) )
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if isinstance(rowwise, Float8BlockwiseQTensorBase):
rowwise = inp._original_tensor
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0] rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]
out_obj.rowwise_gemm_tensor = rowwise_total out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise: if rowwise is not columnwise:
final_quantizer_columnwise = ( final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
) )
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if isinstance(columnwise, Float8BlockwiseQTensorBase):
columnwise = inp._original_tensor
columnwise_total, _ = gather_along_first_dim( columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise columnwise, process_group, False, final_quantizer_columnwise
) )
......
...@@ -1057,7 +1057,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1057,7 +1057,12 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if ( if (
isinstance( isinstance(
grad_output_.get_tensor(True), grad_output_.get_tensor(True),
(QuantizedTensor, Float8TensorBase, MXFP8TensorBase), (
QuantizedTensor,
Float8TensorBase,
MXFP8TensorBase,
Float8BlockwiseQTensorBase,
),
) )
and ctx.use_bias and ctx.use_bias
): ):
......
...@@ -193,6 +193,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -193,6 +193,7 @@ class _LayerNormLinear(torch.autograd.Function):
# or if a gather of ln_out must be in high precision. # or if a gather of ln_out must be in high precision.
with_quantized_norm = ( with_quantized_norm = (
fp8 fp8
and not debug
and not return_layernorm_output and not return_layernorm_output
and not return_layernorm_output_gathered and not return_layernorm_output_gathered
and not force_hp_blockwise_ln_out_gather and not force_hp_blockwise_ln_out_gather
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment