Unverified Commit e3e0375d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix fp8_buf for Linear and LayerNormLinear (#1633)



* Fix fp8_buf for Linear and LayerNormLinear
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 31f5c2d8
......@@ -1262,6 +1262,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None,
fp8_output: Optional[bool] = False,
fp8_grad: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
"""
Apply layer normalization to the input followed by a linear transformation.
......@@ -1292,6 +1293,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None:
is_first_microbatch = False
if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
......@@ -1319,7 +1327,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
output_quantizer,
grad_output_quantizer,
grad_input_quantizer,
) = self._get_quantizers(fp8_output)
) = self._get_quantizers(fp8_output, fp8_grad)
if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply
......@@ -1384,7 +1392,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return out, ln_out
return out
def _get_quantizers(self, fp8_output):
def _get_quantizers(self, fp8_output, fp8_grad):
if not self.fp8:
return [None] * 5
grad_input_quantizer = None
......@@ -1399,6 +1407,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
if torch.is_grad_enabled():
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True
if fp8_grad:
grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
return (
input_quantizer,
......
......@@ -1436,6 +1436,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None:
is_first_microbatch = False
fp8_output = False
if self.ub_overlap_rs:
if get_ub("fc2_fprop").is_fp8_ubuf():
fp8_output = True
with self.prepare_forward(inp, num_gemms=2) as inp:
# Get quantizers
(
......@@ -1447,7 +1452,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
grad_fc1_output_quantizer,
grad_fc2_output_quantizer,
grad_input_quantizer,
) = self._get_quantizers()
) = self._get_quantizers(fp8_output)
# Get weight tensors
fc1_weight = self.fc1_weight
......@@ -1533,7 +1538,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
return out, ln_out
return out
def _get_quantizers(self):
def _get_quantizers(self, fp8_output):
(
fc1_input_quantizer,
fc1_weight_quantizer,
......@@ -1555,6 +1560,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
)
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True
if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT]
if torch.is_grad_enabled():
grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1
......
......@@ -1104,6 +1104,13 @@ class Linear(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None:
is_first_microbatch = False
if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward(
inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment