Commit f49235e3 authored by Zimin Li's avatar Zimin Li
Browse files

issue/288: improve the compatibility of the torch implementations of gemm and random sample

parent a0abcb2c
......@@ -57,11 +57,14 @@ infiniopGemmDescriptor_t = POINTER(GemmDescriptor)
# PyTorch implementation for matrix multiplication
def gemm(d, _c, beta, _a, _b, alpha):
try:
if _c.ndim == 2:
torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
elif _c.ndim == 3:
torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
else:
raise
except Exception:
torch.matmul(_a, _b, out=d)
d.mul_(alpha).add_(_c, alpha=beta)
......
......@@ -68,7 +68,12 @@ def random_sample(data, random_val, topp, topk, voc, temperature):
k_index = min(topk, voc) - 1
threshold = min(cum_probs[k_index], topp) * random_val
try:
idx = torch.searchsorted(cum_probs, threshold)
except Exception:
# Fallback for manual search if torch.searchsorted is not supported
indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs)-1, device=cum_probs.device)
return sorted_indices[idx]
return torch.argmax(data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment