Unverified Commit 78b4e933 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Bug fixes from PR 22 (#65)



* Bug fixes from PR 22
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add FP8 tests to ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* bundle unittests for ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 14198f20
...@@ -7,5 +7,4 @@ set -e ...@@ -7,5 +7,4 @@ set -e
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1 pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/test_transformerengine.py pytest -v -s $TE_PATH/tests/*.py
pytest -v -s $TE_PATH/tests/test_onnx_export.py
...@@ -18,6 +18,10 @@ from transformer_engine.pytorch import ( ...@@ -18,6 +18,10 @@ from transformer_engine.pytorch import (
) )
from transformer_engine.common import recipe from transformer_engine.common import recipe
# Only run FP8 tests on H100.
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip(allow_module_level=True)
def custom_amax_to_scale( def custom_amax_to_scale(
amax: torch.Tensor, amax: torch.Tensor,
......
...@@ -314,7 +314,7 @@ def _default_get_amax( ...@@ -314,7 +314,7 @@ def _default_get_amax(
if amax_compute_algo == "max": if amax_compute_algo == "max":
amax = torch.max(amax_history, dim=0).values amax = torch.max(amax_history, dim=0).values
else: # amax_compute_algo == "most_recent" else: # amax_compute_algo == "most_recent"
amax = amax_history[0] amax = amax_history[0].clone()
amax_history = update_amax_history(amax_history) amax_history = update_amax_history(amax_history)
return amax_history, amax return amax_history, amax
......
...@@ -181,13 +181,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -181,13 +181,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Needed for calculation of scale inverses to # Needed for calculation of scale inverses to
# preserve scale_inv when caching FP8 weights # preserve scale_inv when caching FP8 weights
if fwd: if fwd:
# [True, False]: -> [input, weight] # [True, False, True]: -> [input, weight, output]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor( self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, False] * self.fp8_meta["num_gemms"] [True, False, True] * self.fp8_meta["num_gemms"]
).cuda() ).cuda()
else: else:
# [True, True]: -> [grad_output, grad_input]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor( self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True] * self.fp8_meta["num_gemms"] [True, True] * self.fp8_meta["num_gemms"]
).cuda() ).cuda()
def init_fp8_meta_tensors(self) -> None: def init_fp8_meta_tensors(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment