Commit f49235e3 authored by Zimin Li's avatar Zimin Li
Browse files

issue/288: improve the compatibility of the torch implementations of gemm and random sample

parent a0abcb2c
...@@ -57,11 +57,14 @@ infiniopGemmDescriptor_t = POINTER(GemmDescriptor) ...@@ -57,11 +57,14 @@ infiniopGemmDescriptor_t = POINTER(GemmDescriptor)
# PyTorch implementation for matrix multiplication # PyTorch implementation for matrix multiplication
def gemm(d, _c, beta, _a, _b, alpha): def gemm(d, _c, beta, _a, _b, alpha):
if _c.ndim == 2: try:
torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d) if _c.ndim == 2:
elif _c.ndim == 3: torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d) elif _c.ndim == 3:
else: torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
else:
raise
except Exception:
torch.matmul(_a, _b, out=d) torch.matmul(_a, _b, out=d)
d.mul_(alpha).add_(_c, alpha=beta) d.mul_(alpha).add_(_c, alpha=beta)
......
...@@ -67,8 +67,13 @@ def random_sample(data, random_val, topp, topk, voc, temperature): ...@@ -67,8 +67,13 @@ def random_sample(data, random_val, topp, topk, voc, temperature):
k_index = min(topk, voc) - 1 k_index = min(topk, voc) - 1
threshold = min(cum_probs[k_index], topp) * random_val threshold = min(cum_probs[k_index], topp) * random_val
idx = torch.searchsorted(cum_probs, threshold) try:
idx = torch.searchsorted(cum_probs, threshold)
except Exception:
# Fallback for manual search if torch.searchsorted is not supported
indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs)-1, device=cum_probs.device)
return sorted_indices[idx] return sorted_indices[idx]
return torch.argmax(data) return torch.argmax(data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment