Unverified Commit e3e0375d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix fp8_buf for Linear and LayerNormLinear (#1633)



* Fix fp8_buf for Linear and LayerNormLinear
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 31f5c2d8
...@@ -1262,6 +1262,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1262,6 +1262,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
inp: torch.Tensor, inp: torch.Tensor,
is_first_microbatch: Optional[bool] = None, is_first_microbatch: Optional[bool] = None,
fp8_output: Optional[bool] = False, fp8_output: Optional[bool] = False,
fp8_grad: Optional[bool] = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
""" """
Apply layer normalization to the input followed by a linear transformation. Apply layer normalization to the input followed by a linear transformation.
...@@ -1292,6 +1293,13 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1292,6 +1293,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None: if skip_fp8_weight_update is not None:
is_first_microbatch = False is_first_microbatch = False
if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward( with self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp: ) as inp:
...@@ -1319,7 +1327,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1319,7 +1327,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
output_quantizer, output_quantizer,
grad_output_quantizer, grad_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
) = self._get_quantizers(fp8_output) ) = self._get_quantizers(fp8_output, fp8_grad)
if torch.is_grad_enabled(): if torch.is_grad_enabled():
fwd_fn = _LayerNormLinear.apply fwd_fn = _LayerNormLinear.apply
...@@ -1384,7 +1392,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1384,7 +1392,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
return out, ln_out return out, ln_out
return out return out
def _get_quantizers(self, fp8_output): def _get_quantizers(self, fp8_output, fp8_grad):
if not self.fp8: if not self.fp8:
return [None] * 5 return [None] * 5
grad_input_quantizer = None grad_input_quantizer = None
...@@ -1399,6 +1407,8 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1399,6 +1407,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
if torch.is_grad_enabled(): if torch.is_grad_enabled():
grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
grad_output_quantizer.internal = True grad_output_quantizer.internal = True
if fp8_grad:
grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
return ( return (
input_quantizer, input_quantizer,
......
...@@ -1436,6 +1436,11 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1436,6 +1436,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None: if skip_fp8_weight_update is not None:
is_first_microbatch = False is_first_microbatch = False
fp8_output = False
if self.ub_overlap_rs:
if get_ub("fc2_fprop").is_fp8_ubuf():
fp8_output = True
with self.prepare_forward(inp, num_gemms=2) as inp: with self.prepare_forward(inp, num_gemms=2) as inp:
# Get quantizers # Get quantizers
( (
...@@ -1447,7 +1452,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1447,7 +1452,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
grad_fc1_output_quantizer, grad_fc1_output_quantizer,
grad_fc2_output_quantizer, grad_fc2_output_quantizer,
grad_input_quantizer, grad_input_quantizer,
) = self._get_quantizers() ) = self._get_quantizers(fp8_output)
# Get weight tensors # Get weight tensors
fc1_weight = self.fc1_weight fc1_weight = self.fc1_weight
...@@ -1533,7 +1538,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1533,7 +1538,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
return out, ln_out return out, ln_out
return out return out
def _get_quantizers(self): def _get_quantizers(self, fp8_output):
( (
fc1_input_quantizer, fc1_input_quantizer,
fc1_weight_quantizer, fc1_weight_quantizer,
...@@ -1555,6 +1560,8 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -1555,6 +1560,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
) )
fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT] fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
fc2_weight_quantizer.internal = True fc2_weight_quantizer.internal = True
if fp8_output:
output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT]
if torch.is_grad_enabled(): if torch.is_grad_enabled():
grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][ grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][
tex.FP8BwdTensors.GRAD_OUTPUT1 tex.FP8BwdTensors.GRAD_OUTPUT1
......
...@@ -1104,6 +1104,13 @@ class Linear(TransformerEngineBaseModule): ...@@ -1104,6 +1104,13 @@ class Linear(TransformerEngineBaseModule):
if skip_fp8_weight_update is not None: if skip_fp8_weight_update is not None:
is_first_microbatch = False is_first_microbatch = False
if self.ub_overlap_rs_fprop:
if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
fp8_output = True
if self.ub_overlap_rs_dgrad:
if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
fp8_grad = True
with self.prepare_forward( with self.prepare_forward(
inp, inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor), allow_non_contiguous=isinstance(inp, QuantizedTensor),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment