Unverified Commit 105065e2 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #289 from InfiniTensor/issue/288_improve_torch_implementation_compatibility

issue/288: Improve the Compatibility of the Torch Implementations
parents a0abcb2c c132b4cf
...@@ -57,11 +57,14 @@ infiniopGemmDescriptor_t = POINTER(GemmDescriptor) ...@@ -57,11 +57,14 @@ infiniopGemmDescriptor_t = POINTER(GemmDescriptor)
# PyTorch implementation for matrix multiplication # PyTorch implementation for matrix multiplication
def gemm(d, _c, beta, _a, _b, alpha): def gemm(d, _c, beta, _a, _b, alpha):
if _c.ndim == 2: try:
torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d) if _c.ndim == 2:
elif _c.ndim == 3: torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d) elif _c.ndim == 3:
else: torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
else:
raise
except Exception:
torch.matmul(_a, _b, out=d) torch.matmul(_a, _b, out=d)
d.mul_(alpha).add_(_c, alpha=beta) d.mul_(alpha).add_(_c, alpha=beta)
......
...@@ -67,8 +67,13 @@ def random_sample(data, random_val, topp, topk, voc, temperature): ...@@ -67,8 +67,13 @@ def random_sample(data, random_val, topp, topk, voc, temperature):
k_index = min(topk, voc) - 1 k_index = min(topk, voc) - 1
threshold = min(cum_probs[k_index], topp) * random_val threshold = min(cum_probs[k_index], topp) * random_val
idx = torch.searchsorted(cum_probs, threshold) try:
idx = torch.searchsorted(cum_probs, threshold)
except Exception:
# Fallback for manual search if torch.searchsorted is not supported
indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs)-1, device=cum_probs.device)
return sorted_indices[idx] return sorted_indices[idx]
return torch.argmax(data) return torch.argmax(data)
......
...@@ -116,11 +116,11 @@ NUM_PRERUN = 10 ...@@ -116,11 +116,11 @@ NUM_PRERUN = 10
NUM_ITERATIONS = 1000 NUM_ITERATIONS = 1000
class RerrangeDescriptor(Structure): class RearrangeDescriptor(Structure):
_fields_ = [("device", c_int32)] _fields_ = [("device", c_int32)]
infiniopRearrangeDescriptor_t = POINTER(RerrangeDescriptor) infiniopRearrangeDescriptor_t = POINTER(RearrangeDescriptor)
def rearrange_torch(x, x_shape, y_stride): def rearrange_torch(x, x_shape, y_stride):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment