Commit c2e368c6 authored by zhangyue's avatar zhangyue
Browse files

issue/207: revert rms_norm.py test cases

parent 81e426c4
......@@ -27,17 +27,17 @@ _TEST_CASES = [
# y_shape, x_shape, w_shape, y_stride, x_stride, w_dtype
((1, 4), (1, 4), (4,), None, None, torch.float32),
((16, 2048), (16, 2048), (2048,), None, None, torch.float32),
# ((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
((16, 2048), (16, 2048), (2048,), None, None, torch.float16),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float32),
# ((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
((16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), torch.float16),
]
# x types used for testing
_TENSOR_DTYPES = [torch.float32]
_TENSOR_DTYPES = [torch.float16]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float32: {"atol": 1e-3, "rtol": 1e-3},
torch.float16: {"atol": 1e-3, "rtol": 1e-3},
}
DEBUG = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment