Commit ee787b22 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents b32741e2 bc2d9697
......@@ -41,7 +41,7 @@ from transformer_engine.pytorch import (
Fp8Unpadding,
)
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm, batchgemm
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
......
......@@ -10,7 +10,6 @@ import transformer_engine_torch as tex
from transformer_engine.pytorch.optimizers import MultiTensorApply
from references.quantize_scale_calc import scale_from_amax_tensor
from torch.utils.cpp_extension import IS_HIP_EXTENSION
input_size_pairs = [
......@@ -258,10 +257,4 @@ def test_multi_tensor_compute_scale_and_scale_inv(
scale_ref, scale_inv_ref, _ = scale_from_amax_tensor(
torch.float32, amax, fp8_dtype, eps=epsilon, pow_2_scales=pow_2_scales
)
if(IS_HIP_EXTENSION):
torch.testing.assert_close(scale, scale_ref, rtol=1e-7, atol=0)
torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=1.3e-7, atol=0)
else:
torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0)
torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment