Commit c0811ed4 authored by xgqdut2016's avatar xgqdut2016
Browse files

issue/66: modified random_sample, swiglu, rms_norm, test

parent 08a29c28
...@@ -188,13 +188,9 @@ def test( ...@@ -188,13 +188,9 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
if topp > 0 and topk > 1: profile_operation("PyTorch", lambda: random_sample(
profile_operation("PyTorch", lambda: random_sample( data, random_val, topp, topk, voc, temperature, torch_device
data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu"
), torch_device, NUM_PRERUN, NUM_ITERATIONS) ), torch_device, NUM_PRERUN, NUM_ITERATIONS)
else:
profile_operation("PyTorch", lambda: random_sample_0(data), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_random_sample(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_random_sample(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor)) check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))
......
...@@ -133,6 +133,7 @@ def test( ...@@ -133,6 +133,7 @@ def test(
if DEBUG: if DEBUG:
debug(y, ans, atol=atol, rtol=rtol) debug(y, ans, atol=atol, rtol=rtol)
assert torch.allclose(y, ans, atol=atol, rtol=rtol) assert torch.allclose(y, ans, atol=atol, rtol=rtol)
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
......
...@@ -22,50 +22,29 @@ from enum import Enum, auto ...@@ -22,50 +22,29 @@ from enum import Enum, auto
# Configuration (Internal Use Only) # Configuration (Internal Use Only)
# ============================================================================== # ==============================================================================
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES_ = [
((13, 4), None, None, None),
((13, 4), (10, 1), (10, 1), (10, 1)),
((13, 4, 4), None, None, None),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1)),
((16, 5632), None, None, None),
((16, 5632), (13312, 1), (13312, 1), (13312, 1)),
((4, 4, 5632), None, None, None),
((4, 4, 5632), (45056, 5632, 1), (45056, 5632, 1), (45056, 5632, 1)),
]
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE = [
"Inplace.OUT_OF_PLACE",
"Inplace.INPLACE_A",
"Inplace.INPLACE_B",
]
# Form the test cases by appending each element of _INPLACE to each tuple in _TEST_CASES_
_TEST_CASES = [ _TEST_CASES = [
# shape, a_stride, b_stride, c_stride, inplace test_case + (inplace_item,)
((13, 4), None, None, None, Inplace.OUT_OF_PLACE), for test_case in _TEST_CASES_
((13, 4), None, None, None, Inplace.INPLACE_A), for inplace_item in _INPLACE
((13, 4), None, None, None, Inplace.INPLACE_B),
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.OUT_OF_PLACE),
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_A),
((13, 4), (10, 1), (10, 1), (10, 1), Inplace.INPLACE_B),
((13, 4, 4), None, None, None, Inplace.OUT_OF_PLACE),
((13, 4, 4), None, None, None, Inplace.INPLACE_A),
((13, 4, 4), None, None, None, Inplace.INPLACE_B),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.OUT_OF_PLACE),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_A),
((13, 4, 4), (20, 4, 1), (20, 4, 1), (20, 4, 1), Inplace.INPLACE_B),
((16, 5632), None, None, None, Inplace.OUT_OF_PLACE),
((16, 5632), None, None, None, Inplace.INPLACE_A),
((16, 5632), None, None, None, Inplace.INPLACE_B),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.OUT_OF_PLACE),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_A),
((16, 5632), (13312, 1), (13312, 1), (13312, 1), Inplace.INPLACE_B),
((4, 4, 5632), None, None, None, Inplace.OUT_OF_PLACE),
((4, 4, 5632), None, None, None, Inplace.INPLACE_A),
((4, 4, 5632), None, None, None, Inplace.INPLACE_B),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.OUT_OF_PLACE,
),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.INPLACE_A,
),
(
(4, 4, 5632),
(45056, 5632, 1),
(45056, 5632, 1),
(45056, 5632, 1),
Inplace.INPLACE_B,
),
] ]
# Data types used for testing # Data types used for testing
...@@ -166,7 +145,6 @@ def test( ...@@ -166,7 +145,6 @@ def test(
if DEBUG: if DEBUG:
debug(c, ans, atol=atol, rtol=rtol) debug(c, ans, atol=atol, rtol=rtol)
assert torch.allclose(c, ans, atol=atol, rtol=rtol) assert torch.allclose(c, ans, atol=atol, rtol=rtol)
print("out-of-place Test passed!")
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment