"...AutoBuildImmortalWrt.git" did not exist on "15cb06f860328b2ff3f15afb826af3ba99cfa116"
Unverified Commit 78b4e933 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Bug fixes from PR 22 (#65)



* Bug fixes from PR 22
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add FP8 tests to ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* bundle unittests for ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 14198f20
......@@ -7,5 +7,4 @@ set -e
: ${TE_PATH:=/opt/transformerengine}
pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/test_transformerengine.py
pytest -v -s $TE_PATH/tests/test_onnx_export.py
pytest -v -s $TE_PATH/tests/*.py
......@@ -18,6 +18,10 @@ from transformer_engine.pytorch import (
)
from transformer_engine.common import recipe
# Only run FP8 tests on H100.
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9:
pytest.skip(allow_module_level=True)
def custom_amax_to_scale(
amax: torch.Tensor,
......
......@@ -314,7 +314,7 @@ def _default_get_amax(
if amax_compute_algo == "max":
amax = torch.max(amax_history, dim=0).values
else: # amax_compute_algo == "most_recent"
amax = amax_history[0]
amax = amax_history[0].clone()
amax_history = update_amax_history(amax_history)
return amax_history, amax
......
......@@ -181,13 +181,14 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Needed for calculation of scale inverses to
# preserve scale_inv when caching FP8 weights
if fwd:
# [True, False]: -> [input, weight]
# [True, False, True]: -> [input, weight, output]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True, False] * self.fp8_meta["num_gemms"]
[True, False, True] * self.fp8_meta["num_gemms"]
).cuda()
else:
# [True, True]: -> [grad_output, grad_input]
self.fp8_meta[fp8_meta_tensor_key + "_non_weight_mask"] = torch.BoolTensor(
[True] * self.fp8_meta["num_gemms"]
[True, True] * self.fp8_meta["num_gemms"]
).cuda()
def init_fp8_meta_tensors(self) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment