"git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "df43ce9675a7f68b24a8d0a3a2b80cce0b57de17"
Unverified Commit 509bf877 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Ensure contiguous inputs (#38)



ensure contiguous inputs
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 89f94ba2
...@@ -329,7 +329,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -329,7 +329,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Activation recomputation is used and this is the second forward phase. # Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase(): if self.fp8 and in_fp8_activation_recompute_phase():
get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
return return inp.contiguous()
assert inp.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA."
...@@ -371,6 +371,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -371,6 +371,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
): ):
copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta) copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)
return inp.contiguous()
def post_forward(self) -> None: def post_forward(self) -> None:
"""This is needed because there isn't a way for a module to know """This is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful if it's the last FP8 module in the forward autocast. It is useful
...@@ -1089,7 +1091,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1089,7 +1091,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
produced) produced)
""" """
self.pre_forward(inp) inp = self.pre_forward(inp)
bias_tensor = bias if bias is not None else self.bias bias_tensor = bias if bias is not None else self.bias
...@@ -1615,7 +1617,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1615,7 +1617,7 @@ class Linear(TransformerEngineBaseModule):
produced) produced)
""" """
self.pre_forward(inp) inp = self.pre_forward(inp)
bias_tensor = bias if bias is not None else self.bias bias_tensor = bias if bias is not None else self.bias
...@@ -2418,7 +2420,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2418,7 +2420,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
produced) produced)
""" """
self.pre_forward(inp, num_gemms=2) inp = self.pre_forward(inp, num_gemms=2)
out = _LayerNormMLP.apply( out = _LayerNormMLP.apply(
inp, inp,
......
...@@ -1004,6 +1004,8 @@ class TransformerLayer(torch.nn.Module): ...@@ -1004,6 +1004,8 @@ class TransformerLayer(torch.nn.Module):
backprop. backprop.
""" """
hidden_states = hidden_states.contiguous()
# For AMP # For AMP
if torch.is_autocast_enabled(): if torch.is_autocast_enabled():
hidden_states = cast_if_needed( hidden_states = cast_if_needed(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment