Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a0592c05
Unverified
Commit
a0592c05
authored
Dec 15, 2024
by
Xiaoyu Zhang
Committed by
GitHub
Dec 15, 2024
Browse files
[Benchmark] add a benchmark for hf/vllm/sglang rmsnorm (#2486)
parent
e8dbdf75
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
231 additions
and
0 deletions
+231
-0
benchmark/kernels/rmsnorm/benchmark_rmsnorm.py
benchmark/kernels/rmsnorm/benchmark_rmsnorm.py
+231
-0
No files found.
benchmark/kernels/rmsnorm/benchmark_rmsnorm.py
0 → 100644
View file @
a0592c05
import
itertools
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
triton
import
triton.language
as
tl
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
from
torch
import
nn
from
vllm
import
_custom_ops
as
vllm_ops
class
HuggingFaceRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
)
->
None
:
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
if
residual
is
not
None
:
x
=
x
+
residual
.
to
(
torch
.
float32
)
residual
=
x
.
to
(
orig_dtype
)
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
.
to
(
orig_dtype
)
*
self
.
weight
if
residual
is
None
:
return
x
else
:
return
x
,
residual
def
rmsnorm_naive
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
eps
:
float
=
1e-6
,
):
naive_norm
=
HuggingFaceRMSNorm
(
x
.
shape
[
-
1
],
eps
=
eps
)
naive_norm
.
weight
=
nn
.
Parameter
(
weight
)
naive_norm
=
naive_norm
.
to
(
x
.
device
)
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
residual
is
not
None
:
residual
=
residual
.
view
(
-
1
,
residual
.
shape
[
-
1
])
output
=
naive_norm
(
x
,
residual
)
if
isinstance
(
output
,
tuple
):
output
=
(
output
[
0
].
view
(
orig_shape
),
output
[
1
].
view
(
orig_shape
))
else
:
output
=
output
.
view
(
orig_shape
)
return
output
def
rmsnorm_flashinfer
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
residual
is
not
None
:
residual
=
residual
.
view
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
is
not
None
:
fused_add_rmsnorm
(
x
,
residual
,
weight
,
eps
)
output
=
(
x
,
residual
)
else
:
output
=
rmsnorm
(
x
,
weight
,
eps
)
if
isinstance
(
output
,
tuple
):
output
=
(
output
[
0
].
view
(
orig_shape
),
output
[
1
].
view
(
orig_shape
))
else
:
output
=
output
.
view
(
orig_shape
)
return
output
def
rmsnorm_vllm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
eps
:
float
=
1e-6
,
):
orig_shape
=
x
.
shape
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
if
residual
is
not
None
:
residual
=
residual
.
view
(
-
1
,
residual
.
shape
[
-
1
])
if
residual
is
not
None
:
vllm_ops
.
fused_add_rms_norm
(
x
,
residual
,
weight
,
eps
)
output
=
(
x
,
residual
)
else
:
out
=
torch
.
empty_like
(
x
)
vllm_ops
.
rms_norm
(
out
,
x
,
weight
,
eps
)
output
=
out
if
isinstance
(
output
,
tuple
):
output
=
(
output
[
0
].
view
(
orig_shape
),
output
[
1
].
view
(
orig_shape
))
else
:
output
=
output
.
view
(
orig_shape
)
return
output
def
calculate_diff
(
batch_size
,
seq_len
,
hidden_size
,
use_residual
=
True
):
dtype
=
torch
.
bfloat16
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
output_naive
=
rmsnorm_naive
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
output_flashinfer
=
rmsnorm_flashinfer
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
output_vllm
=
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
if
use_residual
:
output_naive
=
output_naive
[
0
]
output_flashinfer
=
output_flashinfer
[
0
]
output_vllm
=
output_vllm
[
0
]
print
(
f
"Naive output=
{
output_naive
}
"
)
print
(
f
"FlashInfer output=
{
output_flashinfer
}
"
)
print
(
f
"VLLM output=
{
output_vllm
}
"
)
if
torch
.
allclose
(
output_naive
,
output_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
seq_length_range
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
head_num_range
=
[
32
,
48
]
configs
=
list
(
itertools
.
product
(
head_num_range
,
batch_size_range
,
seq_length_range
))
def
get_benchmark
(
use_residual
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"head_num"
,
"batch_size"
,
"seq_len"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
],
line_arg
=
"provider"
,
line_vals
=
[
"huggingface"
,
"flashinfer"
,
"vllm"
],
line_names
=
[
"HuggingFace"
,
"FlashInfer"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
f
"rmsnorm-performance-
{
'with'
if
use_residual
else
'without'
}
-residual"
,
args
=
{},
)
)
def
benchmark
(
head_num
,
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
bfloat16
hidden_size
=
head_num
*
128
# assuming head_dim = 128
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"huggingface"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rmsnorm_naive
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
),
quantiles
=
quantiles
,
)
elif
provider
==
"flashinfer"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rmsnorm_flashinfer
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
),
quantiles
=
quantiles
,
)
else
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
,
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
benchmark
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--use_residual"
,
action
=
"store_true"
,
help
=
"Whether to use residual connection"
)
parser
.
add_argument
(
"--save_path"
,
type
=
str
,
default
=
"./configs/benchmark_ops/rmsnorm/"
,
help
=
"Path to save rmsnorm benchmark results"
,
)
args
=
parser
.
parse_args
()
# Run correctness test
calculate_diff
(
batch_size
=
4
,
seq_len
=
128
,
hidden_size
=
4096
,
use_residual
=
args
.
use_residual
)
# Get the benchmark function with proper use_residual setting
benchmark
=
get_benchmark
(
args
.
use_residual
)
# Run performance benchmark
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment