Commit c2e368c6 authored by zhangyue's avatar zhangyue
Browse files

issue/207: revert rms_norm.py test cases

parent 81e426c4
...@@ -27,17 +27,17 @@ _TEST_CASES = [ ...@@ -27,17 +27,17 @@ _TEST_CASES = [
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype # y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((1, 4), (1, 4), (4,), None, None, torch.float32), ((1, 4), (1, 4), (4,), None, None, torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float32), ((16, 2048), (16, 2048), (2048,), None, None, torch.float32),
# ((16, 2048), (16, 2048), (2048,), None, None, torch.float16), ((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32), ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32),
# ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16), ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
] ]
# x types used for testing # x types used for testing
_TENSOR_DTYPES = [torch.float32] _TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types # Tolerance map for different data types
_TOLERANCE_MAP = { _TOLERANCE_MAP = {
torch.float32: {"atol": 1e-3, "rtol": 1e-3}, torch.float16: {"atol": 1e-3, "rtol": 1e-3},
} }
DEBUG = False DEBUG = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment