Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5fd24ec0
Unverified
Commit
5fd24ec0
authored
Jan 16, 2025
by
Varun Sundar Rabindranath
Committed by
GitHub
Jan 16, 2025
Browse files
[misc] Add LoRA kernel micro benchmarks (#11579)
parent
874f7c29
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1357 additions
and
0 deletions
+1357
-0
benchmarks/kernels/benchmark_lora.py
benchmarks/kernels/benchmark_lora.py
+1147
-0
benchmarks/kernels/utils.py
benchmarks/kernels/utils.py
+210
-0
No files found.
benchmarks/kernels/benchmark_lora.py
0 → 100644
View file @
5fd24ec0
This diff is collapsed.
Click to expand it.
benchmarks/kernels/utils.py
0 → 100644
View file @
5fd24ec0
import
dataclasses
from
typing
import
Any
,
Callable
,
Iterable
,
Optional
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
@
dataclasses
.
dataclass
class
CudaGraphBenchParams
:
num_ops_in_cuda_graph
:
int
@
dataclasses
.
dataclass
class
ArgPool
:
"""
When some argument of the benchmarking function is annotated with this type,
the benchmarking class (BenchMM) will collapse the argument to a pick a
single value from the given list of values, during function invocation.
For every invocation during a benchmarking run, it will choose a
different value from the list.
"""
values
:
Iterable
[
Any
]
def
__getitem__
(
self
,
index
):
return
self
.
values
[
index
]
class
Bench
:
class
ArgsIterator
:
def
__init__
(
self
,
args_list
,
kwargs_list
):
assert
len
(
args_list
)
==
len
(
kwargs_list
)
self
.
args_list
=
args_list
self
.
kwargs_list
=
kwargs_list
self
.
n
=
len
(
self
.
args_list
)
self
.
idx
=
0
def
__next__
(
self
):
while
True
:
yield
(
self
.
args_list
[
self
.
idx
],
self
.
kwargs_list
[
self
.
idx
])
self
.
idx
+=
1
self
.
idx
=
self
.
idx
%
self
.
n
def
reset
(
self
):
self
.
idx
=
0
@
property
def
n_args
(
self
):
return
self
.
n
def
__init__
(
self
,
cuda_graph_params
:
Optional
[
CudaGraphBenchParams
],
label
:
str
,
sub_label
:
str
,
description
:
str
,
fn
:
Callable
,
*
args
,
**
kwargs
):
self
.
cuda_graph_params
=
cuda_graph_params
self
.
use_cuda_graph
=
self
.
cuda_graph_params
is
not
None
self
.
label
=
label
self
.
sub_label
=
sub_label
self
.
description
=
description
self
.
fn
=
fn
# Process args
self
.
_args
=
args
self
.
_kwargs
=
kwargs
self
.
args_list
,
self
.
kwargs_list
=
self
.
collapse_argpool
(
*
args
,
**
kwargs
)
self
.
args_iterator
=
self
.
ArgsIterator
(
self
.
args_list
,
self
.
kwargs_list
)
# Cudagraph runner
self
.
g
=
None
if
self
.
use_cuda_graph
:
self
.
g
=
self
.
get_cuda_graph_runner
()
# benchmark run params
self
.
min_run_time
=
1
def
collapse_argpool
(
self
,
*
args
,
**
kwargs
):
argpool_args
=
[
arg
for
arg
in
args
if
isinstance
(
arg
,
ArgPool
)]
+
[
arg
for
arg
in
kwargs
.
values
()
if
isinstance
(
arg
,
ArgPool
)
]
if
len
(
argpool_args
)
==
0
:
return
[
args
],
[
kwargs
]
# Make sure all argpools are of the same size
argpool_size
=
len
(
argpool_args
[
0
].
values
)
assert
all
([
argpool_size
==
len
(
arg
.
values
)
for
arg
in
argpool_args
])
# create copies of the args
args_list
=
[]
kwargs_list
=
[]
for
_
in
range
(
argpool_size
):
args_list
.
append
(
args
)
kwargs_list
.
append
(
kwargs
.
copy
())
for
i
in
range
(
argpool_size
):
# collapse args; Just pick the ith value
args_list
[
i
]
=
tuple
([
arg
[
i
]
if
isinstance
(
arg
,
ArgPool
)
else
arg
for
arg
in
args_list
[
i
]
])
# collapse kwargs
kwargs_i
=
kwargs_list
[
i
]
arg_pool_keys
=
[
k
for
k
,
v
in
kwargs_i
.
items
()
if
isinstance
(
v
,
ArgPool
)
]
for
k
in
arg_pool_keys
:
# again just pick the ith value
kwargs_i
[
k
]
=
kwargs_i
[
k
][
i
]
kwargs_list
[
i
]
=
kwargs_i
return
args_list
,
kwargs_list
def
get_cuda_graph_runner
(
self
):
assert
self
.
use_cuda_graph
assert
self
.
args_iterator
is
not
None
num_graph_ops
=
self
.
cuda_graph_params
.
num_ops_in_cuda_graph
# warmup
args_it
=
self
.
args_iterator
.
__next__
()
for
_
in
range
(
2
):
args
,
kwargs
=
next
(
args_it
)
self
.
fn
(
*
args
,
**
kwargs
)
self
.
args_iterator
.
reset
()
args_it
=
self
.
args_iterator
.
__next__
()
stream
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
stream
(
stream
):
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
for
_
in
range
(
num_graph_ops
):
args
,
kwargs
=
next
(
args_it
)
self
.
fn
(
*
args
,
**
kwargs
)
return
g
def
run_cudagrah
(
self
)
->
TMeasurement
:
assert
self
.
use_cuda_graph
globals
=
{
'g'
:
self
.
g
}
return
TBenchmark
.
Timer
(
stmt
=
"g.replay()"
,
globals
=
globals
,
label
=
(
f
"
{
self
.
label
}
"
f
" | cugraph
{
self
.
cuda_graph_params
.
num_ops_in_cuda_graph
}
ops"
),
sub_label
=
self
.
sub_label
,
description
=
self
.
description
,
).
blocked_autorange
(
min_run_time
=
self
.
min_run_time
)
def
run_eager
(
self
)
->
TMeasurement
:
setup
=
None
stmt
=
None
globals
=
None
has_arg_pool
=
self
.
args_iterator
.
n_args
>
1
if
has_arg_pool
:
setup
=
'''
args_iterator.reset()
args_it = args_iterator.__next__()
'''
stmt
=
'''
args, kwargs = next(args_it)
fn(*args, **kwargs)
'''
globals
=
{
'fn'
:
self
.
fn
,
'args_iterator'
:
self
.
args_iterator
}
else
:
# no arg pool. Just use the args and kwargs directly
self
.
args_iterator
.
reset
()
args_it
=
self
.
args_iterator
.
__next__
()
args
,
kwargs
=
next
(
args_it
)
setup
=
""
stmt
=
'''
fn(*args, **kwargs)
'''
globals
=
{
'fn'
:
self
.
fn
,
'args'
:
args
,
'kwargs'
:
kwargs
}
return
TBenchmark
.
Timer
(
stmt
=
stmt
,
setup
=
setup
,
globals
=
globals
,
label
=
self
.
label
,
sub_label
=
self
.
sub_label
,
description
=
self
.
description
,
).
blocked_autorange
(
min_run_time
=
self
.
min_run_time
)
def
run
(
self
)
->
TMeasurement
:
timer
=
None
if
self
.
use_cuda_graph
:
# noqa SIM108
timer
=
self
.
run_cudagrah
()
else
:
timer
=
self
.
run_eager
()
if
not
timer
.
meets_confidence
()
or
timer
.
has_warnings
:
print
(
"Doesn't meet confidence - re-running bench ..."
)
return
self
.
run
()
return
timer
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
if
exc_type
:
print
(
f
"exc type
{
exc_type
}
"
)
print
(
f
"exc value
{
exc_value
}
"
)
print
(
f
"exc traceback
{
traceback
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment