"...git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "9090bf02e74334a8020b454814e0d00fa780fd79"
Unverified Commit be055eb0 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Support default process group with FP8 current scaling (#1621)



* Handle case where FP8 current scaling quantizer gets default process group
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid canonicalizing TP group since it may not be initialized
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3b1f5a11
...@@ -317,7 +317,6 @@ def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size): ...@@ -317,7 +317,6 @@ def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size):
device=device, device=device,
with_amax_reduction=True, with_amax_reduction=True,
amax_reduction_group=tp_group, amax_reduction_group=tp_group,
amax_reduction_size=tp_size,
) )
quantizer = quantizer_class( quantizer = quantizer_class(
fp8_dtype=fp8_dtype, fp8_dtype=fp8_dtype,
......
...@@ -149,7 +149,6 @@ class Float8CurrentScalingQuantizer : public Quantizer { ...@@ -149,7 +149,6 @@ class Float8CurrentScalingQuantizer : public Quantizer {
DType dtype; DType dtype;
bool with_amax_reduction; bool with_amax_reduction;
c10::intrusive_ptr<dist_group_type> amax_reduction_group; c10::intrusive_ptr<dist_group_type> amax_reduction_group;
int amax_reduction_size;
bool force_pow_2_scales = false; bool force_pow_2_scales = false;
float amax_epsilon = 0.0; float amax_epsilon = 0.0;
......
...@@ -145,24 +145,21 @@ Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& q ...@@ -145,24 +145,21 @@ Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& q
const at::Tensor& scale = quantizer.attr("scale").cast<at::Tensor>(); const at::Tensor& scale = quantizer.attr("scale").cast<at::Tensor>();
const at::Tensor& amax = quantizer.attr("amax").cast<at::Tensor>(); const at::Tensor& amax = quantizer.attr("amax").cast<at::Tensor>();
const DType type = quantizer.attr("dtype").cast<DType>(); const DType type = quantizer.attr("dtype").cast<DType>();
// For current scaling, need several other components:
// 1. with_amax_reduction: bool
// 2. amax_reduction_group: torch.distributed.ProcessGroup or None
// 3. amax_reduction_size: int
const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast<bool>();
const py::object amax_reduction_group_obj = quantizer.attr("amax_reduction_group");
const c10::intrusive_ptr<dist_group_type> amax_reduction_group =
amax_reduction_group_obj.is_none()
? nullptr
: amax_reduction_group_obj.cast<c10::intrusive_ptr<dist_group_type>>();
const int amax_reduction_size = quantizer.attr("amax_reduction_size").cast<int>();
this->amax = amax; this->amax = amax;
this->scale = scale; this->scale = scale;
this->dtype = type; this->dtype = type;
// Get amax reduction group if needed
const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast<bool>();
c10::intrusive_ptr<dist_group_type> amax_reduction_group;
if (with_amax_reduction) {
auto group = quantizer.attr("_canonicalized_amax_reduction_group")();
NVTE_CHECK(!group.is_none(),
"Float8CurrentScalingQuantizer could not canonicalize amax reduction group");
amax_reduction_group = group.cast<c10::intrusive_ptr<dist_group_type>>();
}
this->with_amax_reduction = with_amax_reduction; this->with_amax_reduction = with_amax_reduction;
this->amax_reduction_group = amax_reduction_group; this->amax_reduction_group = amax_reduction_group;
this->amax_reduction_size = amax_reduction_size;
// fp8 current scaling specific quantization params // fp8 current scaling specific quantization params
this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast<bool>(); this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast<bool>();
......
...@@ -1416,9 +1416,6 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1416,9 +1416,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.quantizers["scaling_fwd"][ self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_size = self.tp_size
else: else:
# set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here) # set grad_output_quantizer with amax epsilon and power_2_scale (no amax reduction here)
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
......
...@@ -1576,9 +1576,6 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1576,9 +1576,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.quantizers["scaling_fwd"][ self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_size = self.tp_size
else: else:
# grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer # grad_fc2_output_quantizer: set configs about amax epsilon and power_2_scale for grad_fc2_output_quantizer
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
...@@ -1602,6 +1599,3 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1602,6 +1599,3 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_size = self.tp_size
...@@ -1221,9 +1221,6 @@ class Linear(TransformerEngineBaseModule): ...@@ -1221,9 +1221,6 @@ class Linear(TransformerEngineBaseModule):
self.quantizers["scaling_fwd"][ self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
self.quantizers["scaling_fwd"][
tex.FP8FwdTensors.GEMM1_INPUT
].amax_reduction_size = self.tp_size
else: else:
# set grad_output_quantizer with amax epsilon and power_2_scale # set grad_output_quantizer with amax epsilon and power_2_scale
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
...@@ -1241,6 +1238,3 @@ class Linear(TransformerEngineBaseModule): ...@@ -1241,6 +1238,3 @@ class Linear(TransformerEngineBaseModule):
self.quantizers["scaling_bwd"][ self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_group = self.tp_group ].amax_reduction_group = self.tp_group
self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
].amax_reduction_size = self.tp_size
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import DType as TE_DType
from ..utils import devices_match, non_tn_fp8_gemm_supported from ..utils import canonicalize_process_group, devices_match, non_tn_fp8_gemm_supported
from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..constants import dist_group_type from ..constants import dist_group_type
...@@ -194,7 +194,6 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -194,7 +194,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
"""amax reduction options""" """amax reduction options"""
with_amax_reduction: bool with_amax_reduction: bool
amax_reduction_group: Optional[dist_group_type] amax_reduction_group: Optional[dist_group_type]
amax_reduction_size: Optional[int]
"""Options about how to quantize the tensor""" """Options about how to quantize the tensor"""
force_pow_2_scales: bool force_pow_2_scales: bool
amax_epsilon: float amax_epsilon: float
...@@ -208,7 +207,6 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -208,7 +207,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
columnwise: bool = True, columnwise: bool = True,
with_amax_reduction: bool = False, with_amax_reduction: bool = False,
amax_reduction_group: Optional[dist_group_type] = None, amax_reduction_group: Optional[dist_group_type] = None,
amax_reduction_size: Optional[int] = 1,
force_pow_2_scales: bool = False, force_pow_2_scales: bool = False,
amax_epsilon: float = 0.0, amax_epsilon: float = 0.0,
) -> None: ) -> None:
...@@ -218,7 +216,6 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -218,7 +216,6 @@ class Float8CurrentScalingQuantizer(Quantizer):
self.dtype = fp8_dtype self.dtype = fp8_dtype
self.with_amax_reduction = with_amax_reduction self.with_amax_reduction = with_amax_reduction
self.amax_reduction_group = amax_reduction_group self.amax_reduction_group = amax_reduction_group
self.amax_reduction_size = amax_reduction_size
self.force_pow_2_scales = force_pow_2_scales self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon self.amax_epsilon = amax_epsilon
...@@ -327,6 +324,10 @@ class Float8CurrentScalingQuantizer(Quantizer): ...@@ -327,6 +324,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
quantizer=self, quantizer=self,
) )
def _canonicalized_amax_reduction_group(self) -> dist_group_type:
"""Get process group for amax reduction"""
return canonicalize_process_group(self.amax_reduction_group)
class Float8Tensor(Float8TensorBase, QuantizedTensor): class Float8Tensor(Float8TensorBase, QuantizedTensor):
"""Experimental tensor class with FP8 data """Experimental tensor class with FP8 data
......
...@@ -386,3 +386,16 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None: ...@@ -386,3 +386,16 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None:
# Pop NVTX range # Pop NVTX range
torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop()
def canonicalize_process_group(
group: Optional[torch.distributed.ProcessGroup],
) -> torch.distributed.ProcessGroup:
"""Convert to PyTorch process group
If `None`, returns default process group.
"""
if group is None:
return torch.distributed.distributed_c10d._get_default_group()
return group
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment