Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
ff84910c
Commit
ff84910c
authored
Dec 04, 2025
by
zhuyue
Browse files
Issue/714 - feat(random_sample): add batch processing interface.
parent
82c3e836
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
159 additions
and
0 deletions
+159
-0
include/infiniop/ops/random_sample.h
include/infiniop/ops/random_sample.h
+19
-0
test/infiniop/random_sample.py
test/infiniop/random_sample.py
+140
-0
No files found.
include/infiniop/ops/random_sample.h
View file @
ff84910c
...
...
@@ -15,6 +15,12 @@ __C __export infiniStatus_t infiniopGetRandomSampleWorkspaceSize(
infiniopRandomSampleDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopCreateRandomSampleBatchDescriptor
(
infiniopHandle_t
handle
,
infiniopRandomSampleDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
result
,
infiniopTensorDescriptor_t
probs
);
__C
__export
infiniStatus_t
infiniopRandomSample
(
infiniopRandomSampleDescriptor_t
desc
,
void
*
workspace
,
...
...
@@ -27,6 +33,19 @@ __C __export infiniStatus_t infiniopRandomSample(
float
temperature
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopRandomSampleBatch
(
infiniopRandomSampleDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
result
,
const
void
*
probs
,
const
float
*
random_val
,
const
float
*
topp
,
const
int
*
topk
,
const
float
*
temperature
,
int
batch_size
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroyRandomSampleDescriptor
(
infiniopRandomSampleDescriptor_t
desc
);
...
...
test/infiniop/random_sample.py
View file @
ff84910c
...
...
@@ -36,6 +36,14 @@ _TEST_CASES = [
# (119696, 0.01, 1.0, 100, 1.0),
]
# Batch test cases: (batch_size, voc, list of (random_val, topp, topk, temperature))
_BATCH_TEST_CASES
=
[
# batch_size, voc, [(random_val, topp, topk, temperature), ...]
(
4
,
512
,
[(
0.8
,
0.8
,
3
,
0.5
),
(
0.05
,
0.9
,
5
,
1.0
),
(
0.15
,
0.85
,
10
,
2.0
),
(
0.08
,
0
,
3
,
0.5
)]),
(
8
,
4096
,
[(
0.5
,
0.9
,
1
,
1.0
),
(
0.15
,
0
,
1
,
2.0
),
(
0.08
,
0.8
,
50
,
1.0
),
(
0.08
,
1.0
,
25
,
1.0
),
(
0.8
,
0.8
,
3
,
0.5
),
(
0.05
,
0.9
,
5
,
1.0
),
(
0.15
,
0.85
,
10
,
2.0
),
(
0.08
,
0
,
3
,
0.5
)]),
(
2
,
16384
,
[(
0.15
,
0.85
,
10
,
2.0
),
(
0.5
,
0.9
,
1
,
1.0
)]),
]
# Data types used for testing
_TENSOR_DTYPES
=
[
InfiniDtype
.
F16
,
InfiniDtype
.
BF16
]
...
...
@@ -183,6 +191,131 @@ def test(
check_error
(
LIBINFINIOP
.
infiniopDestroyRandomSampleDescriptor
(
descriptor
))
def
test_batch
(
handle
,
device
,
batch_size
,
voc
,
params_list
,
dtype
=
InfiniDtype
.
F16
,
sync
=
None
,
):
print
(
f
"Testing RandomSampleBatch on
{
InfiniDeviceNames
[
device
]
}
with batch_size:
{
batch_size
}
voc:
{
voc
}
dtype:
{
InfiniDtypeNames
[
dtype
]
}
"
)
assert
len
(
params_list
)
==
batch_size
logits_list
=
[]
for
i
in
range
(
batch_size
):
_perm
=
torch
.
randperm
(
voc
)
logits_list
.
append
(
torch
.
arange
(
voc
)[
_perm
].
float
()
*
0.0001
)
logits_batch
=
torch
.
stack
(
logits_list
)
logits
=
TestTensor
.
from_torch
(
logits_batch
,
dtype
,
device
)
ans_list
=
[]
for
i
in
range
(
batch_size
):
random_val
,
topp
,
topk
,
temperature
=
params_list
[
i
]
ans
=
random_sample
(
logits
.
torch_tensor
()[
i
],
random_val
,
topp
,
topk
,
voc
,
temperature
).
to
(
torch
.
int32
)
ans_list
.
append
(
ans
)
ans_batch
=
torch
.
stack
(
ans_list
)
indices
=
TestTensor
([
batch_size
],
None
,
InfiniDtype
.
I32
,
device
,
mode
=
"zeros"
)
if
sync
is
not
None
:
sync
()
descriptor
=
infiniopOperatorDescriptor_t
()
try
:
check_error
(
LIBINFINIOP
.
infiniopCreateRandomSampleBatchDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
indices
.
descriptor
,
logits
.
descriptor
,
)
)
except
Exception
as
e
:
print
(
f
"
\033
[93mNote: Batch descriptor creation not implemented yet:
{
e
}
\033
[0m"
)
print
(
f
" This is expected - batch interface implementation is pending"
)
return
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
for
tensor
in
[
logits
,
indices
]:
tensor
.
destroy_desc
()
workspace_size
=
c_uint64
(
0
)
check_error
(
LIBINFINIOP
.
infiniopGetRandomSampleWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
device
)
random_val_array
=
(
ctypes
.
c_float
*
batch_size
)(
*
[
p
[
0
]
for
p
in
params_list
])
topp_array
=
(
ctypes
.
c_float
*
batch_size
)(
*
[
p
[
1
]
for
p
in
params_list
])
topk_array
=
(
ctypes
.
c_int
*
batch_size
)(
*
[
p
[
2
]
for
p
in
params_list
])
temperature_array
=
(
ctypes
.
c_float
*
batch_size
)(
*
[
p
[
3
]
for
p
in
params_list
])
def
lib_random_sample_batch
():
check_error
(
LIBINFINIOP
.
infiniopRandomSampleBatch
(
descriptor
,
workspace
.
data
(),
workspace_size
.
value
,
indices
.
data
(),
logits
.
data
(),
random_val_array
,
topp_array
,
topk_array
,
temperature_array
,
batch_size
,
None
,
)
)
lib_random_sample_batch
()
if
sync
is
not
None
:
sync
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug_all
(
(
indices
.
actual_tensor
(),
logits
.
actual_tensor
()[
torch
.
arange
(
batch_size
),
indices
.
actual_tensor
()]),
(
ans_batch
,
logits
.
torch_tensor
()[
torch
.
arange
(
batch_size
),
ans_batch
]),
"or"
,
atol
=
atol
,
rtol
=
rtol
,
)
actual_indices
=
indices
.
actual_tensor
()
for
i
in
range
(
batch_size
):
assert
(
actual_indices
[
i
]
==
ans_batch
[
i
]
or
logits
.
actual_tensor
()[
i
,
actual_indices
[
i
]]
==
logits
.
torch_tensor
()[
i
,
ans_batch
[
i
]]
)
# Profiling workflow
if
PROFILE
:
# fmt: off
def
pytorch_batch
():
results
=
[]
for
i
in
range
(
batch_size
):
random_val
,
topp
,
topk
,
temperature
=
params_list
[
i
]
results
.
append
(
random_sample
(
logits
.
torch_tensor
()[
i
],
random_val
,
topp
,
topk
,
voc
,
temperature
))
return
torch
.
stack
(
results
)
profile_operation
(
"PyTorch"
,
lambda
:
pytorch_batch
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_random_sample_batch
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
LIBINFINIOP
.
infiniopDestroyRandomSampleDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
args
=
get_args
()
...
...
@@ -195,4 +328,11 @@ if __name__ == "__main__":
for
device
in
get_test_devices
(
args
):
test_operator
(
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
f
"
\n\033
[93mRunning batch tests on
{
InfiniDeviceNames
[
device
]
}
...
\033
[0m"
)
try
:
test_operator
(
device
,
test_batch
,
_BATCH_TEST_CASES
,
_TENSOR_DTYPES
)
except
Exception
as
e
:
print
(
f
"
\033
[91mBatch test failed (not implemented yet):
{
e
}
\033
[0m"
)
print
(
f
" This is expected - batch interface implementation is pending"
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment