Commit ff84910c authored by zhuyue's avatar zhuyue
Browse files

Issue/714 - feat(random_sample): add batch processing interface.

parent 82c3e836
...@@ -15,6 +15,12 @@ __C __export infiniStatus_t infiniopGetRandomSampleWorkspaceSize( ...@@ -15,6 +15,12 @@ __C __export infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
infiniopRandomSampleDescriptor_t desc, infiniopRandomSampleDescriptor_t desc,
size_t *size); size_t *size);
__C __export infiniStatus_t infiniopCreateRandomSampleBatchDescriptor(
infiniopHandle_t handle,
infiniopRandomSampleDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t result,
infiniopTensorDescriptor_t probs);
__C __export infiniStatus_t infiniopRandomSample( __C __export infiniStatus_t infiniopRandomSample(
infiniopRandomSampleDescriptor_t desc, infiniopRandomSampleDescriptor_t desc,
void *workspace, void *workspace,
...@@ -27,6 +33,19 @@ __C __export infiniStatus_t infiniopRandomSample( ...@@ -27,6 +33,19 @@ __C __export infiniStatus_t infiniopRandomSample(
float temperature, float temperature,
void *stream); void *stream);
__C __export infiniStatus_t infiniopRandomSampleBatch(
infiniopRandomSampleDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *result,
const void *probs,
const float *random_val,
const float *topp,
const int *topk,
const float *temperature,
int batch_size,
void *stream);
__C __export infiniStatus_t infiniopDestroyRandomSampleDescriptor( __C __export infiniStatus_t infiniopDestroyRandomSampleDescriptor(
infiniopRandomSampleDescriptor_t desc); infiniopRandomSampleDescriptor_t desc);
......
...@@ -36,6 +36,14 @@ _TEST_CASES = [ ...@@ -36,6 +36,14 @@ _TEST_CASES = [
# (119696, 0.01, 1.0, 100, 1.0), # (119696, 0.01, 1.0, 100, 1.0),
] ]
# Batch test cases: (batch_size, voc, list of (random_val, topp, topk, temperature))
_BATCH_TEST_CASES = [
# batch_size, voc, [(random_val, topp, topk, temperature), ...]
(4, 512, [(0.8, 0.8, 3, 0.5), (0.05, 0.9, 5, 1.0), (0.15, 0.85, 10, 2.0), (0.08, 0, 3, 0.5)]),
(8, 4096, [(0.5, 0.9, 1, 1.0), (0.15, 0, 1, 2.0), (0.08, 0.8, 50, 1.0), (0.08, 1.0, 25, 1.0), (0.8, 0.8, 3, 0.5), (0.05, 0.9, 5, 1.0), (0.15, 0.85, 10, 2.0), (0.08, 0, 3, 0.5)]),
(2, 16384, [(0.15, 0.85, 10, 2.0), (0.5, 0.9, 1, 1.0)]),
]
# Data types used for testing # Data types used for testing
_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16] _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16]
...@@ -183,6 +191,131 @@ def test( ...@@ -183,6 +191,131 @@ def test(
check_error(LIBINFINIOP.infiniopDestroyRandomSampleDescriptor(descriptor)) check_error(LIBINFINIOP.infiniopDestroyRandomSampleDescriptor(descriptor))
def test_batch(
handle,
device,
batch_size,
voc,
params_list,
dtype=InfiniDtype.F16,
sync=None,
):
print(
f"Testing RandomSampleBatch on {InfiniDeviceNames[device]} with batch_size:{batch_size} voc:{voc} dtype:{InfiniDtypeNames[dtype]}"
)
assert len(params_list) == batch_size
logits_list = []
for i in range(batch_size):
_perm = torch.randperm(voc)
logits_list.append(torch.arange(voc)[_perm].float() * 0.0001)
logits_batch = torch.stack(logits_list)
logits = TestTensor.from_torch(logits_batch, dtype, device)
ans_list = []
for i in range(batch_size):
random_val, topp, topk, temperature = params_list[i]
ans = random_sample(
logits.torch_tensor()[i], random_val, topp, topk, voc, temperature
).to(torch.int32)
ans_list.append(ans)
ans_batch = torch.stack(ans_list)
indices = TestTensor([batch_size], None, InfiniDtype.I32, device, mode="zeros")
if sync is not None:
sync()
descriptor = infiniopOperatorDescriptor_t()
try:
check_error(
LIBINFINIOP.infiniopCreateRandomSampleBatchDescriptor(
handle,
ctypes.byref(descriptor),
indices.descriptor,
logits.descriptor,
)
)
except Exception as e:
print(f"\033[93mNote: Batch descriptor creation not implemented yet: {e}\033[0m")
print(f" This is expected - batch interface implementation is pending")
return
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for tensor in [logits, indices]:
tensor.destroy_desc()
workspace_size = c_uint64(0)
check_error(
LIBINFINIOP.infiniopGetRandomSampleWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
workspace = TestWorkspace(workspace_size.value, device)
random_val_array = (ctypes.c_float * batch_size)(*[p[0] for p in params_list])
topp_array = (ctypes.c_float * batch_size)(*[p[1] for p in params_list])
topk_array = (ctypes.c_int * batch_size)(*[p[2] for p in params_list])
temperature_array = (ctypes.c_float * batch_size)(*[p[3] for p in params_list])
def lib_random_sample_batch():
check_error(
LIBINFINIOP.infiniopRandomSampleBatch(
descriptor,
workspace.data(),
workspace_size.value,
indices.data(),
logits.data(),
random_val_array,
topp_array,
topk_array,
temperature_array,
batch_size,
None,
)
)
lib_random_sample_batch()
if sync is not None:
sync()
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug_all(
(indices.actual_tensor(), logits.actual_tensor()[torch.arange(batch_size), indices.actual_tensor()]),
(ans_batch, logits.torch_tensor()[torch.arange(batch_size), ans_batch]),
"or",
atol=atol,
rtol=rtol,
)
actual_indices = indices.actual_tensor()
for i in range(batch_size):
assert (
actual_indices[i] == ans_batch[i]
or logits.actual_tensor()[i, actual_indices[i]] == logits.torch_tensor()[i, ans_batch[i]]
)
# Profiling workflow
if PROFILE:
# fmt: off
def pytorch_batch():
results = []
for i in range(batch_size):
random_val, topp, topk, temperature = params_list[i]
results.append(random_sample(
logits.torch_tensor()[i], random_val, topp, topk, voc, temperature
))
return torch.stack(results)
profile_operation("PyTorch", lambda: pytorch_batch(), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_random_sample_batch(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(LIBINFINIOP.infiniopDestroyRandomSampleDescriptor(descriptor))
if __name__ == "__main__": if __name__ == "__main__":
args = get_args() args = get_args()
...@@ -195,4 +328,11 @@ if __name__ == "__main__": ...@@ -195,4 +328,11 @@ if __name__ == "__main__":
for device in get_test_devices(args): for device in get_test_devices(args):
test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES)
print(f"\n\033[93mRunning batch tests on {InfiniDeviceNames[device]}...\033[0m")
try:
test_operator(device, test_batch, _BATCH_TEST_CASES, _TENSOR_DTYPES)
except Exception as e:
print(f"\033[91mBatch test failed (not implemented yet): {e}\033[0m")
print(f" This is expected - batch interface implementation is pending")
print("\033[92mTest passed!\033[0m") print("\033[92mTest passed!\033[0m")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment