Commit c0811ed4 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/66: modified random_sample, swiglu, rms_norm, test

parent 08a29c28
......@@ -188,13 +188,9 @@ def test(
# Profiling workflow
if PROFILE:
# fmt: off
if topp > 0 and topk > 1:
profile_operation("PyTorch", lambda: random_sample(
data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu"
data, random_val, topp, topk, voc, temperature, torch_device
), torch_device, NUM_PRERUN, NUM_ITERATIONS)
else:
profile_operation("PyTorch", lambda: random_sample_0(data), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_random_sample(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
......
......@@ -133,6 +133,7 @@ def test(
if DEBUG:
debug(y, ans, atol=atol, rtol=rtol)
assert torch.allclose(y, ans, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
# fmt: off
......
......@@ -22,50 +22,29 @@ from enum import Enum, auto
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_ = [
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
]
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE = [
"Inplace.OUT_OF_PLACE",
"Inplace.INPLACE_A",
"Inplace.INPLACE_B",
]
# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_
_TEST_CASES = [
# shape, a_stride, b_stride, c_stride, inplace
((13, 4), None, None, None, Inplace.OUT_OF_PLACE),
((13, 4), None, None, None, Inplace.INPLACE_A),
((13, 4), None, None, None, Inplace.INPLACE_B),
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.OUT_OF_PLACE),
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_A),
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_B),
((13, 4, 4), None, None, None, Inplace.OUT_OF_PLACE),
((13, 4, 4), None, None, None, Inplace.INPLACE_A),
((13, 4, 4), None, None, None, Inplace.INPLACE_B),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.OUT_OF_PLACE),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_A),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_B),
((16, 5632), None, None, None, Inplace.OUT_OF_PLACE),
((16, 5632), None, None, None, Inplace.INPLACE_A),
((16, 5632), None, None, None, Inplace.INPLACE_B),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.OUT_OF_PLACE),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_A),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_B),
((4, 4, 5632), None, None, None, Inplace.OUT_OF_PLACE),
((4, 4, 5632), None, None, None, Inplace.INPLACE_A),
((4, 4, 5632), None, None, None, Inplace.INPLACE_B),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.OUT_OF_PLACE,
),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.INPLACE_A,
),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.INPLACE_B,
),
test_case + (inplace_item,)
for test_case in _TEST_CASES_
for inplace_item in _INPLACE
]
# Data types used for testing
......@@ -166,7 +145,6 @@ def test(
if DEBUG:
debug(c, ans, atol=atol, rtol=rtol)
assert torch.allclose(c, ans, atol=atol, rtol=rtol)
print("out-of-place Test passed!")
# Profiling workflow
if PROFILE:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment