Unverified Commit 384cb5bf authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #264 from InfiniTensor/issue/261_optimize_torch_implementation

Issue/261: Optimize Torch Implementation of Several Operators
parents 2f20af7e 505e0d4b
......@@ -79,8 +79,8 @@ class AddDescriptor(Structure):
infiniopAddDescriptor_t = POINTER(AddDescriptor)
def add(x, y):
return torch.add(x, y)
def add(ans, x, y):
torch.add(x, y, out=ans)
def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace):
......@@ -134,9 +134,10 @@ def test(
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device)
ans = torch.zeros(shape, dtype=dtype).to(torch_device)
a, b, c = process_tensors(c, c_stride, a, a_stride, b, b_stride, inplace)
ans = add(a, b)
add(ans, a, b)
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = (
......@@ -191,7 +192,7 @@ def test(
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: add(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation("PyTorch", lambda: add(ans, a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_add(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyAddDescriptor(descriptor))
......
......@@ -73,9 +73,8 @@ infiniopCausalSoftmaxDescriptor_t = POINTER(CausalSoftmaxDescriptor)
def causal_softmax(x):
type = x.dtype
mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1])
y = x.clone()
masked = torch.where(mask == 1, -torch.inf, y.to(torch.float32))
return torch.nn.functional.softmax(masked, dim=-1).to(type)
masked = torch.where(mask == 1, -torch.inf, x.to(torch.float32))
return torch.nn.functional.softmax(masked, dim=-1, dtype=type)
def test(
......
......@@ -24,14 +24,9 @@ from libinfiniop import (
_TEST_CASES = [
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
]
......@@ -61,11 +56,14 @@ infiniopGemmDescriptor_t = POINTER(GemmDescriptor)
# PyTorch implementation for matrix multiplication
def gemm(_c, beta, _a, _b, alpha):
a, b, c = _a.clone(), _b.clone(), _c.clone()
result_dtype = c.dtype
fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32))
return alpha * fp32_result.to(result_dtype) + beta * c
def gemm(d, _c, beta, _a, _b, alpha):
if _c.ndim == 2:
torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
elif _c.ndim == 3:
torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
else:
torch.matmul(_a, _b, out=d)
d.mul_(alpha).add_(_c, alpha=beta)
# The argument list should be (lib, handle, torch_device, <param list>, dtype)
......@@ -95,9 +93,10 @@ def test(
a = torch.rand(a_shape, dtype=dtype).to(torch_device)
b = torch.rand(b_shape, dtype=dtype).to(torch_device)
c = torch.ones(c_shape, dtype=dtype).to(torch_device)
ans = torch.zeros(c_shape, dtype=dtype).to(torch_device)
# Compute the PyTorch reference result
ans = gemm(c, beta, a, b, alpha)
gemm(ans, c, beta, a, b, alpha)
a, b, c = [
rearrange_if_needed(tensor, stride)
......@@ -157,7 +156,7 @@ def test(
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: gemm(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation("PyTorch", lambda: gemm(ans, c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_gemm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyGemmDescriptor(descriptor))
......
......@@ -59,38 +59,19 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature):
if topp > 0 and topk > 1:
indices = torch.zeros([topk], dtype=torch.int64)
dataNp = data.clone().detach()
sorted_indices = torch.arange(voc)
for i in range(topk):
for j in range(i + 1, voc):
if dataNp[i] < dataNp[j]:
tmp = dataNp[i].clone().detach()
dataNp[i] = dataNp[j].clone().detach()
dataNp[j] = tmp
tmpInd = sorted_indices[i].clone().detach()
sorted_indices[i] = sorted_indices[j].clone().detach()
sorted_indices[j] = tmpInd
# sorted_indices = torch.argsort(dataNp, descending=True)
indices = sorted_indices[:topk]
globalM = dataNp[0]
dataNp = (dataNp - globalM) / temperature
dataNp = torch.softmax(dataNp.float(), dim=0)
for i in range(1, voc):
dataNp[i] += dataNp[i - 1]
limit_k = dataNp[min(topk, voc) - 1]
limit_p = dataNp[voc - 1] * topp
limit = min(limit_k, limit_p) * random_val
for i in range(voc):
if limit < dataNp[i]:
return indices[i]
else:
return torch.argmax(data)
sorted_vals, sorted_indices = torch.sort(data, descending=True)
scaled_vals = (sorted_vals - sorted_vals[0]) / temperature
probs = torch.softmax(scaled_vals, dim=0)
cum_probs = torch.cumsum(probs, dim=0)
k_index = min(topk, voc) - 1
threshold = min(cum_probs[k_index], topp) * random_val
idx = torch.searchsorted(cum_probs, threshold)
return sorted_indices[idx]
return torch.argmax(data)
def test(
......
......@@ -123,6 +123,13 @@ class RerrangeDescriptor(Structure):
infiniopRearrangeDescriptor_t = POINTER(RerrangeDescriptor)
def rearrange_torch(x, x_shape, y_stride):
y_ = x.clone()
y_.set_(y_.untyped_storage(), 0, x_shape, y_stride)
y_[:] = x.view_as(y_)
return y_
def test(
lib,
handle,
......@@ -140,6 +147,8 @@ def test(
x = torch.rand(shape, dtype=dtype).to(torch_device)
y = torch.zeros(shape, dtype=dtype).to(torch_device)
rearrange_torch(x, shape, y_stride)
x, y = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride])
......@@ -177,7 +186,7 @@ def test(
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: rearrange_tensor(y, y_stride), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation("PyTorch", lambda: rearrange_torch(x, shape, y_stride), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rearrange(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
......
......@@ -53,12 +53,13 @@ class RMSNormDescriptor(Structure):
infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)
def rms_norm(x, w, eps):
input_dtype = x.dtype
hidden_states = x.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
return (w * hidden_states).to(input_dtype)
def rms_norm(ans, x, w, eps):
torch.pow(x, 2, out=ans)
mean = torch.mean(ans, dim=-1, keepdim=True)
mean.add_(eps)
torch.rsqrt(mean, out=mean)
torch.mul(x, mean, out=ans)
ans.mul_(w)
def test(
......@@ -82,9 +83,10 @@ def test(
y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device)
w = torch.rand(w_shape, dtype=w_dtype).to(torch_device)
ans = torch.zeros(y_shape, dtype=dtype).to(torch_device)
eps = 1e-5
ans = rms_norm(x, w, eps)
rms_norm(ans, x, w, eps)
x, y = [
rearrange_if_needed(tensor, stride)
......@@ -141,7 +143,7 @@ def test(
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: rms_norm(x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation("PyTorch", lambda: rms_norm(ans, x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rms_norm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
......
......@@ -201,7 +201,7 @@ def test(
if PROFILE:
profile_operation(
"PyTorch",
lambda: rotary_embedding(x, pos, theta, torch_device),
lambda: rotary_embedding(x, sin_table, cos_table, torch_device),
torch_device,
NUM_PRERUN,
NUM_ITERATIONS,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment