Unverified Commit 384cb5bf authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #264 from InfiniTensor/issue/261_optimize_torch_implementation

Issue/261: Optimize Torch Implementation of Several Operators
parents 2f20af7e 505e0d4b
...@@ -79,8 +79,8 @@ class AddDescriptor(Structure): ...@@ -79,8 +79,8 @@ class AddDescriptor(Structure):
infiniopAddDescriptor_t = POINTER(AddDescriptor) infiniopAddDescriptor_t = POINTER(AddDescriptor)
def add(x, y): def add(ans, x, y):
return torch.add(x, y) torch.add(x, y, out=ans)
def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace): def process_tensors(c, c_strides, a, a_stride, b, b_stride, inplace):
...@@ -134,9 +134,10 @@ def test( ...@@ -134,9 +134,10 @@ def test(
a = torch.rand(shape, dtype=dtype).to(torch_device) a = torch.rand(shape, dtype=dtype).to(torch_device)
b = torch.rand(shape, dtype=dtype).to(torch_device) b = torch.rand(shape, dtype=dtype).to(torch_device)
c = torch.rand(shape, dtype=dtype).to(torch_device) c = torch.rand(shape, dtype=dtype).to(torch_device)
ans = torch.zeros(shape, dtype=dtype).to(torch_device)
a, b, c = process_tensors(c, c_stride, a, a_stride, b, b_stride, inplace) a, b, c = process_tensors(c, c_stride, a, a_stride, b, b_stride, inplace)
ans = add(a, b) add(ans, a, b)
a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]] a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]]
c_tensor = ( c_tensor = (
...@@ -191,7 +192,7 @@ def test( ...@@ -191,7 +192,7 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: add(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: add(ans, a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_add(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_add(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
check_error(lib.infiniopDestroyAddDescriptor(descriptor)) check_error(lib.infiniopDestroyAddDescriptor(descriptor))
......
...@@ -73,9 +73,8 @@ infiniopCausalSoftmaxDescriptor_t = POINTER(CausalSoftmaxDescriptor) ...@@ -73,9 +73,8 @@ infiniopCausalSoftmaxDescriptor_t = POINTER(CausalSoftmaxDescriptor)
def causal_softmax(x): def causal_softmax(x):
type = x.dtype type = x.dtype
mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1]) mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1])
y = x.clone() masked = torch.where(mask == 1, -torch.inf, x.to(torch.float32))
masked = torch.where(mask == 1, -torch.inf, y.to(torch.float32)) return torch.nn.functional.softmax(masked, dim=-1, dtype=type)
return torch.nn.functional.softmax(masked, dim=-1).to(type)
def test( def test(
......
...@@ -24,14 +24,9 @@ from libinfiniop import ( ...@@ -24,14 +24,9 @@ from libinfiniop import (
_TEST_CASES = [ _TEST_CASES = [
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None), (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None), (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)), (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)), (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None), (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None),
] ]
...@@ -61,11 +56,14 @@ infiniopGemmDescriptor_t = POINTER(GemmDescriptor) ...@@ -61,11 +56,14 @@ infiniopGemmDescriptor_t = POINTER(GemmDescriptor)
# PyTorch implementation for matrix multiplication # PyTorch implementation for matrix multiplication
def gemm(_c, beta, _a, _b, alpha): def gemm(d, _c, beta, _a, _b, alpha):
a, b, c = _a.clone(), _b.clone(), _c.clone() if _c.ndim == 2:
result_dtype = c.dtype torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32)) elif _c.ndim == 3:
return alpha * fp32_result.to(result_dtype) + beta * c torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
else:
torch.matmul(_a, _b, out=d)
d.mul_(alpha).add_(_c, alpha=beta)
# The argument list should be (lib, handle, torch_device, <param list>, dtype) # The argument list should be (lib, handle, torch_device, <param list>, dtype)
...@@ -95,9 +93,10 @@ def test( ...@@ -95,9 +93,10 @@ def test(
a = torch.rand(a_shape, dtype=dtype).to(torch_device) a = torch.rand(a_shape, dtype=dtype).to(torch_device)
b = torch.rand(b_shape, dtype=dtype).to(torch_device) b = torch.rand(b_shape, dtype=dtype).to(torch_device)
c = torch.ones(c_shape, dtype=dtype).to(torch_device) c = torch.ones(c_shape, dtype=dtype).to(torch_device)
ans = torch.zeros(c_shape, dtype=dtype).to(torch_device)
# Compute the PyTorch reference result # Compute the PyTorch reference result
ans = gemm(c, beta, a, b, alpha) gemm(ans, c, beta, a, b, alpha)
a, b, c = [ a, b, c = [
rearrange_if_needed(tensor, stride) rearrange_if_needed(tensor, stride)
...@@ -157,7 +156,7 @@ def test( ...@@ -157,7 +156,7 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: gemm(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: gemm(ans, c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_gemm(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_gemm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
check_error(lib.infiniopDestroyGemmDescriptor(descriptor)) check_error(lib.infiniopDestroyGemmDescriptor(descriptor))
......
...@@ -59,38 +59,19 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor) ...@@ -59,38 +59,19 @@ infiniopRandomSampleDescriptor_t = POINTER(RandomSampleDescriptor)
def random_sample(data, random_val, topp, topk, voc, temperature): def random_sample(data, random_val, topp, topk, voc, temperature):
if topp > 0 and topk > 1: if topp > 0 and topk > 1:
indices = torch.zeros([topk], dtype=torch.int64) sorted_vals, sorted_indices = torch.sort(data, descending=True)
dataNp = data.clone().detach()
sorted_indices = torch.arange(voc) scaled_vals = (sorted_vals - sorted_vals[0]) / temperature
probs = torch.softmax(scaled_vals, dim=0)
for i in range(topk): cum_probs = torch.cumsum(probs, dim=0)
for j in range(i + 1, voc):
if dataNp[i] < dataNp[j]: k_index = min(topk, voc) - 1
tmp = dataNp[i].clone().detach() threshold = min(cum_probs[k_index], topp) * random_val
dataNp[i] = dataNp[j].clone().detach()
dataNp[j] = tmp idx = torch.searchsorted(cum_probs, threshold)
return sorted_indices[idx]
tmpInd = sorted_indices[i].clone().detach()
sorted_indices[i] = sorted_indices[j].clone().detach() return torch.argmax(data)
sorted_indices[j] = tmpInd
# sorted_indices = torch.argsort(dataNp, descending=True)
indices = sorted_indices[:topk]
globalM = dataNp[0]
dataNp = (dataNp - globalM) / temperature
dataNp = torch.softmax(dataNp.float(), dim=0)
for i in range(1, voc):
dataNp[i] += dataNp[i - 1]
limit_k = dataNp[min(topk, voc) - 1]
limit_p = dataNp[voc - 1] * topp
limit = min(limit_k, limit_p) * random_val
for i in range(voc):
if limit < dataNp[i]:
return indices[i]
else:
return torch.argmax(data)
def test( def test(
......
...@@ -123,6 +123,13 @@ class RerrangeDescriptor(Structure): ...@@ -123,6 +123,13 @@ class RerrangeDescriptor(Structure):
infiniopRearrangeDescriptor_t = POINTER(RerrangeDescriptor) infiniopRearrangeDescriptor_t = POINTER(RerrangeDescriptor)
def rearrange_torch(x, x_shape, y_stride):
y_ = x.clone()
y_.set_(y_.untyped_storage(), 0, x_shape, y_stride)
y_[:] = x.view_as(y_)
return y_
def test( def test(
lib, lib,
handle, handle,
...@@ -140,6 +147,8 @@ def test( ...@@ -140,6 +147,8 @@ def test(
x = torch.rand(shape, dtype=dtype).to(torch_device) x = torch.rand(shape, dtype=dtype).to(torch_device)
y = torch.zeros(shape, dtype=dtype).to(torch_device) y = torch.zeros(shape, dtype=dtype).to(torch_device)
rearrange_torch(x, shape, y_stride)
x, y = [ x, y = [
rearrange_if_needed(tensor, stride) rearrange_if_needed(tensor, stride)
for tensor, stride in zip([x, y], [x_stride, y_stride]) for tensor, stride in zip([x, y], [x_stride, y_stride])
...@@ -177,7 +186,7 @@ def test( ...@@ -177,7 +186,7 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: rearrange_tensor(y, y_stride), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: rearrange_torch(x, shape, y_stride), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rearrange(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_rearrange(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
......
...@@ -53,12 +53,13 @@ class RMSNormDescriptor(Structure): ...@@ -53,12 +53,13 @@ class RMSNormDescriptor(Structure):
infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor) infiniopRMSNormDescriptor_t = POINTER(RMSNormDescriptor)
def rms_norm(x, w, eps): def rms_norm(ans, x, w, eps):
input_dtype = x.dtype torch.pow(x, 2, out=ans)
hidden_states = x.to(torch.float32) mean = torch.mean(ans, dim=-1, keepdim=True)
variance = hidden_states.pow(2).mean(-1, keepdim=True) mean.add_(eps)
hidden_states = hidden_states * torch.rsqrt(variance + eps) torch.rsqrt(mean, out=mean)
return (w * hidden_states).to(input_dtype) torch.mul(x, mean, out=ans)
ans.mul_(w)
def test( def test(
...@@ -82,9 +83,10 @@ def test( ...@@ -82,9 +83,10 @@ def test(
y = torch.zeros(y_shape, dtype=dtype).to(torch_device) y = torch.zeros(y_shape, dtype=dtype).to(torch_device)
x = torch.rand(x_shape, dtype=dtype).to(torch_device) x = torch.rand(x_shape, dtype=dtype).to(torch_device)
w = torch.rand(w_shape, dtype=w_dtype).to(torch_device) w = torch.rand(w_shape, dtype=w_dtype).to(torch_device)
ans = torch.zeros(y_shape, dtype=dtype).to(torch_device)
eps = 1e-5 eps = 1e-5
ans = rms_norm(x, w, eps) rms_norm(ans, x, w, eps)
x, y = [ x, y = [
rearrange_if_needed(tensor, stride) rearrange_if_needed(tensor, stride)
...@@ -141,7 +143,7 @@ def test( ...@@ -141,7 +143,7 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: rms_norm(x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: rms_norm(ans, x, w, eps), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_rms_norm(), torch_device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_rms_norm(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor)) check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor))
......
...@@ -201,7 +201,7 @@ def test( ...@@ -201,7 +201,7 @@ def test(
if PROFILE: if PROFILE:
profile_operation( profile_operation(
"PyTorch", "PyTorch",
lambda: rotary_embedding(x, pos, theta, torch_device), lambda: rotary_embedding(x, sin_table, cos_table, torch_device),
torch_device, torch_device,
NUM_PRERUN, NUM_PRERUN,
NUM_ITERATIONS, NUM_ITERATIONS,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment