Commit 505e0d4b authored by Zimin Li's avatar Zimin Li
Browse files

issue/261: optimize the torch implementation of add, causal softmax, gemm,...

issue/261: optimize the torch implementation of add, causal softmax, gemm, random sample, rearrange, rms norm, rope
parent 3329034a
......@@ -79,8 +79,8 @@ class AddDescriptor(Structure):
infiniopAddDescriptor_t = POINTER(AddDescriptor)
def add(x, y):
return torch.add(x, y)
def add(ans, x, y):
torch.add(x, y, out=ans)
def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace):
......@@ -134,9 +134,10 @@ def test(
a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device)
ans = torch.zeros(shape, dtype=dtype).to(torch_device)
a, b, c = process_tensors(c, c_stride, a, a_stride, b, b_stride, inplace)
ans = add(a, b)
add(ans, a, b)
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = (
......@@ -191,7 +192,7 @@ def test(
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: add(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation("PyTorch", lambda: add(ans, a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_add(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyAddDescriptor(descriptor))
......
......@@ -73,9 +73,8 @@ infiniopCausalSoftmaxDescriptor_t = POINTER(CausalSoftmaxDescriptor)
def causal_softmax(x):
type = x.dtype
mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1])
y = x.clone()
masked = torch.where(mask == 1, -torch.inf, y.to(torch.float32))
return torch.nn.functional.softmax(masked, dim=-1).to(type)
masked = torch.where(mask == 1, -torch.inf, x.to(torch.float32))
return torch.nn.functional.softmax(masked, dim=-1, dtype=type)
def test(
......
......@@ -24,14 +24,9 @@ from libinfiniop import (
_TEST_CASES = [
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
]
......@@ -61,11 +56,14 @@ infiniopGemmDescriptor_t = POINTER(GemmDescriptor)
# PyTorch implementation for matrix multiplication
def gemm(_c, beta, _a, _b, alpha):
a, b, c = _a.clone(), _b.clone(), _c.clone()
result_dtype = c.dtype
fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32))
return alpha * fp32_result.to(result_dtype) + beta * c
def gemm(d, _c, beta, _a, _b, alpha):
if _c.ndim == 2:
torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
elif _c.ndim == 3:
torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
else:
torch.matmul(_a, _b, out=d)
d.mul_(alpha).add_(_c, alpha=beta)
# The argument list should be (lib, handle, torch_device, <param list>, dtype)
......@@ -95,9 +93,10 @@ def test(
a = torch.rand(a_shape, dtype=dtype).to(torch_device)
b = torch.rand(b_shape, dtype=dtype).to(torch_device)
c = torch.ones(c_shape, dtype=dtype).to(torch_device)
ans = torch.zeros(c_shape, dtype=dtype).to(torch_device)
# Compute the PyTorch reference result
ans = gemm(c, beta, a, b, alpha)
gemm(ans, c, beta, a, b, alpha)
a, b, c = [
rearrange_if_needed(tensor, stride)
......@@ -157,7 +156,7 @@ def test(
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: gemm(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation("PyTorch", lambda: gemm(ans, c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_gemm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyGemmDescriptor(descriptor))
......
......@@ -59,38 +59,19 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature):
if topp > 0 and topk > 1:
indices = torch.zeros([topk], dtype=torch.int64)
dataNp = data.clone().detach()
sorted_indices = torch.arange(voc)
for i in range(topk):
for j in range(i + 1, voc):
if dataNp[i] < dataNp[j]:
tmp = dataNp[i].clone().detach()
dataNp[i] = dataNp[j].clone().detach()
dataNp[j] = tmp
tmpInd = sorted_indices[i].clone().detach()
sorted_indices[i] = sorted_indices[j].clone().detach()
sorted_indices[j] = tmpInd
# sorted_indices = torch.argsort(dataNp, descending=True)
indices = sorted_indices[:topk]
globalM = dataNp[0]
dataNp = (dataNp - globalM) / temperature
dataNp = torch.softmax(dataNp.float(), dim=0)
for i in range(1, voc):
dataNp[i] += dataNp[i - 1]
limit_k = dataNp[min(topk, voc) - 1]
limit_p = dataNp[voc - 1] * topp
limit = min(limit_k, limit_p) * random_val
for i in range(voc):
if limit < dataNp[i]:
return indices[i]
else:
return torch.argmax(data)
sorted_vals, sorted_indices = torch.sort(data, descending=True)
scaled_vals = (sorted_vals - sorted_vals[0]) / temperature
probs = torch.softmax(scaled_vals, dim=0)
cum_probs = torch.cumsum(probs, dim=0)
k_index = min(topk, voc) - 1
threshold = min(cum_probs[k_index], topp) * random_val
idx = torch.searchsorted(cum_probs, threshold)
return sorted_indices[idx]
return torch.argmax(data)
def test(
......
......@@ -123,6 +123,13 @@ class RerrangeDescriptor(Structure):
infiniopRearrangeDescriptor_t = POINTER(RerrangeDescriptor)
def rearrange_torch(x, x_shape, y_stride):
y_ = x.clone()
y_.set_(y_.untyped_storage(), 0, x_shape, y_stride)
y_[:] = x.view_as(y_)
return y_
def test(
lib,
handle,
......@@ -140,6 +147,8 @@ def test(
x = torch.rand(shape, dtype=dtype).to(torch_device)
y = torch.zeros(shape, dtype=dtype).to(torch_device)
rearrange_torch(x, shape, y_stride)
x, y = [
rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride])
......@@ -177,7 +186,7 @@ def test(
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: rearrange_tensor(y, y_stride), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation("PyTorch", lambda: rearrange_torch(x, shape, y_stride), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rearrange(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
......
......@@ -53,12 +53,13 @@ class RMSNormDescriptor(Structure):
infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)
def rms_norm(x, w, eps):
input_dtype = x.dtype
hidden_states = x.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + eps)
return (w * hidden_states).to(input_dtype)
def rms_norm(ans, x, w, eps):
torch.pow(x, 2, out=ans)
mean = torch.mean(ans, dim=-1, keepdim=True)
mean.add_(eps)
torch.rsqrt(mean, out=mean)
torch.mul(x, mean, out=ans)
ans.mul_(w)
def test(
......@@ -82,9 +83,10 @@ def test(
y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device)
w = torch.rand(w_shape, dtype=w_dtype).to(torch_device)
ans = torch.zeros(y_shape, dtype=dtype).to(torch_device)
eps = 1e-5
ans = rms_norm(x, w, eps)
rms_norm(ans, x, w, eps)
x, y = [
rearrange_if_needed(tensor, stride)
......@@ -141,7 +143,7 @@ def test(
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: rms_norm(x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation("PyTorch", lambda: rms_norm(ans, x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rms_norm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
......
......@@ -201,7 +201,7 @@ def test(
if PROFILE:
profile_operation(
"PyTorch",
lambda: rotary_embedding(x, pos, theta, torch_device),
lambda: rotary_embedding(x, sin_table, cos_table, torch_device),
torch_device,
NUM_PRERUN,
NUM_ITERATIONS,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment