Unverified Commit be055eb0 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Support default process group with FP8 current scaling (#1621)



* Handle case where FP8 current scaling quantizer gets default process group
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid canonicalizing TP group since it may not be initialized
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3b1f5a11
......@@ -317,7 +317,6 @@ def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size):
device=device,
with_amax_reduction=True,
amax_reduction_group=tp_group,
amax_reduction_size=tp_size,
)
quantizer = quantizer_class(
fp8_dtype=fp8_dtype,
......
......@@ -149,7 +149,6 @@ class Float8CurrentScalingQuantizer : public Quantizer {
DType dtype;
bool with_amax_reduction;
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
int amax_reduction_size;
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
......
......@@ -145,24 +145,21 @@ Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& q
const at::Tensor& scale = quantizer.attr("scale").cast<at::Tensor>();
const at::Tensor& amax = quantizer.attr("amax").cast<at::Tensor>();
const DType type = quantizer.attr("dtype").cast<DType>();
// For current scaling, need several other components:
// 1. with_amax_reduction: bool
// 2. amax_reduction_group: torch.distributed.ProcessGroup or None
// 3. amax_reduction_size: int
const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast<bool>();
const py::object amax_reduction_group_obj = quantizer.attr("amax_reduction_group");
const c10::intrusive_ptr<dist_group_type> amax_reduction_group =
amax_reduction_group_obj.is_none()
? nullptr
: amax_reduction_group_obj.cast<c10::intrusive_ptr<dist_group_type>>();
const int amax_reduction_size = quantizer.attr("amax_reduction_size").cast<int>();
this->amax = amax;
this->scale = scale;
this->dtype = type;
// Get amax reduction group if needed
const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast<bool>();
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
if (with_amax_reduction) {
auto group = quantizer.attr("_canonicalized_amax_reduction_group")();
NVTE_CHECK(!group.is_none(),
"Float8CurrentScalingQuantizer could not canonicalize amax reduction group");
amax_reduction_group = group.cast<c10::intrusive_ptr<dist_group_type>>();
}
this->with_amax_reduction = with_amax_reduction;
this->amax_reduction_group = amax_reduction_group;
this->amax_reduction_size = amax_reduction_size;
// fp8 current scaling specific quantization params
this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast<bool>();
......
......@@ -1416,9 +1416,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_size = self.tp_size
else:
# set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here)
self.quantizers["scaling_bwd"][
......
......@@ -1576,9 +1576,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_size = self.tp_size
else:
# grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer
self.quantizers["scaling_bwd"][
......@@ -1602,6 +1599,3 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_size = self.tp_size
......@@ -1221,9 +1221,6 @@ class Linear(TransformerEngineBaseModule):
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_size = self.tp_size
else:
# set grad_output_quantizer with amax epsilon and power_2_scale
self.quantizers["scaling_bwd"][
......@@ -1241,6 +1238,3 @@ class Linear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_size = self.tp_size
......@@ -11,7 +11,7 @@ import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..utils import devices_match, non_tn_fp8_gemm_supported
from ..utils import canonicalize_process_group, devices_match, non_tn_fp8_gemm_supported
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..constants import dist_group_type
......@@ -194,7 +194,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
"""amax reduction options"""
with_amax_reduction: bool
amax_reduction_group: Optional[dist_group_type]
amax_reduction_size: Optional[int]
"""Options about how to quantize the tensor"""
force_pow_2_scales: bool
amax_epsilon: float
......@@ -208,7 +207,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
columnwise: bool = True,
with_amax_reduction: bool = False,
amax_reduction_group: Optional[dist_group_type] = None,
amax_reduction_size: Optional[int] = 1,
force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0,
) -> None:
......@@ -218,7 +216,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
self.dtype = fp8_dtype
self.with_amax_reduction = with_amax_reduction
self.amax_reduction_group = amax_reduction_group
self.amax_reduction_size = amax_reduction_size
self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon
......@@ -327,6 +324,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
quantizer=self,
)
def _canonicalized_amax_reduction_group(self) -> dist_group_type:
"""Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group)
class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data
......
......@@ -386,3 +386,16 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None:
# Pop NVTX range
torch.cuda.nvtx.range_pop()
def canonicalize_process_group(
group: Optional[torch.distributed.ProcessGroup],
) -> torch.distributed.ProcessGroup:
"""Convert to PyTorch process group
If `None`, returns default process group.
"""
if group is None:
return torch.distributed.distributed_c10d._get_default_group()
return group
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment