Commit 5f3d9ada authored by Mitchell Wortsman's avatar Mitchell Wortsman
Browse files

triton-v1

parent 51f8bb71
import torch
import json
from bitsandbytes.nn.triton_based_modules import SwitchBackGlobalMLP, SwitchBackGlobalLinear, MyLinear
import time
if __name__ == '__main__':
print('Startin')
for dim in [1024, 1280, 1408, 1664, 2048]:
for batch in [2**14, 2**15, 2**16, 2**17]:
x1 = torch.randn(batch, dim).cuda().requires_grad_(True)
d = 2
standard = torch.nn.Sequential(
torch.nn.LayerNorm(dim),
torch.nn.Linear(dim, 4 * dim),
torch.nn.GELU(),
torch.nn.Linear(4 * dim, dim),
).cuda()
my_standard = torch.nn.Sequential(
torch.nn.LayerNorm(dim),
MyLinear(dim, 4 * dim),
torch.nn.GELU(),
MyLinear(4 * dim, dim),
).cuda()
fused_mlp = SwitchBackGlobalMLP(dim, 4 * dim).cuda()
sb = torch.nn.Sequential(
torch.nn.LayerNorm(dim),
SwitchBackGlobalLinear(dim, 4 * dim),
torch.nn.GELU(),
SwitchBackGlobalLinear(4 * dim, dim),
).cuda()
standard_compiled = torch.compile(standard)
print('Model part 2')
repeat = 32
info = {'repeat' : repeat, 'batch_size' : batch, 'dim' : dim}
k = 'standard'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_standard = standard(x1)
((2 ** 16) * out_standard).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_standard = standard(x1)
((2 ** 16) * out_standard).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'my_standard'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_my_standard = my_standard(x1)
((2 ** 16) * out_my_standard).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_my_standard = my_standard(x1)
((2 ** 16) * out_my_standard).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'standard_compiled'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_standard_compiled = standard_compiled(x1)
((2 ** 16) * out_standard_compiled).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_standard_compiled = standard_compiled(x1)
((2 ** 16) * out_standard_compiled).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
x1.grad.zero_()
k = 'sb'
for _ in range(repeat // 2):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
with torch.cuda.amp.autocast():
out_sb = sb(x1)
((2 ** 16) * out_sb).abs().mean().backward()
torch.cuda.synchronize()
end = time.time()
ms = (end - start) / repeat * 1000
print(f"time {k}: {ms:.3f} ms")
info[k] = ms
info_json = json.dumps(info)
with open("tests/triton_tests/info_mlp_autocast_ln.jsonl", "a") as file:
file.write(info_json + "\n")
#exit()
# err_fused = (out_standard - out_fused).abs().mean()
# err_sb = (out_standard - out_sb).abs().mean()
# print('OUT', err_fused, err_sb)
# err_fused = (standard[d].weight.grad - fused_mlp.linear2.weight.grad).abs().mean()
# err_sb = (standard[d].weight.grad - sb[d].weight.grad).abs().mean()
# print('GW2', err_fused, err_sb)
# err_fused = (standard[0].weight.grad - fused_mlp.linear1.weight.grad).abs().mean()
# err_sb = (standard[0].weight.grad - sb[0].weight.grad).abs().mean()
# print('GW1', err_fused, err_sb)
# err_fused = (x1.grad - x2.grad).abs().mean()
# err_sb = (x1.grad - x3.grad).abs().mean()
# print('GX1', err_fused, err_sb)
# import pdb; pdb.set_trace()
# # NO GELU, ST GRADIENTS, EVERYTHING FINE.
\ No newline at end of file
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import matplotlib.gridspec as gridspec
cmap=plt.get_cmap('cool')
if __name__ == '__main__':
fig = plt.figure(tight_layout=True, figsize=(6,3.5))
gs = gridspec.GridSpec(1, 1)
rdf = pd.read_json('tests/triton_tests/info.jsonl', lines=True)
ax = fig.add_subplot(gs[0, 0])
# now plot the % speedup for different batch sizes
for j, batch_size in enumerate([2**14, 2**15, 2**16, 2**17]):
all_xs, all_ys = [], []
for k, marker, ls, color, name in [
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose+standard_gw+global_fwd+global_bwd', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
('x_quantize_rowwise+g_quantize_rowwise+w_quantize_global+w_quantize_global_transpose', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
]:
xs, ys = [], []
df = rdf[rdf.batch_size == batch_size]
for embed_dim in [1024, 1280, 1408, 1664, 2048, 4096]:
df_ = df[df.dim_in == embed_dim]
df_ = df_[df_.dim_out == embed_dim * 4]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
df_ = df[df.dim_in == embed_dim * 4]
df_ = df_[df_.dim_out == embed_dim]
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_ * 0.5)
all_xs.append(xs)
all_ys.append(ys)
color = cmap(j * 0.25)
real_ys = [100 * all_ys[1][i] / all_ys[0][i] for i in range(len(all_ys[0]))]
markers = ['^', 'v', 'P', 'o']
ax.plot(all_xs[0], real_ys, color=color, label=f'batch * sequence length = {batch_size}', marker=markers[j], markersize=5 if marker=='s' else 5)
ax.legend()
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.grid()
ax.set_ylabel(r'% time occupied by quantize ops', fontsize=12)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048, 4096])
ax.set_xticklabels([1024, 2048, 4096])
ax.set_xticks([], minor=True)
#ax.set_title(' Linear layer summary, varying dimensions', fontsize=10, loc='left', y=1.05, pad=-20)
plt.savefig('tests/triton_tests/plot2.pdf', bbox_inches='tight')
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import matplotlib.lines as mlines
import matplotlib.gridspec as gridspec
cmap=plt.get_cmap('cool')
if __name__ == '__main__':
fig = plt.figure(tight_layout=True, figsize=(12,3.5))
gs = gridspec.GridSpec(1, 3)
rdf1 = pd.read_json('tests/triton_tests/info_mlp_autocast_ln.jsonl', lines=True)
ax = fig.add_subplot(gs[0, 0])
# now plot the % speedup for different batch sizes
for j, batch_size in enumerate([2**15, 2**17]):#, 2**15, 2**17, 2**17]):
all_xs, all_ys = {}, {}
for k, marker, ls, color, name in [
('standard_compiled', 'o', '-', 'C0', 'standard compiled (total time)'),
#('standard', 'o', '-', 'C1', 'standard (total time)'),
('my_standard', 'o', '-', 'C2', 'my standard (total time)'),
('sb', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
]:
xs, ys = [], []
df = rdf1[rdf1.batch_size == batch_size]
for embed_dim in [1024, 1280, 1408, 1664, 2048]:
df_ = df[df.dim == embed_dim]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_)
all_xs[k] = xs
all_ys[k] = ys
#ax.plot(xs, ys, color=color, label=f'batch * sequence length = {batch_size}', marker=marker, markersize=5 if marker=='s' else 5)
color= cmap(float(j))
speedup_over_my_standard = [-100 * (all_ys['sb'][i] - all_ys['my_standard'][i]) / all_ys['my_standard'][i] for i in range(len(all_ys['my_standard']))]
speedup_over_compile = [-100 * (all_ys['sb'][i] - all_ys['standard_compiled'][i]) / all_ys['standard_compiled'][i] for i in range(len(all_ys['standard_compiled']))]
ax.plot(xs, speedup_over_my_standard, color=color, label=f'batch * sequence length = {batch_size}', marker='o', markersize=5 if marker=='s' else 5)
ax.plot(xs, speedup_over_compile, color=color, label=f'batch * sequence length = {batch_size}', marker='o', markersize=5 if marker=='s' else 5, linestyle='--')
#ax.legend()
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.grid()
ax.set_ylabel(r'% speedup', fontsize=12)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048])
ax.set_xticklabels([1024, 2048])
ax.set_xticks([], minor=True)
ax.set_title('MLP Block', fontsize=10, loc='left', y=1.07, pad=-20)
##########################################
rdf2 = pd.read_json('tests/triton_tests/attn_info_ln.jsonl', lines=True)
ax = fig.add_subplot(gs[0, 1])
for j, batch_size in enumerate([2**15, 2**17]):#, 2**15, 2**17, 2**17]):
all_xs, all_ys = {}, {}
for k, marker, ls, color, name in [
('standard_compiled', 'o', '-', 'C0', 'standard compiled (total time)'),
#('standard', 'o', '-', 'C1', 'standard (total time)'),
('my_standard', 'o', '-', 'C2', 'my standard (total time)'),
('sb', 'o', '-', 'C4', 'SwitchBack int8 (total time)'),
]:
xs, ys = [], []
df = rdf2[rdf2.batch_size == batch_size]
for embed_dim in [1024, 1280, 1408, 1664, 2048]:
df_ = df[df.dim == embed_dim]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_)
all_xs[k] = xs
all_ys[k] = ys
#ax.plot(xs, ys, color=color, label=f'batch * sequence length = {batch_size}', marker=marker, markersize=5 if marker=='s' else 5)
color= cmap(float(j))
speedup_over_my_standard = [-100 * (all_ys['sb'][i] - all_ys['my_standard'][i]) / all_ys['my_standard'][i] for i in range(len(all_ys['my_standard']))]
speedup_over_compile = [-100 * (all_ys['sb'][i] - all_ys['standard_compiled'][i]) / all_ys['standard_compiled'][i] for i in range(len(all_ys['standard_compiled']))]
ax.plot(xs, speedup_over_my_standard, color=color, label=f'batch * sequence length = {batch_size}', marker='o', markersize=5 if marker=='s' else 5)
ax.plot(xs, speedup_over_compile, color=color, label=f'batch * sequence length = {batch_size}', marker='o', markersize=5 if marker=='s' else 5, linestyle='--')
speedup_compiled = mlines.Line2D([], [], linestyle='--', color='gray', label='speedup over compiled')
speedup_baseline = mlines.Line2D([], [], linestyle='-', color='gray', label='speedup over baseline')
batch_size_4 = mlines.Line2D([], [], linestyle='-', color=cmap(0.), label=f'batch = {int(2**15 // 256)}, sequence = {256}')
batch_size_8 = mlines.Line2D([], [], linestyle='-', color=cmap(1.), label=f'batch = {int(2**17 / 256)} sequence = {256}')
# Create the legend with the proxy artists
# adjust plots so that they dont get squished by putting the legend under both
plt.subplots_adjust(left=0.2)
plt.subplots_adjust(right=0.8)
fig.legend(handles=[speedup_compiled, speedup_baseline, batch_size_4, batch_size_8], ncol=2, loc='upper center', bbox_to_anchor=(0.35, 0.255))
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.grid()
ax.set_ylabel(r'% speedup', fontsize=12)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048])
ax.set_xticklabels([1024, 2048])
ax.set_xticks([], minor=True)
ax.set_title('Attention Block', fontsize=10, loc='left', y=1.07, pad=-20)
##########################################
ax = fig.add_subplot(gs[0, 2])
for j, batch_size in enumerate([2**15]):#, 2**15, 2**17, 2**17]):
all_xs, all_ys = {}, {}
for k, marker, ls, color, name, b in [
('standard_compiled', 'o', '-', 'C0', 'standard compiled (total time)', False),
('standard_compiled', 'o', '-', 'C0', 'standard compiled (total time)', True),
#('standard', 'o', '-', 'C1', 'standard (total time)'),
#('my_standard', 'o', '-', 'C2', 'my standard (total time)'),
('attn', 'o', '-', 'C4', 'SwitchBack int8 (total time)', True),
]:
rdf = rdf2 if b else rdf1
xs, ys = [], []
df = rdf[rdf.batch_size == batch_size]
for embed_dim in [1024, 1280, 1408, 1664, 2048]:
df_ = df[df.dim == embed_dim]
xs.append(embed_dim)
y_ = 0
for k_ in k.split('+'):
y_ += df_[k_].values[0]
ys.append(y_)
all_xs[k + str(int(b))] = xs
all_ys[k + str(int(b))] = ys
#ax.plot(xs, ys, color=color, label=f'batch * sequence length = {batch_size}', marker=marker, markersize=5 if marker=='s' else 5)
print(all_ys.keys())
all_ys['standard_compiled'] = [x + y for x, y in zip(all_ys['standard_compiled0'], all_ys['standard_compiled1'])]
speedup_over_my_standard = [100 * all_ys['attn1'][i] / (all_ys['standard_compiled'][i] + all_ys['attn1'][i]) for i in range(len(all_ys['standard_compiled']))]
ax.plot(xs, speedup_over_my_standard, color='gold', label=r'% time occupied by attention', marker='H', markersize=8)
speedup_over_my_standard = [100 * all_ys['standard_compiled1'][i] / (all_ys['standard_compiled0'][i] + all_ys['standard_compiled1'][i]) for i in range(len(all_ys['standard_compiled']))]
ax.plot(xs, speedup_over_my_standard, color='indianred', label=r'% time occupied by attention block', marker='P', markersize=8)
ax.legend(bbox_to_anchor=(1.02, -0.27))
ax.set_xlabel('dim', fontsize=13)
ax.set_xscale('log')
ax.grid()
ax.set_ylabel(r'% time', fontsize=12)
ax.tick_params(axis='x', labelsize=11)
ax.tick_params(axis='y', labelsize=11)
ax.set_xticks([1024, 2048])
ax.set_xticklabels([1024, 2048])
ax.set_xticks([], minor=True)
plt.savefig('tests/triton_tests/plot3.pdf', bbox_inches='tight')
import time
import torch
import torch
import torch.nn as nn
import bitsandbytes.nn as bnn
from bitsandbytes.nn.triton_based_modules import SwitchBackLinear, SwitchBackGlobalLinear
from bitsandbytes.nn.triton_utils.v0.quantize_rowwise_nogroup import quantize_rowwise_nogroup
# 256 * 256 * 4096 _> 0.7
# 256 * 128 * 8192 -> 10
if __name__ == '__main__':
torch.manual_seed(0)
# hparams
repeat = 16
dim=8192
layers = 4
batch_size = 256 * 128
# simulate forward pass
x = torch.randn(batch_size, dim, dtype=torch.float16).cuda()
for _ in range(repeat // 2):
quantize_rowwise_nogroup(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(repeat):
quantize_rowwise_nogroup(x)
torch.cuda.synchronize()
end = time.time()
print(f"time: {(end - start) / repeat * 1000:.3f} ms")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment