Commit 505e0d4b authored by Zimin Li's avatar Zimin Li
Browse files

issue/261: optimize the torch implementation of add, causal softmax, gemm,...

issue/261: optimize the torch implementation of add, causal softmax, gemm, random sample, rearrange, rms norm, rope
parent 3329034a
...@@ -79,8 +79,8 @@ class AddDescriptor(Structure): ...@@ -79,8 +79,8 @@ class AddDescriptor(Structure):
infiniopAddDescriptor_t = POINTER(AddDescriptor) infiniopAddDescriptor_t = POINTER(AddDescriptor)
def add(x, y): def add(ans, x, y):
return torch.add(x, y) torch.add(x, y, out=ans)
def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace): def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace):
...@@ -134,9 +134,10 @@ def test( ...@@ -134,9 +134,10 @@ def test(
a = torch.rand(shape, dtype=dtype).to(torch_device) a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device) b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device) c = torch.rand(shape, dtype=dtype).to(torch_device)
ans = torch.zeros(shape, dtype=dtype).to(torch_device)
a, b, c = process_tensors(c, c_stride, a, a_stride, b, b_stride, inplace) a, b, c = process_tensors(c, c_stride, a, a_stride, b, b_stride, inplace)
ans = add(a, b) add(ans, a, b)
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]] a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = ( c_tensor = (
...@@ -191,7 +192,7 @@ def test( ...@@ -191,7 +192,7 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: add(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: add(ans, a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_add(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_add(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
check_error(lib.infiniopDestroyAddDescriptor(descriptor)) check_error(lib.infiniopDestroyAddDescriptor(descriptor))
......
...@@ -73,9 +73,8 @@ infiniopCausalSoftmaxDescriptor_t = POINTER(CausalSoftmaxDescriptor) ...@@ -73,9 +73,8 @@ infiniopCausalSoftmaxDescriptor_t = POINTER(CausalSoftmaxDescriptor)
def causal_softmax(x): def causal_softmax(x):
type = x.dtype type = x.dtype
mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1]) mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1])
y = x.clone() masked = torch.where(mask == 1, -torch.inf, x.to(torch.float32))
masked = torch.where(mask == 1, -torch.inf, y.to(torch.float32)) return torch.nn.functional.softmax(masked, dim=-1, dtype=type)
return torch.nn.functional.softmax(masked, dim=-1).to(type)
def test( def test(
......
...@@ -24,14 +24,9 @@ from libinfiniop import ( ...@@ -24,14 +24,9 @@ from libinfiniop import (
_TEST_CASES = [ _TEST_CASES = [
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None), (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None), (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)), (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)), (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None), (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
] ]
...@@ -61,11 +56,14 @@ infiniopGemmDescriptor_t = POINTER(GemmDescriptor) ...@@ -61,11 +56,14 @@ infiniopGemmDescriptor_t = POINTER(GemmDescriptor)
# PyTorch implementation for matrix multiplication # PyTorch implementation for matrix multiplication
def gemm(_c, beta, _a, _b, alpha): def gemm(d, _c, beta, _a, _b, alpha):
a, b, c = _a.clone(), _b.clone(), _c.clone() if _c.ndim == 2:
result_dtype = c.dtype torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32)) elif _c.ndim == 3:
return alpha * fp32_result.to(result_dtype) + beta * c torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
else:
torch.matmul(_a, _b, out=d)
d.mul_(alpha).add_(_c, alpha=beta)
# The argument list should be (lib, handle, torch_device, <param list>, dtype) # The argument list should be (lib, handle, torch_device, <param list>, dtype)
...@@ -95,9 +93,10 @@ def test( ...@@ -95,9 +93,10 @@ def test(
a = torch.rand(a_shape, dtype=dtype).to(torch_device) a = torch.rand(a_shape, dtype=dtype).to(torch_device)
b = torch.rand(b_shape, dtype=dtype).to(torch_device) b = torch.rand(b_shape, dtype=dtype).to(torch_device)
c = torch.ones(c_shape, dtype=dtype).to(torch_device) c = torch.ones(c_shape, dtype=dtype).to(torch_device)
ans = torch.zeros(c_shape, dtype=dtype).to(torch_device)
# Compute the PyTorch reference result # Compute the PyTorch reference result
ans = gemm(c, beta, a, b, alpha) gemm(ans, c, beta, a, b, alpha)
a, b, c = [ a, b, c = [
rearrange_if_needed(tensor, stride) rearrange_if_needed(tensor, stride)
...@@ -157,7 +156,7 @@ def test( ...@@ -157,7 +156,7 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: gemm(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: gemm(ans, c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_gemm(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_gemm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
check_error(lib.infiniopDestroyGemmDescriptor(descriptor)) check_error(lib.infiniopDestroyGemmDescriptor(descriptor))
......
...@@ -59,37 +59,18 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor) ...@@ -59,37 +59,18 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature): def random_sample(data, random_val, topp, topk, voc, temperature):
if topp > 0 and topk > 1: if topp > 0 and topk > 1:
indices = torch.zeros([topk], dtype=torch.int64) sorted_vals, sorted_indices = torch.sort(data, descending=True)
dataNp = data.clone().detach()
sorted_indices = torch.arange(voc) scaled_vals = (sorted_vals - sorted_vals[0]) / temperature
probs = torch.softmax(scaled_vals, dim=0)
for i in range(topk): cum_probs = torch.cumsum(probs, dim=0)
for j in range(i + 1, voc):
if dataNp[i] < dataNp[j]: k_index = min(topk, voc) - 1
tmp = dataNp[i].clone().detach() threshold = min(cum_probs[k_index], topp) * random_val
dataNp[i] = dataNp[j].clone().detach()
dataNp[j] = tmp idx = torch.searchsorted(cum_probs, threshold)
return sorted_indices[idx]
tmpInd = sorted_indices[i].clone().detach()
sorted_indices[i] = sorted_indices[j].clone().detach()
sorted_indices[j] = tmpInd
# sorted_indices = torch.argsort(dataNp, descending=True)
indices = sorted_indices[:topk]
globalM = dataNp[0]
dataNp = (dataNp - globalM) / temperature
dataNp = torch.softmax(dataNp.float(), dim=0)
for i in range(1, voc):
dataNp[i] += dataNp[i - 1]
limit_k = dataNp[min(topk, voc) - 1]
limit_p = dataNp[voc - 1] * topp
limit = min(limit_k, limit_p) * random_val
for i in range(voc):
if limit < dataNp[i]:
return indices[i]
else:
return torch.argmax(data) return torch.argmax(data)
......
...@@ -123,6 +123,13 @@ class RerrangeDescriptor(Structure): ...@@ -123,6 +123,13 @@ class RerrangeDescriptor(Structure):
infiniopRearrangeDescriptor_t = POINTER(RerrangeDescriptor) infiniopRearrangeDescriptor_t = POINTER(RerrangeDescriptor)
def rearrange_torch(x, x_shape, y_stride):
y_ = x.clone()
y_.set_(y_.untyped_storage(), 0, x_shape, y_stride)
y_[:] = x.view_as(y_)
return y_
def test( def test(
lib, lib,
handle, handle,
...@@ -140,6 +147,8 @@ def test( ...@@ -140,6 +147,8 @@ def test(
x = torch.rand(shape, dtype=dtype).to(torch_device) x = torch.rand(shape, dtype=dtype).to(torch_device)
y = torch.zeros(shape, dtype=dtype).to(torch_device) y = torch.zeros(shape, dtype=dtype).to(torch_device)
rearrange_torch(x, shape, y_stride)
x, y = [ x, y = [
rearrange_if_needed(tensor, stride) rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride]) for tensor, stride in zip([x, y], [x_stride, y_stride])
...@@ -177,7 +186,7 @@ def test( ...@@ -177,7 +186,7 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: rearrange_tensor(y, y_stride), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: rearrange_torch(x, shape, y_stride), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rearrange(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_rearrange(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
......
...@@ -53,12 +53,13 @@ class RMSNormDescriptor(Structure): ...@@ -53,12 +53,13 @@ class RMSNormDescriptor(Structure):
infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor) infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)
def rms_norm(x, w, eps): def rms_norm(ans, x, w, eps):
input_dtype = x.dtype torch.pow(x, 2, out=ans)
hidden_states = x.to(torch.float32) mean = torch.mean(ans, dim=-1, keepdim=True)
variance = hidden_states.pow(2).mean(-1, keepdim=True) mean.add_(eps)
hidden_states = hidden_states * torch.rsqrt(variance + eps) torch.rsqrt(mean, out=mean)
return (w * hidden_states).to(input_dtype) torch.mul(x, mean, out=ans)
ans.mul_(w)
def test( def test(
...@@ -82,9 +83,10 @@ def test( ...@@ -82,9 +83,10 @@ def test(
y = torch.zeros(y_shape, dtype=dtype).to(torch_device) y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device) x = torch.rand(x_shape, dtype=dtype).to(torch_device)
w = torch.rand(w_shape, dtype=w_dtype).to(torch_device) w = torch.rand(w_shape, dtype=w_dtype).to(torch_device)
ans = torch.zeros(y_shape, dtype=dtype).to(torch_device)
eps = 1e-5 eps = 1e-5
ans = rms_norm(x, w, eps) rms_norm(ans, x, w, eps)
x, y = [ x, y = [
rearrange_if_needed(tensor, stride) rearrange_if_needed(tensor, stride)
...@@ -141,7 +143,7 @@ def test( ...@@ -141,7 +143,7 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: rms_norm(x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: rms_norm(ans, x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rms_norm(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_rms_norm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor)) check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
......
...@@ -201,7 +201,7 @@ def test( ...@@ -201,7 +201,7 @@ def test(
if PROFILE: if PROFILE:
profile_operation( profile_operation(
"PyTorch", "PyTorch",
lambda: rotary_embedding(x, pos, theta, torch_device), lambda: rotary_embedding(x, sin_table, cos_table, torch_device),
torch_device, torch_device,
NUM_PRERUN, NUM_PRERUN,
NUM_ITERATIONS, NUM_ITERATIONS,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment