If the tensor is not processed using *modify_tensor* and the fp8 recipe is enabled,
then the decision whether to cast it to fp8 is based on the value returned by the call *fp8_gemm_enabled*.
If the tensor is processed using *modify_tensor* or fp8 autocast is not enabled,
the result of this call does not matter.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
bool - default is True
"""
returnTrue# if it is false, fp8_gemm will be turned off. Otherwise nothing happens.
defmodify_tensor_enabled(
self,
config:Dict,
layer_name:str,
gemm:str,
tensor_name:str,
iteration:int,
)->bool:
"""
It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name. It has **higher priority** than fp8_gemm, if *modify_tensor_enabled* returns True, then modify_tensor call is invoked for the respective tensor no matter what.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
bool - default is False
"""
returnFalse
defmodify_tensor(
self,
config:Dict,
layer_name:str,
gemm:str,
tensor_name:str,
tensor:torch.Tensor,
default_quantizer:Quantizer,
iteration:int,
out:Union[torch.Tensor,QuantizedTensor],
)->Union[torch.Tensor,QuantizedTensor,None]:
"""
It allows tensor modification.
For example, feature `FakeQuant` uses it to emulate casting to FP8.
It can be invoked at most once for each tensor within a given GEMM operation.
This call is invoked if `modify_tensor_enabled` returns `True` and the feature is enabled for the *tensor_name* and *gemm*.
Then it is called **instead of** the default quantization.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor: torch.Tensor
tensor in high precision,
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
default_quantizer : Quantizer
quantizer which is used to cast the tensor to lower precision
if *modify_tensor* is not invoked. For example,
feature per tensor scale uses it to obtain FP8 dtype of the tensor.
If the recipe indicates that the tensor is not cast - for example,
if running without FP8 autocast, then `default_quantizer=None`,
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
out: Union[torch.Tensor, QuantizedTensor]
output tensor, used in the weight caching mechanism.
can be `torch.Tensor` or one of the Transformer Engine's `QuantizedTensor` -
the rule is that both tensors returned for each GEMM should have the same type.
If both are `Float8Tensor`, then GEMM is run in FP8.
If both are `torch.Tensor`, GEMM is run in high precision.
Please take that into account especially if only one tensor of the GEMM
is processed by the `modify_tensor()`. For example, `FakeQuant`
disabled FP8 GEMM to ensure that the second tensor is also in high precision.
If the tensor is not the input for any GEMM - namely `output`,
`wgrad` and `dgrad` - the return type would match the input type.
Should return `None` if `out` is not `None`.
"""
raiseNotImplementedError(
"modify_tensor_enabled() returned True, modify_tensor() was invoked, but it is not"
" handled by any API."
)
definspect_tensor(
self,
config:Dict,
layer_name:str,
tensor_name:str,
tensor:torch.Tensor,
iteration:int,
tp_group:torch.distributed.ProcessGroup,
)->None:
"""
The feature is invoked if *inspect_tensor_enabled* returns `True`. It can be used to obtain information on the high precision tensor. For example, it is run by the `LogTensorStats` feature.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: torch.Tensor
tensor in high precision,
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
tp_group: torch.distributed.ProcessGroup
process group for the tensor parallel group. This is used for weight statistics reduction.
This is not reduction group from debug_api.
Returns
-------
Should return nothing.
"""
definspect_tensor_postquantize(
self,
config:Dict,
layer_name:str,
tensor_name:str,
gemm:str,
tensor:torch.Tensor,
iteration:int,
tp_group:torch.distributed.ProcessGroup,
)->None:
"""
Similar to *inspect_tensor*, but is run after one of the: fp8 cast, modify_tensor if they are run. If none of the fp8 cast or modify_tensor is invoked, then *inspect_tensor_postquantize* is also not invoked. The feature LogFp8Stats uses this call to collect FP8 statistics after the quantization.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: torch.Tensor
tensor in fp8 or processed tensor after the modify_tensor call,
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
tp_group: torch.distributed.ProcessGroup
process group for the tensor parallel group. This is used for weight statistics reduction.
This is not reduction group from debug_api.
Returns
-------
Should return nothing.
"""
definspect_tensor_enabled(
self,
config:Dict,
layer_name:str,
tensor_name:str,
iteration:int,
)->bool:
"""
It is a routing call, which is run at the initialization of the layer. If it returns true, then *inspect_tensor* for a given GEMM and tensor will be invoked.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`].
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
bool - default is False
"""
returnFalse
definspect_tensor_postquantize_enabled(
self,
config:Dict,
layer_name:str,
gemm:str,
tensor_name:str,
iteration:int,
)->bool:
"""
It is a routing call, which is run at the initialization of the layer.
If it returns true, then *inspect_tensor_postquantize* for
a given GEMM and tensor will be invoked.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Disables FP8 GEMM. Fake quantizes chosen tensors to FP8 - using per-tensor scaling factor, not delayed scaling - and runs high-precision GEMM.
.. figure:: ./img/fake_quant.svg
:align: center
Fig 1: Comparison of FP8 FPROP GEMM with the same GEMM in BF16 with fake quantization of activation tensor. Green tensors have the same values, but different dtypes.
This feature handles the logging of basic tensor statistics.
For a distributed setting, the auxiliary stats are computed for each node and gathered after the `debug_api.step()` call. Do not forget to invoke `debug_api.step()` at every step to log stats!
`LogTensorStats` supports micro-batching. If multiple forward/backward passes are invoked per `debug_api.step()`, then stats for all tensors except weights will be accumulated.
`LogTensorStats` can induce significant overhead. To mitigate this issue, logging stats with `freq > 1` is recommended. If `LogTensorStats` is not used in a given step, the overhead is smaller. Moreover, if no other feature is used for the layer, the TE layer will run as fast as it would without `debug_api` initialized.
Parameters
----------
stats: List[str]
list of statistics to log
- min
- max
- mean
- std
- l1_norm
- l2_norm
- cur_amax – maximal absolute value of a tensor,
- dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)`
tensors/tensors_struct: List[str]
list of tensors to log
- activation
- gradient
- weight
- output
- wgrad
- dgrad
freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps