Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
......@@ -2,4 +2,4 @@
#
# See LICENSE for license information.
"""Python interface for dot product attention"""
"""Debug features."""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Feature doing nothing, used for testing purposes."""
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper
@Registry.register_feature(namespace="transformer_engine")
class TestDummyFeature(TEConfigAPIMapper):
"""
This is feature used only in tests. It invokes look_at_tensor_before_process
and does nothing.
If no features are used, then TE layer automatically switches to the non-debug mode.
This feature is invoked for each GEMM to prevent this behavior.
"""
@api_method
def inspect_tensor_enabled(self, *_args, **_kwargs):
"""API call used to determine whether to run look_at_tensor_before_process
in the forward pass."""
return True
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""API definition for nvidia-dlframework-inspect."""
import copy
from typing import Dict, Union
from nvdlfw_inspect.base import BaseNamespaceAPI, BaseConfigAPIMapper
from nvdlfw_inspect.registry import Registry
import torch
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import all_tensor_types
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
class TEConfigAPIMapper(BaseConfigAPIMapper):
"""Class responsible for determining which NV DLFW Inspect API should be run for each tensor and gemm."""
def parse_config_and_api(self, config, **kwargs):
"""Process the config and returns True if the config and api args match, along with processed config."""
processed_config = None
config_copy = copy.deepcopy(config)
gemm_parsing = kwargs.get("gemm_parsing", False)
tensor_parsing = kwargs.get("tensor_parsing", False)
if gemm_parsing:
# parse with GEMM and/or tensor
processed_config = self._process_transformer_engine_config(config_copy, **kwargs)
elif tensor_parsing:
# parse with only tensor
processed_config = self._process_tensor_config(config_copy, kwargs["tensor_name"])
if not processed_config:
return False, None
if "enabled" in processed_config:
processed_config.pop("enabled")
return True, processed_config
def _validate_gemm(self, gemm):
assert gemm in ["fprop", "wgrad", "dgrad"], (
f"[NVTORCH INSPECT ERROR] Invalid gemm: {gemm}. It must be one of the ['fprop',"
" 'wgrad', 'dgrad']."
)
def _process_transformer_engine_config(self, config, **kwargs):
"""
Return config specific to a particular tensor name and gemm that matches the api args.
"""
if "gemms_struct" in config:
for cfg in config["gemms_struct"]:
self._validate_gemm(cfg["gemm"])
if cfg["gemm"] == kwargs["gemm"]:
if kwargs["tensor_parsing"]:
cfg = self._process_tensor_config(cfg, kwargs["tensor_name"])
if not cfg:
return None
cfg_copy = copy.deepcopy(cfg)
config.pop("gemms_struct")
assert (
"enabled" not in cfg_copy
), "[NVTORCH INSPECT ERROR] Enabled field should not be part of gemms_struct"
config.update(cfg_copy)
return config
return None
if "gemms" in config:
for gemm in config["gemms"]:
self._validate_gemm(gemm)
if kwargs["gemm"] in config["gemms"]:
if kwargs["tensor_parsing"]:
cfg = self._process_tensor_config(config, kwargs["tensor_name"])
if not cfg:
return None
config["gemm"] = kwargs["gemm"]
config.pop("gemms")
return config
return None
raise ValueError(
"[NVTORCH INSPECT ERROR] Provide 'gemms_struct: List[Dict]' or 'gemms: List[str]'"
" in the config yaml"
)
required_kwargs = {
"fp8_gemm_enabled": ["gemm"],
"modify_tensor_enabled": ["tensor_name", "gemm"],
"modify_tensor": ["tensor_name", "gemm"],
"inspect_tensor": ["tensor_name"],
"inspect_tensor_postquantize": ["tensor_name"],
"inspect_tensor_enabled": ["tensor_name"],
"inspect_tensor_postquantize_enabled": ["tensor_name"],
"default": ["tensor_name", "gemm"],
}
# pylint: disable=unused-argument
class TEDefaultFeatures:
"""Transformer Engine API calls default behavior."""
def fp8_gemm_enabled(self, config: Dict, layer_name: str, gemm: str, iteration: int) -> bool:
"""
If the tensor is not processed using *modify_tensor* and the fp8 recipe is enabled,
then the decision whether to cast it to fp8 is based on the value returned by the call *fp8_gemm_enabled*.
If the tensor is processed using *modify_tensor* or fp8 autocast is not enabled,
the result of this call does not matter.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
bool - default is True
"""
return True # if it is false, fp8_gemm will be turned off. Otherwise nothing happens.
def modify_tensor_enabled(
self,
config: Dict,
layer_name: str,
gemm: str,
tensor_name: str,
iteration: int,
) -> bool:
"""
It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name. It has **higher priority** than fp8_gemm, if *modify_tensor_enabled* returns True, then modify_tensor call is invoked for the respective tensor no matter what.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
bool - default is False
"""
return False
def modify_tensor(
self,
config: Dict,
layer_name: str,
gemm: str,
tensor_name: str,
tensor: torch.Tensor,
default_quantizer: Quantizer,
iteration: int,
out: Union[torch.Tensor, QuantizedTensor],
) -> Union[torch.Tensor, QuantizedTensor, None]:
"""
It allows tensor modification.
For example, feature `FakeQuant` uses it to emulate casting to FP8.
It can be invoked at most once for each tensor within a given GEMM operation.
This call is invoked if `modify_tensor_enabled` returns `True` and the feature is enabled for the *tensor_name* and *gemm*.
Then it is called **instead of** the default quantization.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor: torch.Tensor
tensor in high precision,
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
default_quantizer : Quantizer
quantizer which is used to cast the tensor to lower precision
if *modify_tensor* is not invoked. For example,
feature per tensor scale uses it to obtain FP8 dtype of the tensor.
If the recipe indicates that the tensor is not cast - for example,
if running without FP8 autocast, then `default_quantizer=None`,
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
out: Union[torch.Tensor, QuantizedTensor]
output tensor, used in the weight caching mechanism.
Returns
-------
Union[torch.Tensor, transformer_engine.pytorch.QuantizerTensor, None]
can be `torch.Tensor` or one of the Transformer Engine's `QuantizedTensor` -
the rule is that both tensors returned for each GEMM should have the same type.
If both are `Float8Tensor`, then GEMM is run in FP8.
If both are `torch.Tensor`, GEMM is run in high precision.
Please take that into account especially if only one tensor of the GEMM
is processed by the `modify_tensor()`. For example, `FakeQuant`
disabled FP8 GEMM to ensure that the second tensor is also in high precision.
If the tensor is not the input for any GEMM - namely `output`,
`wgrad` and `dgrad` - the return type would match the input type.
Should return `None` if `out` is not `None`.
"""
raise NotImplementedError(
"modify_tensor_enabled() returned True, modify_tensor() was invoked, but it is not"
" handled by any API."
)
def inspect_tensor(
self,
config: Dict,
layer_name: str,
tensor_name: str,
tensor: torch.Tensor,
iteration: int,
tp_group: torch.distributed.ProcessGroup,
) -> None:
"""
The feature is invoked if *inspect_tensor_enabled* returns `True`. It can be used to obtain information on the high precision tensor. For example, it is run by the `LogTensorStats` feature.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: torch.Tensor
tensor in high precision,
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
tp_group: torch.distributed.ProcessGroup
process group for the tensor parallel group. This is used for weight statistics reduction.
This is not reduction group from debug_api.
Returns
-------
Should return nothing.
"""
def inspect_tensor_postquantize(
self,
config: Dict,
layer_name: str,
tensor_name: str,
gemm: str,
tensor: torch.Tensor,
iteration: int,
tp_group: torch.distributed.ProcessGroup,
) -> None:
"""
Similar to *inspect_tensor*, but is run after one of the: fp8 cast, modify_tensor if they are run. If none of the fp8 cast or modify_tensor is invoked, then *inspect_tensor_postquantize* is also not invoked. The feature LogFp8Stats uses this call to collect FP8 statistics after the quantization.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: torch.Tensor
tensor in fp8 or processed tensor after the modify_tensor call,
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
tp_group: torch.distributed.ProcessGroup
process group for the tensor parallel group. This is used for weight statistics reduction.
This is not reduction group from debug_api.
Returns
-------
Should return nothing.
"""
def inspect_tensor_enabled(
self,
config: Dict,
layer_name: str,
tensor_name: str,
iteration: int,
) -> bool:
"""
It is a routing call, which is run at the initialization of the layer. If it returns true, then *inspect_tensor* for a given GEMM and tensor will be invoked.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`].
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
bool - default is False
"""
return False
def inspect_tensor_postquantize_enabled(
self,
config: Dict,
layer_name: str,
gemm: str,
tensor_name: str,
iteration: int,
) -> bool:
"""
It is a routing call, which is run at the initialization of the layer.
If it returns true, then *inspect_tensor_postquantize* for
a given GEMM and tensor will be invoked.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
bool - default is False
"""
return False
@Registry.register_namespace_api(namespace="transformer_engine")
class TransformerEngineAPI(BaseNamespaceAPI):
"""
Transformer Engine API class that contains default APIs that are invoked when a config is not provided
or a layer is not selected in the config.
TransformerEngine specific features must override these APIs wherever required.
The overridden APIs will be invoked whenever the corresponding feature is enabled in the config.
"""
def __init__(self):
BaseNamespaceAPI.__init__(self)
self._default_api_impl = TEDefaultFeatures()
self._cacheable_api_kwargs_map = {
"fp8_gemm": ["gemm"],
"modify_tensor": ["tensor_name", "gemm"],
"inspect_tensor": ["tensor_name"],
"inspect_tensor_postquantize": ["tensor_name"],
"inspect_tensor_enabled": ["tensor_name"],
"inspect_tensor_postquantize_enabled": ["tensor_name"],
"modify_tensor_enabled": ["tensor_name"],
}
def is_multiple_feature_invocation_allowed(self, api_name):
"""
Check if API allows executing multiple features for a single call
"""
return api_name in {
"fp8_gemm_enabled",
"inspect_tensor",
"inspect_tensor_postquantize",
"inspect_tensor_enabled",
"inspect_tensor_postquantize_enabled",
}
def input_assertions_hook(self, api_name, **kwargs):
"""
These args must be passed as kwargs in the API call for all TransformerEngine specific APIs.
"""
if api_name in required_kwargs:
for kwarg in required_kwargs[api_name]:
assert kwarg in kwargs, (
f"[NVTORCH INSPECT ERROR] Cannot route API, too ambiguous. Provide {kwarg} in"
f" {api_name}."
)
else:
for kwarg in required_kwargs["default"]:
assert kwarg in kwargs, (
f"[NVTORCH INSPECT ERROR] Cannot route API, too ambiguous. Provide {kwarg} in"
f" {api_name}."
)
def routing_condition(self, api_name, config, _, feature_obj, **kwargs):
"""
Overridden APIs are selected based on the GEMM name in the config and kwargs.
"""
tensor_parsing = "tensor_name" in required_kwargs[api_name]
gemm_parsing = "gemm" in required_kwargs[api_name]
status, modified_config = feature_obj.parse_config_and_api(
config, gemm_parsing=gemm_parsing, tensor_parsing=tensor_parsing, **kwargs
)
return status, modified_config
def output_assertions_hook(self, api_name, ret, **kwargs):
"""Output hooks used to check correctness of the outputs of the API calls."""
if "enabled" in api_name or api_name == "fp8_gemm":
assert isinstance(ret, bool)
if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]:
assert ret is None
if api_name == "modify_tensor":
assert type(ret) in all_tensor_types
if (
type(ret) == torch.Tensor # pylint: disable=unidiomatic-typecheck
and "dtype" in kwargs
):
if kwargs["dtype"] is not None:
assert ret.dtype == kwargs["dtype"]
def step(self):
"""This function is called by the nvidia-dlframework-inspect after every debug_api.step()"""
STATS_BUFFERS.log_stats()
def end_debug(self):
"""This function is called by the nvidia-dlframework-inspect after every debug_api.end_debug()"""
TEDebugState.reset()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect"""
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper
@Registry.register_feature(namespace="transformer_engine")
class DisableFP8GEMM(TEConfigAPIMapper):
"""
GEMM operations are executed in higher precision, even when FP8 autocast is enabled.
Parameters
----------
gemms: List[str]
list of gemms to disable
- fprop
- dgrad
- wgrad
Example
-------
.. code-block:: yaml
example_disable_fp8_gemm:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8GEMM:
enabled: True
gemms: [dgrad, wgrad]
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for choice between high-precision and FP8 GEMM execution."""
for key in config:
if key != "gemm":
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect"""
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
@Registry.register_feature(namespace="transformer_engine")
class DisableFP8Layer:
"""
Disables all FP8 GEMMs in the layer.
Example
-------
.. code-block:: yaml
example_disable_fp8_layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8Layer:
enabled: True
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
for key in config:
if key not in ["enabled", "gemm"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If FP8 training, disable FP8 for the selected layers if this feature is enabled in config.
debug_api.log_message("FP8 Disabled", layer_name)
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False
def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API
DisableFP8Layer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return config["enabled"], None
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FakeQuant Feature support for nvidia-dlframework-inspect"""
from typing import Optional
import torch
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
from nvdlfw_inspect.utils import append_parent_docstring
import transformer_engine_torch as tex
from transformer_engine.debug.features.api import TEConfigAPIMapper
from transformer_engine.common.recipe import Format
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.fp8 import _default_sf_compute
def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None):
"""Input tensor is quantized to fp8 and then dequantized."""
assert tensor.dtype in (
torch.float,
torch.float16,
torch.bfloat16,
), "[NVTORCH INSPECT ERROR] Unsupported tensor type."
assert tensor.is_cuda, "[NVTORCH INSPECT ERROR] Must be a GPU tensor."
assert fp8_format in {
"FP8E4M3",
"FP8E5M2",
"MXFP8E4M3",
"MXFP8E5M2",
}, (
"[NVTORCH INSPECT ERROR] Only 4 FP8 types: FP8E4M3, FP8E5M2, MXFP8E4M3, MXFP8E5M2 are"
" supported in TE."
)
if fp8_format in ["FP8E4M3", "FP8E5M2"]:
if fp8_format == "FP8E4M3":
fp8_max = Format.E4M3.value.max_fwd
fp8_dtype = tex.DType.kFloat8E4M3
else:
fp8_max = Format.E5M2.value.max_fwd
fp8_dtype = tex.DType.kFloat8E5M2
amax = tensor.abs().max().float()
one = torch.ones(1, device=tensor.device)
scale = _default_sf_compute(amax, one, fp8_max)
quantizer = Float8Quantizer(scale, amax, fp8_dtype)
else:
quantizer = MXFP8Quantizer(fp8_dtype=fp8_format)
if out is not None:
out.copy_(quantizer(tensor).dequantize())
return None
return quantizer(tensor).dequantize()
@Registry.register_feature(namespace="transformer_engine")
@append_parent_docstring(parent=TEConfigAPIMapper)
class FakeQuant(TEConfigAPIMapper):
"""
Disables FP8 GEMM. Fake quantizes chosen tensors to FP8 - using per-tensor scaling factor, not delayed scaling - and runs high-precision GEMM.
.. figure:: ./img/fake_quant.svg
:align: center
Fig 1: Comparison of FP8 FPROP GEMM with the same GEMM in BF16 with fake quantization of activation tensor. Green tensors have the same values, but different dtypes.
Parameters
----------
gemms/gemms_struct: List[str]
list of gemms to fake quantize
- fprop
- dgrad
- wgrad
tensors/tensors_struct: List[str]
list of tensors to fake quantize
- activation
- gradient
- weight
- output
- wgrad
- dgrad
quant_format: str
specifies the FP8 format to use:
- FP8E5M2
- FP8E4M3
Example
-------
.. code-block:: yaml
example_fake_quant_fp8:
enabled: True
layers:
layer_types: [transformer_layer.layernorm_mlp.fc1]
transformer_engine:
FakeQuant:
enabled: True
quant_format: FP8E5M2
gemms_struct:
- gemm: fprop
tensors: [activation, weight]
- gemm: dgrad
tensors: [gradient]
"""
def _supported_formats(self):
"""Returns formats that one can fake quantize tensor to."""
return ["FP8E4M3", "FP8E5M2", "MXFP8E4M3", "MXFP8E5M2"]
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
return False
@api_method
def modify_tensor_enabled(
self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run process_tensor() in the forward."""
return True
@api_method
def modify_tensor(
self,
config,
layer_name: str,
gemm: str,
tensor_name: str,
tensor: torch.Tensor,
iteration: int,
default_quantizer: Quantizer,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
): # pylint: disable=unused-argument
"""API call used to process the tensor."""
for key in config.keys():
if key not in ["gemm", "tensor", "quant_format"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
if "quant_format" not in config:
raise ValueError(
f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor:"
f" quant_format missing for Tensor: {tensor_name} in the config yaml for"
" FakeQuant feature which is a required field"
)
if config["quant_format"] not in self._supported_formats():
raise ValueError(
f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor:"
f" quant_format: {config['quant_format']} for Tensor: {tensor_name} in the config"
" yaml for FakeQuant feature is not supported"
)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=process_tensor: {gemm}, {tensor_name}",
layer_name,
extra_cachable_args=(gemm, tensor_name),
)
quant_format = config["quant_format"]
q_tensor = fake_quantize(tensor, quant_format, out=out)
if dtype is not None:
q_tensor = q_tensor.to(dtype)
return q_tensor
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LogFp8TensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Union
import torch
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from transformer_engine.debug.pytorch.debug_state import TEDebugState
@Registry.register_feature(namespace="transformer_engine")
class LogFp8TensorStats(BaseLogTensorStats):
"""
This feature handles logging of FP8 tensor stats.
In a distributed setting, the auxiliary stats are computed on each rank and gathered after
the `debug_api.step()` call. Do not forget to invoke `debug_api.step()` at every step to log
stats!
`LogFp8TensorStats` supports micro-batching. If multiple forward/backward passes are invoked
per `debug_api.step()`, then stats for all tensors except weights will be accumulated.
`LogFp8TensorStats` can induce significant overhead. To mitigate this issue, logging stats
with `freq > 1` is recommended. If `LogFp8TensorStats` is not used in a given step, the
overhead is smaller. If no other feature is used for the layer, the TE layer will
run as fast as it would without `debug_api` initialized.
Parameters
----------
stats: List[str]
list of statistics to log
- underflows% - percentage of elements of the tensor equal to 0,
tensors/tensors_struct: List[str]
list of tensors to log
- activation
- gradient
- weight
freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps
start_step: Optional[int], default = None
start step of logging stats
end_step: Optional[int], default = None
end step of logging stats
start_end_list: Optional[list([int, int])], default = None
non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step
Example
-------
.. code-block:: yaml
example_fp8_tensor_stat_collection:
enabled: True
layers:
layer_types: [layernorm_linear]
transformer_engine:
LogFp8TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%]
freq: 1
- tensor: gradient
stats: [underflows%]
freq: 5
start_step: 0
end_step: 80
"""
def _get_supported_stats_list(self):
"""Returns stats this feature can log."""
return {"underflows%"}
@api_method
def inspect_tensor_postquantize_enabled(
self, config: Dict, layer_name: str, gemm: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run inspect_tensor_postquantize() in the forward."""
# check whether logging should happen in this iteration
return self._check_params(config, layer_name, iteration=iteration)
@api_method
def inspect_tensor_postquantize(
self,
config: Dict,
layer_name: str,
tensor_name: str,
tensor: Union[torch.Tensor, QuantizedTensor],
rowwise: bool,
iteration: int,
tp_group: torch.distributed.ProcessGroup,
):
"""
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert type(tensor) in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase], (
f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be a quantized tensor when using"
" log_fp8_tensor_stats. Use log_tensor_stats for high precision tensors."
)
# This API can be invoked twice - with the tensor and with the transpose.
# We want to collect the stats once.
if not rowwise:
return # tensor was already seen rowwise in the other gemm
tensor = tensor._data
options = (
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
"fp8",
)
skip_reduction = False
reduction_group = debug_api.get_tensor_reduction_group()
reduce_within_microbatch = tensor_name != "weight"
if tensor_name == "weight":
if TEDebugState.weight_tensor_tp_group_reduce:
reduction_group = tp_group
else:
skip_reduction = True
for stat in config["stats"]:
assert (
stat in self._get_supported_stats_list()
), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported."
STATS_BUFFERS.try_add_buffer(
layer_name=layer_name,
tensor_name=tensor_name,
stats=config["stats"],
options=options,
reduction_group=reduction_group,
reduce_within_microbatch=reduce_within_microbatch,
)
STATS_BUFFERS.feed(layer_name, tensor_name, options, tensor, iteration, skip_reduction)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=inspect_tensor_postquantize: {tensor_name}",
layer_name,
extra_cachable_args=(tensor_name,),
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LogTensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Union
import torch
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method
import nvdlfw_inspect.api as debug_api
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
@Registry.register_feature(namespace="transformer_engine")
class LogTensorStats(BaseLogTensorStats):
"""
This feature handles the logging of basic tensor statistics.
For a distributed setting, the auxiliary stats are computed for each node and gathered after the `debug_api.step()` call. Do not forget to invoke `debug_api.step()` at every step to log stats!
`LogTensorStats` supports micro-batching. If multiple forward/backward passes are invoked per `debug_api.step()`, then stats for all tensors except weights will be accumulated.
`LogTensorStats` can induce significant overhead. To mitigate this issue, logging stats with `freq > 1` is recommended. If `LogTensorStats` is not used in a given step, the overhead is smaller. Moreover, if no other feature is used for the layer, the TE layer will run as fast as it would without `debug_api` initialized.
Parameters
----------
stats: List[str]
list of statistics to log
- min
- max
- mean
- std
- l1_norm
- l2_norm
- cur_amax – maximal absolute value of a tensor,
- dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)`
tensors/tensors_struct: List[str]
list of tensors to log
- activation
- gradient
- weight
- output
- wgrad
- dgrad
freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps
start_step: Optional[int], default = None
start step of logging stats
end_step: Optional[int], default = None
end step of logging stats
start_end_list: Optional[list([int, int])], default = None
non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step
Example
-------
.. code-block:: yaml
example_tensor_stat_collection:
enabled: True
layers:
layer_name_regex_pattern: .*(fc1|self_attention).*
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [mean]
freq: 10
start_step: 5
end_step: 100
- tensor: gradient
stats: [mean, max, min]
freq: 2
start_end_list: [[0, 20], [80, 100]]
- tensor: weight
stats: [dynamic_range]
"""
def _get_supported_stats_list(self):
"""Returns stats this feature can log."""
return BaseLogTensorStats._get_supported_stats_list(None) | {"cur_amax", "dynamic_range"}
@api_method
def inspect_tensor_enabled(
self, config: Dict, layer_name: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run look_at_tensor_before_process() in the forward."""
return self._check_params(config, layer_name, iteration=iteration)
@api_method
def inspect_tensor(
self,
config: Dict,
layer_name: str,
tensor_name: str,
tensor: Union[torch.Tensor, QuantizedTensor],
iteration: int,
tp_group: torch.distributed.ProcessGroup,
):
"""API call used to collect the data about the tensor before process_tensor()/quantization."""
assert (
type(tensor) not in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase]
and tensor.dtype != torch.uint8
), (
f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be in high precision when using"
" log_tensor_stats. Use log_fp8_tensor_stats for FP8 tensors."
)
options = (
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
)
skip_reduction = False
reduction_group = debug_api.get_tensor_reduction_group()
reduce_within_microbatch = tensor_name != "weight"
if tensor_name == "weight":
if TEDebugState.weight_tensor_tp_group_reduce:
reduction_group = tp_group
else:
skip_reduction = True
for stat in config["stats"]:
assert (
stat in self._get_supported_stats_list()
), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported."
STATS_BUFFERS.try_add_buffer(
layer_name=layer_name,
tensor_name=tensor_name,
stats=config["stats"],
options=options,
reduction_group=reduction_group,
reduce_within_microbatch=reduce_within_microbatch,
)
STATS_BUFFERS.feed(layer_name, tensor_name, options, tensor, iteration, skip_reduction)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=look_at_tensor_before_process: {tensor_name}",
layer_name,
extra_cachable_args=(tensor_name),
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""PerTensorScaling Feature support for nvidia-dlframework-inspect"""
from typing import Optional
import torch
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
Float8CurrentScalingQuantizer,
)
from transformer_engine.debug.features.api import TEConfigAPIMapper
def per_tensor_cast(
tensor: torch.Tensor, fp8_dtype: tex.DType, out: Float8Tensor = None
) -> Float8Tensor:
"""
This function computes the scaling factors based on the tensor amax and then casts it to the fp8
"""
assert tensor.dtype in (
torch.float,
torch.float16,
torch.bfloat16,
), "[NVTORCH INSPECT ERROR] Unsupported tensor type for per tensor current scaling"
assert tensor.is_cuda, "[NVTORCH INSPECT ERROR] Must be a GPU tensor."
assert fp8_dtype in {
tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2,
}, "[NVTORCH INSPECT ERROR] Only 2 FP8 types: E4M3 and E5M2 are supported in TE."
tensor = tensor.contiguous()
quantizer = Float8CurrentScalingQuantizer(fp8_dtype)
if out is not None:
quantizer.update_quantized(tensor, out)
return None
return quantizer(tensor)
@Registry.register_feature(namespace="transformer_engine")
class PerTensorScaling(TEConfigAPIMapper):
"""
Allows using per-tensor current scaling for the specific tensors.
Can be used only within `DelayedScaling` recipe autocast.
Parameters
----------
gemms/gemms_struct: List[str]
list of gemms to enable per-tensor current scaling for
- fprop
- dgrad
- wgrad
tensors/tensors_struct: List[str]
list of tensors to enable per-tensor current scaling for
- activation
- gradient
- weight
Example
-------
.. code-block:: yaml
example_per_tensor_scaling:
enabled: True
layers:
layer_types: [transformer_layer.self_attn.layernorm_q]
transformer_engine:
PerTensorScaling:
enabled: True
margin: 1
gemms: [dgrad]
tensors: [weight, activation]
"""
@api_method
def fp8_gemm(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
return False
@api_method
def modify_tensor_enabled(
self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run process_tensor() in the forward."""
return True
@api_method
def modify_tensor(
self,
config,
layer_name: str,
gemm: str,
tensor_name: str,
tensor: torch.Tensor,
iteration: int,
default_quantizer: Quantizer,
out: Optional[Float8Tensor] = None,
dtype: Optional[torch.dtype] = None,
): # pylint: disable=unused-argument
"""API call used to process the tensor."""
for key in config.keys():
if key not in ["gemm", "tensor"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
assert isinstance(default_quantizer, Float8CurrentScalingQuantizer), (
f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor: "
"Per-tensor current scaling can be used only within `DelayedScaling` recipe autocast."
f" {layer_name}"
)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=process_tensor: {gemm}, {tensor_name}",
layer_name,
extra_cachable_args=(gemm, tensor_name),
)
fp8_tensor = per_tensor_cast(tensor, default_quantizer.dtype, out=out)
return fp8_tensor
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Utils for the debug features.
"""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Buffer used for LogTensorStats and LogFp8TensorStats features.
Buffers are fed with tensors, they compute necessary stats and save them.
When log() is called, they gather stats from all nodes, compute combined final stats and log them.
"""
from collections import defaultdict
import torch
from nvdlfw_inspect.utils import gather_along_first_dim
from nvdlfw_inspect.logging import MetricLogger
from transformer_engine.debug.features.utils.stats_computation import (
STATS,
DEPENDENCIES,
stats_to_num,
)
class _Buffer:
"""
Buffer stores temporary statistics for one tensor for one layer.
It also can synchronize between nodes and log final stats.
"""
def __init__(self, layer_name, tensor_name, stats, reduction_group, reduce_within_microbatch):
self.layer_name = layer_name
self.tensor_name = tensor_name
self.reduction_group = reduction_group
self.reduce_within_microbatch = reduce_within_microbatch
self.stats_to_log = stats
self.stats_to_compute = set()
for stat in stats:
self.stats_to_compute = self.stats_to_compute | DEPENDENCIES[stat]
self._buffer = torch.zeros(len(STATS), dtype=torch.float32).cuda()
self._new_buffer = self._buffer.clone()
self._tmp_buffer = self._buffer.clone()
# in case of data parallelism it is possible that layer will not be run on one node
# modified is set to True if node is run
# we do not take not run nodes into account
self.modified = torch.tensor([False], dtype=torch.bool).cuda()
self.iteration = None
self.skip_reduction = False
def _reset_before_next_step(self):
"""Resets the state after the logging."""
self.modified[0] = False
def _gather_helper_stats(self) -> torch.Tensor:
"""
If tensor stats should be accumulated among many nodes,
this method gathers all stats from the nodes where the stat was modified.
"""
if self.skip_reduction:
return self._buffer.unsqueeze(0)
mask = gather_along_first_dim(self.modified, process_group=self.reduction_group)[0]
gathered_buffer, _ = gather_along_first_dim(
self._buffer.unsqueeze(0), process_group=self.reduction_group
)
return gathered_buffer[mask.to(bool)]
def feed(self, tensor, iteration):
"""
feed() is used to add tensor for computing the statistics.
Because of the microbatching, feed() can be used multiple
times for one log().
The main reason of this design: need to combine results for already processed
tensors with the result of the new tensor.
"""
self.iteration = iteration
# If the stats are not reduced within microbatch and
# buffer was fed, we do not change the stats for the tensor.
# It is used for weights and microbatching.
if self.modified[0] and not self.reduce_within_microbatch:
return
# save stats for tensor to tmp buffer
for stat_name in self.stats_to_compute:
fn, _ = STATS[stat_name]
self._tmp_buffer[stats_to_num[stat_name]] = fn(tensor)
# [num_buffers, num_stats]
buffers = torch.cat((self._buffer.unsqueeze(0), self._tmp_buffer.unsqueeze(0)), dim=0)
for stat_name in self.stats_to_compute:
fn, combinator = STATS[stat_name]
if self.modified[0]:
self._new_buffer[stats_to_num[stat_name]] = combinator(buffers)
else:
fn = STATS[stat_name][0]
self._new_buffer[stats_to_num[stat_name]] = fn(tensor)
self._buffer.copy_(self._new_buffer)
self.modified[0] = True
def log(self):
"""
Log the tensor stats and resets buffer.
"""
# [num_active_nodes, num_stats]
gathered_helper_stats = self._gather_helper_stats()
if not self.modified[0]:
return {}
output = {}
for stat_name in self.stats_to_log:
combiner = STATS[stat_name][1]
stat_value = combiner(gathered_helper_stats)
MetricLogger.log_scalar(
f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration
)
output[(self.layer_name, self.tensor_name, stat_name, self.iteration)] = (
stat_value # for debugging purposes
)
self._reset_before_next_step()
return output
class StatsBuffers:
"""
StatsBuffers class represents all buffers of the statistics for all tensors.
It is used to feed the tensors to the correct buffers.
"""
def __init__(self):
self.buffers = {} # (layer_name, tensor_name) -> buffer
self.reduction_group_to_buffer = defaultdict(list)
def reset(self):
"""Resets all buffers."""
self.buffers = {} # (layer_name, tensor_name) -> buffer
self.reduction_group_to_buffer = defaultdict(list)
def try_add_buffer(
self, layer_name, tensor_name, stats, options, reduction_group, reduce_within_microbatch
):
"""If buffer for such combination of stats/tensor_name/... is not present, this method creates it."""
if (layer_name, tensor_name, options) in self.buffers:
return
buffer = _Buffer(layer_name, tensor_name, stats, reduction_group, reduce_within_microbatch)
self.buffers[(layer_name, tensor_name, options)] = buffer
self.reduction_group_to_buffer[reduction_group].append(buffer)
def feed(self, layer_name, tensor_name, options, tensor, iteration, skip_reduction):
"""Feeds the tensor into the respective buffer."""
buffer = self.buffers[(layer_name, tensor_name, options)]
buffer.feed(tensor, iteration)
buffer.skip_reduction = skip_reduction
def log_stats(self):
"""Logs the stats from all the buffers."""
output = {}
for reduction_group, buffers in self.reduction_group_to_buffer.items():
changed_buffers = [
(i, buffer)
for i, buffer in enumerate(buffers)
if gather_along_first_dim(
buffer.modified.unsqueeze(0), process_group=reduction_group
)[0].any()
]
for _, buffer in changed_buffers:
stats = buffer.log()
output.update(stats)
return output
STATS_BUFFERS = StatsBuffers()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Mathematical functions used to tensor statistics computation.
"""
import math
import torch
MAX_FP8_VALUE_INT8 = 126
@torch.compile
def _compute_dynamic_range_top(tensor):
"""Computes the log2 of the amax of the tensor"""
tensor_abs = tensor.abs()
tensor_abs = tensor_abs[tensor_abs != 0]
amax = tensor_abs.max().float()
if not amax.all():
amax = torch.tensor(1, device=tensor.device).to(torch.float)
return torch.log2(amax)
def _compute_dynamic_range_bottom(tensor):
"""Computes the log2 of the amin of the tensor"""
tensor_abs = tensor.abs()
tensor_abs = tensor_abs[tensor_abs != 0]
if tensor_abs.any():
amin = tensor_abs.min().float()
else:
amin = torch.tensor(1, device=tensor.device).to(torch.float)
return torch.log2(amin)
def compute_variance(variances, numels, sums):
"""Welford algorithm is used for numerically stable distributed variance computation."""
mean = torch.sum(sums) / torch.sum(numels)
means = sums / numels
var = torch.sum(numels * (variances - torch.pow((means - mean), 2))) / torch.sum(numels)
return var
def compute_std(variances, numels, sums):
"""Computates standard deviation."""
return torch.sqrt(compute_variance(variances, numels, sums))
# buffers is tensor of shape [nr_buffers, nr_stats]
def _get(buffers, stat_name):
stat_nr = stats_to_num[stat_name]
return buffers[:, stat_nr]
stats_to_num = {
"min": 0,
"max": 1,
"sum": 2,
"mean": 3,
"numel": 4,
"l1_norm": 5,
"l2_norm_square": 6,
"l2_norm": 7,
"variance": 8,
"cur_amax": 9,
"dynamic_range_top": 10,
"dynamic_range_bottom": 11,
"underflows_num": 12,
"std": 13,
"dynamic_range": 14,
"underflows%": 15,
}
DEPENDENCIES = {
"min": {"min"},
"max": {"max"},
"sum": {"sum"},
"mean": {"sum", "numel"},
"numel": {"numel"},
"l1_norm": {"l1_norm"},
"l2_norm_square": {"l2_norm_square", "numel"},
"l2_norm": {"l2_norm_square"},
"variance": {"variance", "numel", "sum"},
"cur_amax": {"cur_amax"},
"dynamic_range_top": {"dynamic_range_top"},
"dynamic_range_bottom": {"dynamic_range_bottom"},
"underflows_num": {"underflows_num"},
"std": {"variance", "numel", "sum"},
"dynamic_range": {"dynamic_range_top", "dynamic_range_bottom"},
"underflows%": {"underflows_num", "numel"},
}
STATS = {
"min": (torch.min, lambda buffers: min(_get(buffers, "min"))),
"max": (torch.max, lambda buffers: max(_get(buffers, "max"))),
"sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))),
"mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))),
"numel": (lambda x: x.numel(), lambda buffers: sum(_get(buffers, "numel"))),
"l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))),
"l2_norm_square": (
lambda x: torch.sum(x**2),
lambda buffers: sum(_get(buffers, "l2_norm_square")),
),
"l2_norm": (
lambda x: torch.norm(x, p=2),
lambda buffers: math.sqrt(sum(_get(buffers, "l2_norm_square"))),
),
"variance": (
torch.var,
lambda buffers: compute_variance(
_get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum")
),
),
"cur_amax": (lambda x: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))),
"dynamic_range_top": (
_compute_dynamic_range_top,
lambda buffers: max(_get(buffers, "dynamic_range_top")),
),
"dynamic_range_bottom": (
_compute_dynamic_range_bottom,
lambda buffers: min(_get(buffers, "dynamic_range_bottom")),
),
"underflows_num": (
lambda x: (x._data == 0).sum(),
lambda buffers: sum(_get(buffers, "underflows_num")),
),
"std": (
torch.std,
lambda buffers: compute_std(
_get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum")
),
),
"dynamic_range": (
lambda x: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x),
lambda buffers: max(_get(buffers, "dynamic_range_top"))
- min(_get(buffers, "dynamic_range_bottom")),
),
"underflows%": (
lambda x: (x == 0).sum() / x.numel() * 100,
lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")),
),
}
......@@ -20,66 +20,17 @@ All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation.
"""
# pylint: disable=wrong-import-position,wrong-import-order
# pylint: disable=wrong-import-position
import logging
import importlib
import importlib.util
from importlib.metadata import version
import sys
# This unused import is needed because the top level `transformer_engine/__init__.py`
# file catches an `ImportError` as a guard for cases where the given framework's
# extensions are not available.
import jax
from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension
from transformer_engine.common import load_framework_extension
load_framework_extension("jax")
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_jax"
if is_package_installed(module_name):
assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
assert is_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'"
)
if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name):
logging.info(
"Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'",
module_name,
)
extension = _get_sys_extension()
try:
so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
try:
so_dir = get_te_path() / "transformer_engine" / "wheel_lib"
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
spec = importlib.util.spec_from_file_location(module_name, so_path)
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
_load_library()
from . import flax
from . import quantize
......
......@@ -89,8 +89,7 @@ class ActLuPrimitive(BasePrimitive):
6,
7,
8,
9,
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, scale_shapes, is_outer
) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer
inner_primitive = None
outer_primitive = None
......@@ -105,13 +104,12 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
te_act_lu_p abstract
"""
del act_enum, scale_shapes
del act_enum
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
......@@ -120,6 +118,11 @@ class ActLuPrimitive(BasePrimitive):
f" {x_aval.shape} and act_len {act_len}"
)
assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
"Current tensor scaling is not yet supported for fused activation and quantization."
" Please do activation in higher-precision then quantize with current tensor scaling."
)
out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
......@@ -151,13 +154,12 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
te_gated_act_lu_p lowering rules
"""
del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
del out_dtype, scale_dtype, act_len, is_outer
x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
......@@ -177,7 +179,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
......@@ -196,7 +197,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=False,
)
)
......@@ -225,7 +225,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
......@@ -248,7 +247,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
),
out_bdims,
)
......@@ -261,7 +259,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
......@@ -272,7 +269,6 @@ class ActLuPrimitive(BasePrimitive):
result_infos,
act_enum,
scale_dtype,
scale_shapes,
act_len,
is_outer,
) # Unused.
......@@ -326,7 +322,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
......@@ -387,7 +382,6 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=True,
)
)
......@@ -415,17 +409,16 @@ class ActLuPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, act_enum, act_len, scale_dtype, scale_shapes, is_outer, mesh, result_types
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
x_rank = len(value_types[0].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var="i", flatten_axis=-2
x_rank - 1, unique_var="ActLuPrimitive_i", flatten_axis=-2
)
x_axes = scale_rules.input_spec + (f"x{x_rank-1}",)
out = (*x_axes[:-2], x_axes[-1])
......@@ -467,8 +460,8 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
name = "te_dact_dbias_quantize_ffi"
multiple_results = True
# out_dtype, scaling_mode, is_2x, scale_dtype, scale_shapes, is_dbias, act_enum, act_len, is_outer
impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11)
# out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer
impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10)
inner_primitive = None
outer_primitive = None
......@@ -482,7 +475,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -491,7 +483,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
"""
te_dact_dbias_quantize_p abstract
"""
del act_enum, scale_shapes
del act_enum
dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype)
assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_dtype
......@@ -500,9 +492,18 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
f" {x_aval.shape} and act_len {act_len}"
)
assert scale_aval.dtype == jnp.float32
assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
"Current tensor scaling is not supported for fused dact and quantization. Please do"
" dact in higher-precision then quantize with current tensor scaling."
)
ir_hidden_size = dz_aval.shape[-1]
gi_hidden_size = act_len * x_aval.shape[-1]
assert act_len * ir_hidden_size == gi_hidden_size
assert (
x_aval.shape[:-2] == dz_aval.shape[:-1]
), "dz and x should have the same leading dimensions"
out_shape = x_aval.shape
out_aval = x_aval.update(shape=out_shape, dtype=out_dtype)
......@@ -512,7 +513,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2)
if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2)
else:
colwise_out_shape = out_shape
......@@ -575,7 +576,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -584,7 +584,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
"""
te_dact_dbias_quantize_p lowering rules
"""
del out_dtype, scale_dtype, scale_shapes, act_len, is_outer
del out_dtype, scale_dtype, act_len, is_outer
dz_aval, x_aval, scale_aval = ctx.avals_in
assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert x_aval.dtype == dz_aval.dtype
......@@ -609,7 +609,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -629,7 +628,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
......@@ -658,7 +656,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -690,7 +687,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
......@@ -704,7 +700,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -714,10 +709,14 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
result_infos,
):
del out_dtype, result_infos, act_enum
del scale_dtype, scale_shapes, act_len, is_outer
del scale_dtype, act_len, is_outer
x_spec = get_padded_spec(arg_infos[1])
scale_spec = get_padded_spec(arg_infos[2])
assert (
scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value
), "Partitioned current tensor scaling is not yet supported."
out_sharding = NamedSharding(
mesh, PartitionSpec(*x_spec), desc="DActLuDBiasQuantizePrimitive.out"
)
......@@ -774,7 +773,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -826,8 +824,12 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh, PartitionSpec(*colwise_scale_inv_spec), desc="ActLuPrimitive.colwise_scale_inv"
)
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
# Ensure dz and x are partitioned the same way.
arg_shardings[0] = NamedSharding(
mesh, PartitionSpec(*x_spec[:-2], x_spec[-1]), desc="DActLuDBiasQuantizePrimitive.dz"
)
arg_shardings = tuple(arg_shardings)
out_shardings = (
out_sharding,
colwise_out_sharding,
......@@ -847,7 +849,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
act_enum=act_enum,
act_len=act_len,
......@@ -874,7 +875,6 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_dbias,
act_enum,
act_len,
......@@ -883,11 +883,11 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
value_types,
result_types,
):
del out_dtype, scale_dtype, scale_shapes, act_enum, act_len, is_outer, mesh, result_types
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
x_rank = len(value_types[1].shape)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank, unique_var="i", flatten_axis=-2
x_rank, unique_var="DActLuDbiasQuantizePrimitive_i", flatten_axis=-2
)
x_axes = scale_rules.input_spec
out = x_axes
......@@ -1020,12 +1020,21 @@ def act_lu(
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((), ()),
is_outer=True,
)
out = out.reshape(output_shape)
return out
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = act_lu(
x=x.astype(jnp.float32),
activation_type=activation_type,
quantizer=None,
)
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
return out
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
......@@ -1044,8 +1053,6 @@ def act_lu(
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
# output does not have act axis
scale_shapes=quantizer.get_scale_shapes(output_shape, flatten_axis=-1),
is_outer=True,
)
......@@ -1101,8 +1108,12 @@ def quantize_dact_dbias(
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer):
out = dact_lu(dz, x, activation_type, quantizer=None)
return _quantize_dbias_impl(out, quantizer, is_dbias=True, flatten_axis=-2)
out = dact_lu(
dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None
)
return _quantize_dbias_impl(
out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2
)
is_gated = act_len == 2
# TE/common does not support DelayedScaling2x for gated-act yet
......@@ -1134,7 +1145,6 @@ def quantize_dact_dbias(
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False, # unused
scale_dtype=jnp.float32, # unused
scale_shapes=((), ()), # unused
is_dbias=False,
act_enum=act_type_id,
act_len=act_len,
......@@ -1145,6 +1155,19 @@ def quantize_dact_dbias(
dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2)
return output.astype(x.dtype), dbias
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform dact in higher precision then quantize after.
out = dact_lu(
dz=dz.astype(jnp.float32),
x=x.astype(jnp.float32),
activation_type=activation_type,
quantizer=None,
)
out, dbias = _quantize_dbias_impl(
out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2
)
return out, dbias
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
......@@ -1158,8 +1181,6 @@ def quantize_dact_dbias(
)
return out, dbias
out_shape = x.shape
(
rowwise_casted_output,
colwise_casted_output,
......@@ -1175,8 +1196,6 @@ def quantize_dact_dbias(
scaling_mode=quantizer.scaling_mode.value,
is_2x=quantizer.is_2x2x(),
scale_dtype=quantizer.get_scale_dtype(),
# output has act axis
scale_shapes=quantizer.get_scale_shapes(out_shape, flatten_axis=-2),
is_dbias=is_dbias,
act_enum=act_type_id,
act_len=act_len,
......@@ -1184,7 +1203,7 @@ def quantize_dact_dbias(
)
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax)
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""JAX te modules"""
from typing import Tuple, Sequence, Union, Dict, List
from typing import Tuple, Sequence, Union, Dict
from functools import partial, reduce
import operator
import jax
......@@ -21,7 +21,7 @@ from ..quantize import (
)
__all__ = ["gemm", "grouped_gemm"]
__all__ = ["gemm"]
num_cublas_streams = 4
......@@ -155,7 +155,7 @@ def _dequantize(x, scale_inv, dq_dtype):
4,
),
)
def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
def __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching
lhs_dq = _dequantize(lhs.data, lhs.scale_inv, lhs.dq_dtype)
rhs_dq = _dequantize(rhs.data, rhs.scale_inv, rhs.dq_dtype)
......@@ -173,13 +173,11 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
return out_3d
def _jax_gemm_delayed_scaling_fp8(
def _jax_gemm_tensor_scaling_fp8(
lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]]
):
"""FP8 GEMM for XLA pattern match"""
assert (
rhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
), "rhs does not have delayed tensor scaling mode"
assert rhs.scaling_mode.is_tensor_scaling(), "rhs does not have tensor scaling mode"
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums
if lhs.data_layout == "T":
......@@ -196,7 +194,7 @@ def _jax_gemm_delayed_scaling_fp8(
precision = (
jax.lax.Precision.HIGHEST if QuantizeConfig.FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT
)
out_3d = __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
out_3d = __jitted_jax_gemm_tensor_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision)
# Reshape [B, M, N] -> [..., M, N]
out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape)
......@@ -271,8 +269,8 @@ def _jax_gemm(
def _jax_gemm_fp8_impl(lhs, rhs):
if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
return _jax_gemm_delayed_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode.is_tensor_scaling():
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums)
if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums)
......@@ -340,8 +338,9 @@ def gemm(
return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set)
"""
def swizzled_scale(scales):
"""Swizzle the scale tensor for FP8 GEMM"""
# Swizzle the scale tensor for FP8 GEMM
assert scales.ndim == 2
rows, cols = scales.shape
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
......@@ -356,7 +355,7 @@ def grouped_gemm(
contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]],
bias_list: List[jnp.ndarray] = None,
) -> List[jnp.ndarray]:
"""Grouped GEMM for multiple pairs of tensors."""
# Grouped GEMM for multiple pairs of tensors.
assert (
len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
), "lhs_list, rhs_list, contracting_dims_list must have the same length"
......@@ -378,7 +377,7 @@ def grouped_gemm(
rhs_shape = rhs.data.shape
out_dtype = lhs.dq_dtype
# For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal data_layout
if lhs.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if lhs.scaling_mode.is_tensor_scaling():
assert not (
lhs.data.dtype == jnp.float8_e5m2 and rhs.data.dtype == jnp.float8_e5m2
), "FP8 GEMM does not support E5M2 * E5M2"
......@@ -406,7 +405,7 @@ def grouped_gemm(
if scaling_mode == ScalingMode.NO_SCALING:
lhs_3d = _shape_normalization(lhs, lhs_dn)
rhs_3d = _shape_normalization(rhs, rhs_dn)
elif scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
elif scaling_mode.is_tensor_scaling():
lhs_3d = _shape_normalization(lhs.data, lhs_dn, lhs.data_layout == "N")
rhs_3d = _shape_normalization(rhs.data, rhs_dn, rhs.data_layout == "T")
elif scaling_mode == ScalingMode.MXFP8_1D_SCALING:
......@@ -443,7 +442,7 @@ def grouped_gemm(
if scaling_mode == ScalingMode.NO_SCALING:
lhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
rhs_sinv_list_.append(jnp.ones(1, dtype=jnp.float32))
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if scaling_mode.is_tensor_scaling():
lhs_sinv_list_.append(lhs.scale_inv)
rhs_sinv_list_.append(rhs.scale_inv)
if scaling_mode == ScalingMode.MXFP8_1D_SCALING:
......@@ -465,3 +464,4 @@ def grouped_gemm(
)
return out_list
"""
......@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type
import transformer_engine_jax
from ..sharding import get_padded_spec as te_get_padded_spec
from ..quantize import ScalingMode, ScaledTensorFactory, QuantizeLayout
from ..quantize import ScaledTensorFactory, QuantizeLayout
TEDType = transformer_engine_jax.DType
......@@ -215,9 +215,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1,
@return: the output of 'f' with the colwise output calculated
"""
should_apply_war = (
quantizer is not None
and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING
and quantizer.is_2x2x()
quantizer is not None and quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x()
)
if not should_apply_war:
return None
......
......@@ -26,7 +26,9 @@ from .misc import (
jax_dtype_to_te_dtype,
te_dtype_to_jax_dtype,
NamedSharding,
get_cudnn_version,
)
from .quantization import _quantize_dbias_impl
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory
from ..quantize import (
......@@ -85,6 +87,10 @@ def is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode: ScalingMode) -> bo
return int(os.getenv("NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE", "0")) == 1
# CuDNN version must be at least this to use MXFP8 fused normalization otherwise unfused norm and quantize will be used
FUSED_MXFP8_NORM_CUDNN_MIN_VERSION = (9, 10, 0)
class NormFwdPrimitive(BasePrimitive):
"""
Layer Normalization Forward FP8 Primitive
......@@ -92,7 +98,7 @@ class NormFwdPrimitive(BasePrimitive):
name = "te_norm_forward_ffi"
multiple_results = True
impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12)
impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11)
inner_primitive = None
outer_primitive = None
......@@ -110,22 +116,37 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
LayerNorm fwd inner primitive abstract
"""
del scale_shapes
x_dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval is None or scale_aval.dtype == jnp.float32
assert (
scaling_mode != ScalingMode.MXFP8_1D_SCALING.value
or get_cudnn_version() >= FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
), (
"MXFP8 Fused Normalization is only supported in CuDNN version"
f" {FUSED_MXFP8_NORM_CUDNN_MIN_VERSION} or higher"
)
assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, (
"Current tensor scaling is not supported for fused norm and quantization. Please do"
" norm in higher-precision then quantize with current tensor scaling."
)
mu_rsigama_dtype = jnp.float32
if norm_type == NVTE_Norm_Type.LayerNorm:
assert gamma_aval.size == beta_aval.size
assert gamma_aval.dtype == beta_aval.dtype, (
f"gamma and beta should have the same dtype, but got {gamma_aval.dtype} and "
f"{beta_aval.dtype}"
)
out_aval = x_aval.update(shape=x_aval.shape, dtype=out_dtype)
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
......@@ -215,13 +236,12 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
LayerNorm fwd lowering rules
"""
del out_dtype, scale_dtype, scale_shapes, is_outer
del out_dtype, scale_dtype, is_outer
x_aval, scale_aval, gamma_aval, beta_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
......@@ -264,7 +284,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
......@@ -293,7 +312,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=False,
)
rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode(
......@@ -329,7 +347,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
):
"""
......@@ -363,7 +380,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
),
out_bdims,
)
......@@ -377,14 +393,13 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
result_infos,
):
del zero_centered_gamma, epsilon, out_dtype, result_infos
del scale_dtype, scale_shapes, is_outer
del scale_dtype, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_spec = (*x_spec[:-1], None)
......@@ -436,7 +451,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
arg_infos,
......@@ -488,7 +502,16 @@ class NormFwdPrimitive(BasePrimitive):
)
amax_sharding = NamedSharding(mesh, PartitionSpec(*amax_spec), desc="NormFwdPrimitive.amax")
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
arg_shardings = list(arg_i.sharding for arg_i in arg_infos)
# Enforce no sharding of hidden dim for x, gamma and beta
arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.x")
arg_shardings[2] = NamedSharding(
mesh, PartitionSpec(*g_spec[:-1], None), desc="NormFwdPrimitive.gamma"
)
arg_shardings[3] = NamedSharding(
mesh, PartitionSpec(*b_spec[:-1], None), desc="NormFwdPrimitive.beta"
)
arg_shardings = tuple(arg_shardings)
out_shardings = (
out_sharding,
colwise_out_sharding,
......@@ -521,7 +544,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode=scaling_mode,
is_2x=is_2x,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_outer=True,
)
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
......@@ -550,7 +572,6 @@ class NormFwdPrimitive(BasePrimitive):
scaling_mode,
is_2x,
scale_dtype,
scale_shapes,
is_outer,
mesh,
value_types,
......@@ -561,14 +582,13 @@ class NormFwdPrimitive(BasePrimitive):
epsilon,
out_dtype,
scale_dtype,
scale_shapes,
is_outer,
mesh,
result_types,
)
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=-1
len(value_types[0].shape), unique_var="NormFwdPrimitive_i", flatten_axis=-1
)
x_axes = scale_rules.input_spec
......@@ -908,14 +928,36 @@ def layernorm_fwd(
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((1,), (1,)),
is_outer=True,
)
return output, mu, rsigma
if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
):
out, mu, rsigma = layernorm_fwd(
x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None
)
out, _ = _quantize_dbias_impl(out, quantizer)
return out, mu, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform norm in higher precision then quantize after.
out, mu, rsigma = layernorm_fwd(
x=x,
gamma=gamma,
beta=beta,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
quantizer=None,
)
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
return out, mu, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False
(
rowwise_casted_output,
......@@ -937,13 +979,12 @@ def layernorm_fwd(
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
is_outer=True,
)
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......@@ -1090,14 +1131,33 @@ def rmsnorm_fwd(
scaling_mode=ScalingMode.NO_SCALING.value,
is_2x=False,
scale_dtype=jnp.float32,
scale_shapes=((), ()),
is_outer=True,
)
return output, rsigma
if (
quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING
and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION
):
out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None)
out, _ = _quantize_dbias_impl(out, quantizer)
return out, rsigma
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Current scaling does not support fused operations. Perform norm in higher precision then quantize after.
out, rsigma = rmsnorm_fwd(
x=x,
gamma=gamma,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
quantizer=None,
)
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype)
return out, rsigma
is_2x2x = quantizer.is_2x2x()
# TE/common normalization doesn't support 2x delayed scaling
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
is_2x2x = False
(
rowwise_casted_output,
......@@ -1119,13 +1179,12 @@ def rmsnorm_fwd(
scaling_mode=quantizer.scaling_mode.value,
is_2x=is_2x2x,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape),
is_outer=True,
)
quantizer.update(updated_amax)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if quantizer.is_2x2x() and quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling():
colwise_casted_output = jnp.transpose(
rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1))
)
......
......@@ -27,7 +27,13 @@ from .misc import (
)
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor2x, ScaledTensor, ScaledTensorFactory
from ..quantize import Quantizer, QuantizeLayout, DelayedScaleQuantizer, ScalingMode
from ..quantize import (
Quantizer,
QuantizeLayout,
DelayedScaleQuantizer,
ScalingMode,
compute_scale_from_amax,
)
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
......@@ -53,8 +59,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
6,
7,
8,
9,
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer
) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer
inner_primitive = None
outer_primitive = None
......@@ -68,14 +73,12 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
"""
te_dbias_quantize_p abstract
"""
del scale_shapes
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_shape = x_aval.shape
......@@ -94,7 +97,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis)
else:
colwise_out_shape = out_shape
......@@ -166,14 +169,13 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
"""
te_dbias_quantize_p lowering rules
"""
del out_dtype, scale_dtype, scale_shapes, is_outer
del out_dtype, scale_dtype, is_outer
x_aval, scale_aval = ctx.avals_in
assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert scale_aval.dtype == jnp.float32
......@@ -196,7 +198,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
......@@ -221,7 +222,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
is_outer=False,
)
......@@ -254,7 +254,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
):
......@@ -278,7 +277,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
),
out_bdims,
......@@ -291,14 +289,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
arg_infos,
result_infos,
):
del (out_dtype, result_infos, scale_dtype, scale_shapes, is_outer) # Unused.
del (out_dtype, result_infos, scale_dtype, is_outer) # Unused.
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(
......@@ -307,7 +305,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
colwise_out_spec = x_spec
......@@ -363,7 +361,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
......@@ -371,6 +368,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
result_infos,
):
del result_infos, is_outer
x_spec = get_padded_spec(arg_infos[0])
scale_spec = get_padded_spec(arg_infos[1])
out_sharding = NamedSharding(
......@@ -379,7 +377,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
desc="DBiasQuantizePrimitive.out_sharding",
)
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis)
else:
colwise_out_spec = x_spec
......@@ -445,7 +443,6 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout=q_layout,
flatten_axis=flatten_axis,
scale_dtype=scale_dtype,
scale_shapes=scale_shapes,
is_dbias=is_dbias,
is_outer=True,
)
......@@ -478,17 +475,18 @@ class DBiasQuantizePrimitive(BasePrimitive):
q_layout,
flatten_axis,
scale_dtype,
scale_shapes,
is_dbias,
is_outer,
mesh,
value_types,
result_types,
):
del out_dtype, scale_dtype, scale_shapes, is_outer, mesh, result_types
del out_dtype, scale_dtype, is_outer, mesh, result_types
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var="i", flatten_axis=flatten_axis
len(value_types[0].shape),
unique_var="DBiasQuantizePrimitive_i",
flatten_axis=flatten_axis,
)
x_axes = scale_rules.input_spec
......@@ -496,7 +494,7 @@ class DBiasQuantizePrimitive(BasePrimitive):
out = x_axes
if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value):
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
if ScalingMode(scaling_mode).is_tensor_scaling():
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis))
else:
colwise_out = x_axes
......@@ -612,6 +610,13 @@ def _quantize_dbias_impl(
return x, _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis)
return x, None
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
# Globally reduce amax across all devices for current scaling so we have a single global scale.
# This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this
# until the tensor is dequantized (e.g. in the GEMM).
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32)
scale = compute_scale_from_amax(amax, quantizer.q_dtype)
if isinstance(quantizer, DelayedScaleQuantizer):
scale = quantizer.scale
......@@ -630,12 +635,11 @@ def _quantize_dbias_impl(
q_layout=quantizer.q_layout.value,
flatten_axis=flatten_axis,
scale_dtype=quantizer.get_scale_dtype(),
scale_shapes=quantizer.get_scale_shapes(x.shape, flatten_axis=flatten_axis),
is_dbias=is_dbias,
is_outer=True,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING and quantizer.is_2x2x():
if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x():
colwise_scale_inv = rowwise_scale_inv
quantizer.update(updated_amax)
......
......@@ -28,8 +28,8 @@
#include "common/util/logging.h"
#include "extensions/ffi.h"
#include "extensions/misc.h"
#include "extensions/utils.h"
#include "transformer_engine/activation.h"
#include "utils.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
......
......@@ -44,11 +44,16 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, static_cast<DType>(out_dtype), output_shape);
NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
"Current tensor scaling does not support fused operations yet. Please call this primitive "
"in higher-precision then quantize with current scaling.");
if (is_fp8_dtype(out_dtype)) {
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
......@@ -152,6 +157,11 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};
NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
"Current tensor scaling does not support fused operations yet. Please call this primitive "
"in higher-precision then quantize with current scaling.");
// Evil hack to specify TE impl
// Note: nvte_quantize_dbias_dgelu chooses its internal impl based
// on what pointers are allocated, e.g. whether to output with
......@@ -219,6 +229,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto act_type = static_cast<NVTE_Activation_Type>(act_enum);
auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis
NVTE_CHECK(
scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING,
"Current tensor scaling does not support fused operations yet. Please call this primitive "
"in higher-precision then quantize with current scaling.");
auto *output = output_buf->untyped_data();
auto *colwise_output = colwise_output_buf->untyped_data();
auto *dbias = dbias_buf->untyped_data();
......@@ -230,8 +245,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1] * act_len, ir_dz_shape == input_dims
auto act_len = act_input_dims[act_input_dims.size() - 2];
NVTE_CHECK(act_input_dims.back() == input_dims.back(),
"Shape mismatch between activation input and gradient input");
NVTE_CHECK(act_len == 1 || act_len == 2,
"The value of the activation dimension (axis=-2) must be 1 for non-gated or 2 for "
"gated activation, got ",
act_len);
auto m = product(act_input_dims, 0, act_input_dims.size() - 2);
auto n = input_dims.back();
......@@ -242,8 +260,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
auto dbias_shape = std::vector<size_t>{n * act_len};
std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());
auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
auto act_input_tensor = TensorWrapper(act_input, act_input_shape, in_dtype);
auto input_tensor =
TensorWrapper(input, input_shape, convert_ffi_datatype_to_te_dtype(input_buf.element_type()));
auto act_input_tensor = TensorWrapper(
act_input, act_input_shape, convert_ffi_datatype_to_te_dtype(act_input_buf.element_type()));
auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
output_tensor.set_rowwise_data(output, out_dtype, output_shape);
......@@ -251,7 +271,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf,
if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
cudaMemsetAsync(amax, 0, sizeof(float), stream);
nvte_memset(amax, 0, sizeof(float), stream);
output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
output_tensor.set_rowwise_scale_inv(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment