Commit 7005dd1b authored by zhuyue's avatar zhuyue Committed by zhuyue
Browse files

Add random_sample python interface and tests.

parent 2286cf78
#pragma once
#include "../device.hpp"
#include "common/op.hpp"
#include "infinicore/tensor.hpp"
namespace infinicore::op {
class RandomSample {
public:
using schema = void (*)(Tensor, Tensor, float, float, int, float);
static void execute(Tensor indices, Tensor logits, float random_val, float topp, int topk, float temperature);
static common::OpDispatcher<schema> &dispatcher();
};
// Out-of-place API
Tensor random_sample(Tensor logits, float random_val, float topp, int topk, float temperature);
// In-place API
void random_sample_(Tensor indices, Tensor logits, float random_val, float topp, int topk, float temperature);
} // namespace infinicore::op
......@@ -2,7 +2,7 @@ import infinicore
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
__all__ = ["causal_softmax", "rms_norm", "silu", "swiglu"]
__all__ = ["causal_softmax", "random_sample", "rms_norm", "silu", "swiglu"]
def causal_softmax(input: Tensor, out=None) -> Tensor:
......@@ -65,3 +65,37 @@ def swiglu(input: Tensor, other: Tensor, *, out=None):
_infinicore.swiglu_(out._underlying, input._underlying, other._underlying)
return out
def random_sample(
logits: Tensor,
random_val: float,
topp: float,
topk: int,
temperature: float,
*,
out=None,
) -> Tensor:
r"""Sample an index from logits with nucleus/top-k filtering."""
if out is None:
return Tensor(
_infinicore.random_sample(
logits._underlying,
random_val,
topp,
topk,
temperature,
)
)
_infinicore.random_sample_(
out._underlying,
logits._underlying,
random_val,
topp,
topk,
temperature,
)
return out
#include "infinicore/ops/random_sample.hpp"
namespace infinicore::op {
common::OpDispatcher<RandomSample::schema> &RandomSample::dispatcher() {
static common::OpDispatcher<RandomSample::schema> dispatcher_;
return dispatcher_;
};
void RandomSample::execute(
Tensor indices, Tensor logits,
float random_val, float topp, int topk, float temperature) {
dispatcher().lookup(context::getDevice().getType())(
indices, logits, random_val, topp, topk, temperature);
}
Tensor random_sample(
Tensor logits,
float random_val,
float topp,
int topk,
float temperature) {
auto indices = Tensor::empty({}, DataType::I32, logits->device());
random_sample_(indices, logits, random_val, topp, topk, temperature);
return indices;
}
void random_sample_(
Tensor indices,
Tensor logits,
float random_val,
float topp,
int topk,
float temperature) {
RandomSample::execute(indices, logits, random_val, topp, topk, temperature);
}
} // namespace infinicore::op
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/random_sample.hpp"
#include <infiniop.h>
namespace infinicore::op::random_sample_impl::infiniop_backend {
thread_local common::OpCache<size_t, infiniopRandomSampleDescriptor_t> caches(
100, // capacity
[](infiniopRandomSampleDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyRandomSampleDescriptor(desc));
desc = nullptr;
}
});
static void calculate(
Tensor indices,
Tensor logits,
float random_val,
float topp,
int topk,
float temperature) {
// cache per (result desc + logits desc) on device
size_t seed = hash_combine(indices, logits);
auto device_type = context::getDevice().getType();
auto device_index = context::getDevice().getIndex();
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed);
infiniopRandomSampleDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor(
context::getInfiniopHandle(), &desc,
indices->desc(), logits->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}
size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetRandomSampleWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);
INFINICORE_CHECK_ERROR(infiniopRandomSample(
desc,
workspace->data(), workspace_size,
indices->data(), logits->data(),
random_val, topp, topk, temperature,
context::getStream()));
}
} // namespace infinicore::op::random_sample_impl::infiniop_backend
namespace infinicore::op {
static bool registered = []() {
RandomSample::dispatcher().registerAll(&random_sample_impl::infiniop_backend::calculate, false);
return true;
}();
} // namespace infinicore::op
......@@ -7,6 +7,7 @@
#include "ops/causal_softmax.hpp"
#include "ops/matmul.hpp"
#include "ops/mul.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
#include "ops/rms_norm.hpp"
#include "ops/silu.hpp"
......@@ -20,6 +21,7 @@ inline void bind(py::module &m) {
bind_add(m);
bind_attention(m);
bind_causal_softmax(m);
bind_random_sample(m);
bind_matmul(m);
bind_mul(m);
bind_rearrange(m);
......
#pragma once
#include <pybind11/pybind11.h>
#include "infinicore/ops/random_sample.hpp"
namespace py = pybind11;
namespace infinicore::ops {
inline void bind_random_sample(py::module &m) {
m.def("random_sample",
&op::random_sample,
py::arg("logits"),
py::arg("random_val"),
py::arg("topp"),
py::arg("topk"),
py::arg("temperature"),
R"doc(Random sampling: returns an int32 scalar index.)doc");
m.def("random_sample_",
&op::random_sample_,
py::arg("indices"),
py::arg("logits"),
py::arg("random_val"),
py::arg("topp"),
py::arg("topk"),
py::arg("temperature"),
R"doc(In-place random sampling into provided int32 scalar tensor.)doc");
}
} // namespace infinicore::ops
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
import infinicore.nn.functional as F
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.tensor import TensorInitializer
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases: (voc, random_val, topp, topk, temperature)
# Aligned with test/infiniop/random_sample.py
_TEST_CASES_DATA = [
(512, 0.8, 0.8, 3, 0.5),
(4096, 0.05, 0.9, 5, 1.0),
(16384, 0.15, 0.85, 10, 2.0),
(512, 0.08, 0.0, 3, 0.5),
(4096, 0.5, 0.9, 1, 1.0),
(16384, 0.15, 0.0, 1, 2.0), # Duplicate as in infiniop test
(32000, 0.08, 0.8, 50, 1.0),
(32000, 0.08, 1.0, 25, 1.0),
# (119696, 0.01, 1.0, 100, 1.0), # Commented out in infiniop test
]
# Data types - note: infiniop random_sample supports F16/BF16/F32/F64 for logits
# But NVIDIA backend may have restrictions, adjust based on actual device support
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16]
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 0},
infinicore.bfloat16: {"atol": 0, "rtol": 0},
}
def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for all operation types.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for data in _TEST_CASES_DATA:
voc, random_val, topp, topk, temperature = data
base_kwargs = {
"voc": voc,
"random_val": random_val,
"topp": topp,
"topk": topk,
"temperature": temperature,
}
for tensor_dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(tensor_dtype, {"atol": 0, "rtol": 0})
# Out-of-place (return value)
test_cases.append(
TestCase(
inputs=[TensorSpec.from_tensor((voc,), dtype=tensor_dtype)],
kwargs=base_kwargs.copy(),
output_spec=None,
comparison_target=None,
tolerance=tolerance,
description=f"RandomSample - OUT_OF_PLACE",
)
)
# With explicit output tensor
test_cases.append(
TestCase(
inputs=[TensorSpec.from_tensor((voc,), dtype=tensor_dtype)],
kwargs=base_kwargs.copy(),
output_spec=TensorSpec.from_tensor(
(), dtype=infinicore.int32, init_mode=TensorInitializer.ZEROS
),
comparison_target="out",
tolerance=tolerance,
description=f"RandomSample - OUT",
)
)
return test_cases
def torch_random_sample(data, random_val, topp, topk, voc, temperature):
if topp > 0 and topk > 1:
sorted_vals, sorted_indices = torch.sort(data, descending=True)
scaled_vals = (sorted_vals - sorted_vals[0]) / temperature
try:
probs = torch.softmax(scaled_vals, dim=0)
except RuntimeError as e:
if "not implemented for 'Half'" in str(e):
scaled_vals = scaled_vals.to(torch.float32)
probs = torch.softmax(scaled_vals, dim=0)
else:
raise
cum_probs = torch.cumsum(probs, dim=0)
k_index = min(topk, voc) - 1
threshold = min(cum_probs[k_index], topp) * random_val
try:
idx = torch.searchsorted(cum_probs, threshold)
except Exception:
indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs) - 1, device=cum_probs.device)
return sorted_indices[idx]
return torch.argmax(data)
class OpTest(BaseOperatorTest):
"""RandomSample operator test with simplified implementation"""
def __init__(self):
super().__init__("RandomSample")
self._current_logits = None # Store logits for special comparison
def get_test_cases(self):
return parse_test_cases()
def prepare_inputs_and_kwargs(self, test_case, device):
"""Prepare inputs and kwargs, replacing TensorSpec objects with actual tensors"""
inputs, kwargs = super().prepare_inputs_and_kwargs(test_case, device)
# If we already have stored logits (from a previous call), reuse them
# to ensure consistency across multiple calls for the same test case
if self._current_logits is not None:
inputs[0] = self._current_logits
return inputs, kwargs
voc = kwargs["voc"]
from framework.devices import torch_device_map
if device not in torch_device_map:
raise ValueError(f"Unsupported device: {device}")
torch_device = torch.device(torch_device_map[device])
tensor_dtype = inputs[0].dtype
# Match infiniop test: torch.arange(voc)[_perm].float() * 0.0001
_perm = torch.randperm(voc, device=torch_device)
logits = (
torch.arange(voc, dtype=torch.float32, device=torch_device)[_perm] * 0.0001
).to(tensor_dtype)
inputs[0] = logits
self._current_logits = logits # Store for special comparison
return inputs, kwargs
def torch_operator(self, logits, out=None, **kwargs):
"""PyTorch random_sample implementation"""
idx = torch_random_sample(
logits,
kwargs["random_val"],
kwargs["topp"],
kwargs["topk"],
kwargs["voc"],
kwargs["temperature"],
).to(torch.int32)
if out is None:
return idx
out.copy_(idx)
return out
def infinicore_operator(self, logits, out=None, **kwargs):
"""InfiniCore random_sample implementation"""
if out is None:
return F.random_sample(
logits,
kwargs["random_val"],
kwargs["topp"],
kwargs["topk"],
kwargs["temperature"],
)
return F.random_sample(
logits,
kwargs["random_val"],
kwargs["topp"],
kwargs["topk"],
kwargs["temperature"],
out=out,
)
def run_test(self, device, test_case, config):
"""
Override run_test to handle random_sample's special comparison logic.
For random_sample, if the indices differ but the logits values at those
indices are equal, the result is still considered valid. This handles
cases where multiple valid indices exist due to floating-point precision.
This is necessary because random_sample can return different valid indices
when multiple positions have the same logits value, especially with
low-precision types like bfloat16 due to floating-point rounding.
"""
# Clear stored logits before test to ensure fresh generation
self._current_logits = None
try:
# Try the standard comparison first
# This will call prepare_inputs_and_kwargs which will set self._current_logits
return super().run_test(device, test_case, config)
except AssertionError:
# If standard comparison fails, check if this is a valid case where
# indices differ but logits values are equal
# Only handle if we have stored logits (from prepare_inputs_and_kwargs)
if self._current_logits is None:
raise
logits_tensor = self._current_logits
# Re-run operations with the same logits to get results for comparison
# prepare_inputs_and_kwargs will reuse self._current_logits if it exists
from framework.utils import (
infinicore_tensor_from_torch,
convert_infinicore_to_torch,
)
inputs, kwargs = self.prepare_inputs_and_kwargs(test_case, device)
# Prepare infinicore inputs
infini_inputs = []
for inp in inputs:
if isinstance(inp, torch.Tensor):
cloned_inp = inp.clone().detach()
infini_tensor = infinicore_tensor_from_torch(cloned_inp)
infini_inputs.append(infini_tensor)
else:
infini_inputs.append(inp)
infini_kwargs = kwargs.copy()
if "out" in infini_kwargs and isinstance(infini_kwargs["out"], torch.Tensor):
cloned_out = infini_kwargs["out"].clone().detach()
infini_kwargs["out"] = infinicore_tensor_from_torch(cloned_out)
# Run both operators
torch_result = self.torch_operator(*inputs, **kwargs)
infini_result = self.infinicore_operator(*infini_inputs, **infini_kwargs)
# Extract indices from results
comparison_target = test_case.comparison_target
if comparison_target == "out":
# Compare output tensor from kwargs
ref_idx = kwargs["out"].item()
torch_result_from_infini = convert_infinicore_to_torch(
infini_kwargs["out"], kwargs["out"]
)
ic_idx = torch_result_from_infini.item()
else:
# Compare return values
ref_idx = torch_result.item()
torch_result_from_infini = convert_infinicore_to_torch(
infini_result, torch_result
)
ic_idx = torch_result_from_infini.item()
# Check if indices are equal (standard case)
if ic_idx == ref_idx:
return
# Special case: indices differ but logits values are equal
# This is valid for random_sample when multiple indices have the same logits value
try:
logits_ref = logits_tensor[ref_idx].item()
logits_ic = logits_tensor[ic_idx].item()
if logits_ic == logits_ref:
# Valid: different indices but same logits value
return
except (IndexError, RuntimeError):
# If we can't access the logits, fall through to raise the original error
pass
# If we get here, the results are truly different
raise
def main():
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment