Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5fd24ec0
Unverified
Commit
5fd24ec0
authored
Jan 16, 2025
by
Varun Sundar Rabindranath
Committed by
GitHub
Jan 16, 2025
Browse files
[misc] Add LoRA kernel micro benchmarks (#11579)
parent
874f7c29
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1357 additions
and
0 deletions
+1357
-0
benchmarks/kernels/benchmark_lora.py
benchmarks/kernels/benchmark_lora.py
+1147
-0
benchmarks/kernels/utils.py
benchmarks/kernels/utils.py
+210
-0
No files found.
benchmarks/kernels/benchmark_lora.py
0 → 100644
View file @
5fd24ec0
import
argparse
import
copy
import
json
import
pickle
import
time
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
itertools
import
product
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
utils
import
ArgPool
,
Bench
,
CudaGraphBenchParams
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm.lora.ops.triton_ops.bgmv_expand
import
bgmv_expand
from
vllm.lora.ops.triton_ops.bgmv_expand_slice
import
bgmv_expand_slice
from
vllm.lora.ops.triton_ops.bgmv_shrink
import
bgmv_shrink
from
vllm.lora.ops.triton_ops.sgmv_expand
import
sgmv_expand
from
vllm.lora.ops.triton_ops.sgmv_shrink
import
sgmv_shrink
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_TP_SIZES
=
[
1
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
192
,
256
,
320
,
384
,
448
,
512
,
640
,
768
,
896
,
1024
,
2048
,
3072
,
4096
,
5120
,
6144
,
7168
,
8192
]
DEFAULT_HIDDEN_SIZES
=
[
1024
,
2048
,
4096
,
8192
,
16384
]
DEFAULT_LORA_RANKS
=
[
16
]
DEFAULT_NUM_LORAS
=
[
1
,
2
,
3
,
4
]
DEFAULT_SORT_BY_LORA_IDS
=
[
False
,
True
]
DEFAULT_SEQ_LENGTHS
=
[
1
]
DEFAULT_EXPAND_FN_ADD_INPUTS
=
[
True
,
False
]
# Utilities
def
dtype_to_str
(
dtype
:
torch
.
dtype
):
if
dtype
==
torch
.
float16
:
return
"f16"
if
dtype
==
torch
.
bfloat16
:
return
"bf16"
if
dtype
==
torch
.
float32
:
return
"f32"
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
"
)
def
make_rand_lora_weight_tensor
(
k
:
int
,
n
:
int
,
num_loras
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
=
"cuda"
)
->
torch
.
Tensor
:
# LoRA weights column major
return
torch
.
rand
((
num_loras
,
n
,
k
),
dtype
=
dtype
).
to
(
device
)
def
make_rand_tensors
(
a_shape
:
Tuple
[
int
],
b_shape
:
Tuple
[
int
],
c_shape
:
Tuple
[
int
],
a_dtype
:
torch
.
dtype
,
b_dtype
:
torch
.
dtype
,
c_dtype
:
torch
.
dtype
,
num_slices
:
int
,
device
:
str
=
"cuda"
,
)
->
Tuple
[
torch
.
Tensor
,
List
[
torch
.
Tensor
],
torch
.
Tensor
]:
"""
Make LoRA input/output matrices.
"""
A
=
torch
.
rand
(
a_shape
,
dtype
=
a_dtype
).
to
(
device
)
# LoRA weights column major
Bs
=
[
torch
.
rand
(
b_shape
,
dtype
=
b_dtype
).
to
(
device
)
for
_
in
range
(
num_slices
)
]
C
=
torch
.
zeros
(
c_shape
,
dtype
=
c_dtype
).
to
(
device
)
return
A
,
Bs
,
C
def
make_prompt_lora_mapping
(
num_prompts
:
int
,
num_active_loras
:
int
,
sort_by_lora_id
:
bool
,
device
:
str
)
->
torch
.
Tensor
:
"""
All prompts are mapped to a Lora ID in range [0, num_active_loras).
where 0 refers to first lora, 1 refers to second lora and so on.
"""
assert
num_active_loras
>
0
if
not
sort_by_lora_id
:
return
torch
.
randint
(
0
,
num_active_loras
,
(
num_prompts
,
),
dtype
=
torch
.
long
)
# Divide LoRAs equally and in order.
part_size
=
num_prompts
//
num_active_loras
part_size
=
max
(
part_size
,
1
)
lora_id
=
0
prompt_lora_mapping
=
[]
while
len
(
prompt_lora_mapping
)
<
num_prompts
:
prompt_lora_mapping
.
extend
([
lora_id
]
*
part_size
)
lora_id
=
lora_id
+
1
if
lora_id
+
1
<
num_active_loras
else
lora_id
return
torch
.
tensor
(
prompt_lora_mapping
[:
num_prompts
],
dtype
=
torch
.
long
,
device
=
device
)
def
make_token_lora_mapping
(
num_tokens
:
int
,
num_prompts
:
int
,
prompt_lora_mapping
:
torch
.
Tensor
,
seq_len_tensor
:
torch
.
Tensor
,
device
:
str
):
"""
Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor
"""
assert
prompt_lora_mapping
.
shape
[
0
]
==
num_prompts
# token to lora index mapping
token_lora_mapping
=
[
0
]
*
num_tokens
current_offset
=
0
for
b_id
in
range
(
num_prompts
):
lora_index
=
prompt_lora_mapping
[
b_id
].
item
()
s
=
current_offset
e
=
s
+
seq_len_tensor
[
b_id
].
item
()
token_lora_mapping
[
s
:
e
]
=
[
lora_index
]
*
(
e
-
s
)
current_offset
+=
seq_len_tensor
[
b_id
].
item
()
return
torch
.
tensor
(
token_lora_mapping
,
dtype
=
torch
.
long
,
device
=
device
)
def
ref_group_gemm
(
ref_out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
lora_weights
:
List
[
torch
.
Tensor
],
seq_lens_cpu
:
torch
.
Tensor
,
prompt_lora_mapping_cpu
:
torch
.
Tensor
,
scaling
:
float
,
add_inputs
:
Optional
[
bool
]):
"""
Torch group gemm reference implementation to test correctness of
benchmarking operations.
"""
batches
=
seq_lens_cpu
.
size
(
0
)
out_list
=
[]
current_offset
=
0
for
lora_index
,
b_length
in
zip
(
range
(
batches
),
seq_lens_cpu
):
x
=
input
[
current_offset
:
b_length
+
current_offset
,
:]
current_offset
+=
b_length
w
=
lora_weights
[
prompt_lora_mapping_cpu
[
lora_index
]]
result
=
torch
.
nn
.
functional
.
linear
(
x
,
w
)
result
*=
scaling
out_list
.
append
(
result
)
torch
.
cat
(
out_list
,
dim
=
0
)
cat_result
=
torch
.
cat
(
out_list
,
dim
=
0
)
if
add_inputs
:
ref_out
+=
cat_result
else
:
ref_out
.
copy_
(
cat_result
)
class
OpType
(
Enum
):
"""
LoRA Ops to benchmark and its properties.
"""
SGMV_SHRINK
=
auto
()
BGMV_SHRINK
=
auto
()
SGMV_EXPAND
=
auto
()
BGMV_EXPAND
=
auto
()
BGMV_EXPAND_SLICE
=
auto
()
@
staticmethod
def
from_str
(
s
:
str
)
->
"OpType"
:
if
s
.
lower
()
==
'sgmv_shrink'
:
return
OpType
.
SGMV_SHRINK
if
s
.
lower
()
==
'sgmv_expand'
:
return
OpType
.
SGMV_EXPAND
if
s
.
lower
()
==
'bgmv_shrink'
:
return
OpType
.
BGMV_SHRINK
if
s
.
lower
()
==
'bgmv_expand'
:
return
OpType
.
BGMV_EXPAND
if
s
.
lower
()
==
"bgmv_expand_slice"
:
return
OpType
.
BGMV_EXPAND_SLICE
raise
ValueError
(
f
"Unrecognized str
{
s
}
to convert to OpType"
)
def
is_shrink_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
SGMV_SHRINK
,
OpType
.
BGMV_SHRINK
]
def
is_expand_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
SGMV_EXPAND
,
OpType
.
BGMV_EXPAND
]
def
is_prefill_op
(
self
)
->
bool
:
return
self
in
[
OpType
.
SGMV_SHRINK
,
OpType
.
SGMV_EXPAND
]
def
is_decode_op
(
self
)
->
bool
:
return
self
in
[
OpType
.
BGMV_SHRINK
,
OpType
.
BGMV_EXPAND
,
OpType
.
BGMV_EXPAND_SLICE
]
def
is_expand_slice_fn
(
self
)
->
bool
:
return
self
in
[
OpType
.
BGMV_EXPAND_SLICE
]
def
num_slices
(
self
)
->
List
[
int
]:
if
self
in
[
OpType
.
SGMV_EXPAND
,
OpType
.
SGMV_SHRINK
]:
# SGMV kernels supports slices
return
[
1
,
2
,
3
]
if
self
in
[
OpType
.
BGMV_SHRINK
,
OpType
.
BGMV_EXPAND
]:
return
[
1
]
if
self
in
[
OpType
.
BGMV_EXPAND_SLICE
]:
return
[
2
,
3
]
raise
ValueError
(
f
"Unrecognized OpType
{
self
}
"
)
def
mkn
(
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
)
->
Tuple
[
int
,
int
,
int
]:
num_tokens
=
batch_size
*
seq_length
if
self
.
is_shrink_fn
():
m
=
num_tokens
k
=
hidden_size
n
=
lora_rank
else
:
assert
self
.
is_expand_fn
()
or
self
.
is_expand_slice_fn
()
m
=
num_tokens
k
=
lora_rank
n
=
hidden_size
return
m
,
k
,
n
def
matmul_dtypes
(
self
,
op_dtype
:
torch
.
dtype
)
->
Tuple
[
torch
.
dtype
,
torch
.
dtype
,
torch
.
dtype
]:
"""
return a type, b type and c type for A x B = C
"""
if
self
.
is_shrink_fn
():
return
op_dtype
,
op_dtype
,
torch
.
float32
else
:
assert
self
.
is_expand_fn
()
or
self
.
is_expand_slice_fn
()
return
torch
.
float32
,
op_dtype
,
op_dtype
def
matmul_shapes
(
self
,
batch_size
:
int
,
seq_length
:
int
,
hidden_size
:
int
,
lora_rank
:
int
,
num_loras
:
int
,
num_slices
:
int
)
->
Tuple
[
Tuple
[
int
],
Tuple
[
int
],
Tuple
[
int
]]:
"""
Given num_slices, return the shapes of the A, B, and C matrices
in A x B = C, for the op_type
"""
m
,
k
,
n
=
self
.
mkn
(
batch_size
,
seq_length
,
hidden_size
,
lora_rank
)
b_shape
=
(
num_loras
,
n
,
k
)
# col-major
if
self
==
OpType
.
SGMV_SHRINK
:
# SGMV shrink supports num_slices inherently in the kernel
return
((
m
,
k
),
b_shape
,
(
num_slices
,
m
,
n
))
if
self
==
OpType
.
SGMV_EXPAND
:
# SGMV expand supports num_slices inherently in the kernel
return
((
num_slices
,
m
,
k
),
b_shape
,
(
m
,
n
*
num_slices
))
if
self
==
OpType
.
BGMV_SHRINK
:
return
((
m
,
k
),
b_shape
,
(
m
,
n
))
if
self
==
OpType
.
BGMV_EXPAND
:
return
((
m
,
k
),
b_shape
,
(
m
,
n
))
if
self
==
OpType
.
BGMV_EXPAND_SLICE
:
return
((
num_slices
,
m
,
k
),
b_shape
,
(
m
,
n
*
num_slices
))
raise
ValueError
(
f
"Unrecognized op_type
{
self
}
"
)
def
bench_fn
(
self
)
->
Callable
:
def
emulate_bgmv_expand_slice
(
kwargs_list
:
List
[
Dict
[
str
,
Any
]]):
for
x
in
kwargs_list
:
bgmv_expand_slice
(
**
x
)
if
self
==
OpType
.
SGMV_SHRINK
:
return
sgmv_shrink
if
self
==
OpType
.
SGMV_EXPAND
:
return
sgmv_expand
if
self
==
OpType
.
BGMV_SHRINK
:
return
bgmv_shrink
if
self
==
OpType
.
BGMV_EXPAND
:
return
bgmv_expand
if
self
==
OpType
.
BGMV_EXPAND_SLICE
:
return
emulate_bgmv_expand_slice
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
def
run_ref_group_gemm
(
self
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
lora_weights
:
List
[
torch
.
Tensor
],
**
kwargs
)
->
Callable
:
"""Each benchmark operation expected the input, lora_weights and outputs
in a slightly different format. Refer to self.matmul_shapes().
run_ref_group_gemm accounts for those differences in executing a
reference group gemm for correctness testing.
"""
w_dtype
=
lora_weights
[
0
].
dtype
num_slices
=
len
(
lora_weights
)
if
self
==
OpType
.
SGMV_SHRINK
:
for
slice_idx
in
range
(
num_slices
):
ref_group_gemm
(
ref_out
=
output
[
slice_idx
,
:],
input
=
input
,
lora_weights
=
lora_weights
[
slice_idx
],
**
kwargs
)
if
self
==
OpType
.
SGMV_EXPAND
:
hidden_size
=
lora_weights
[
0
].
shape
[
1
]
for
slice_idx
in
range
(
num_slices
):
slice_offset
=
slice_idx
*
hidden_size
ref_group_gemm
(
ref_out
=
output
[:,
slice_offset
:
slice_offset
+
hidden_size
],
input
=
input
[
slice_idx
].
clone
().
to
(
dtype
=
w_dtype
),
lora_weights
=
lora_weights
[
slice_idx
],
**
kwargs
)
if
self
==
OpType
.
BGMV_SHRINK
:
assert
num_slices
==
1
ref_group_gemm
(
ref_out
=
output
,
input
=
input
,
lora_weights
=
lora_weights
[
0
],
**
kwargs
)
if
self
==
OpType
.
BGMV_EXPAND
:
assert
num_slices
==
1
ref_group_gemm
(
ref_out
=
output
,
input
=
input
.
clone
().
to
(
dtype
=
w_dtype
),
lora_weights
=
lora_weights
[
0
],
**
kwargs
)
if
self
==
OpType
.
BGMV_EXPAND_SLICE
:
hidden_size
=
lora_weights
[
0
].
shape
[
1
]
for
slice_idx
in
range
(
num_slices
):
slice_offset
=
slice_idx
*
hidden_size
ref_group_gemm
(
ref_out
=
output
[:,
slice_offset
:
slice_offset
+
hidden_size
],
input
=
input
[
slice_idx
].
clone
().
to
(
dtype
=
w_dtype
),
lora_weights
=
lora_weights
[
slice_idx
],
**
kwargs
)
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
@
dataclass
class
BenchmarkContext
:
"""
LoRA benchmark context
"""
batch_size
:
int
hidden_size
:
int
num_loras
:
int
num_active_loras
:
int
lora_rank
:
int
sort_by_lora_id
:
bool
dtype
:
torch
.
dtype
seq_length
:
Optional
[
int
]
=
None
num_slices
:
Optional
[
int
]
=
None
# num_slices for slice based ops
def
with_seq_length
(
self
,
seq_length
:
int
)
->
"BenchmarkContext"
:
ctx
=
copy
.
copy
(
self
)
ctx
.
seq_length
=
seq_length
return
ctx
def
with_num_slices
(
self
,
num_slices
:
int
)
->
"BenchmarkContext"
:
ctx
=
copy
.
copy
(
self
)
ctx
.
num_slices
=
num_slices
return
ctx
def
bench_label
(
self
)
->
str
:
return
f
"lora-
{
self
.
dtype
}
"
def
bench_sublabel
(
self
,
op_type
:
OpType
)
->
str
:
m
,
k
,
n
=
op_type
.
mkn
(
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
,
self
.
lora_rank
)
desc
=
{
'bs'
:
self
.
batch_size
,
'sl'
:
self
.
seq_length
,
'm'
:
m
,
'k'
:
k
,
'n'
:
n
,
'num_loras'
:
self
.
num_loras
,
'sort_by_lora'
:
self
.
sort_by_lora_id
,
'num_slices'
:
self
.
num_slices
,
}
return
json
.
dumps
(
desc
)
@
dataclass
class
BenchmarkTensors
:
"""
Input/Output tensors used for benchmarks
"""
# matmul tensors
input
:
torch
.
Tensor
lora_weights_lst
:
List
[
torch
.
Tensor
]
output
:
torch
.
Tensor
# metadata tensors
seq_lens
:
torch
.
Tensor
seq_start_loc
:
torch
.
Tensor
prompt_lora_mapping
:
torch
.
Tensor
token_lora_mapping
:
torch
.
Tensor
def
io_types
(
self
)
->
str
:
return
(
f
"
{
dtype_to_str
(
self
.
input
.
dtype
)
}
x"
f
"
{
dtype_to_str
(
self
.
lora_weights_lst
[
0
].
dtype
)
}
=>"
f
"
{
dtype_to_str
(
self
.
output
.
dtype
)
}
"
)
@
staticmethod
def
make
(
ctx
:
BenchmarkContext
,
op_type
:
OpType
,
device
:
str
=
"cuda"
)
->
"BenchmarkTensors"
:
# Make input / output matmul tensors.
a_shape
,
b_shape
,
c_shape
=
op_type
.
matmul_shapes
(
ctx
.
batch_size
,
ctx
.
seq_length
,
ctx
.
hidden_size
,
ctx
.
lora_rank
,
ctx
.
num_loras
,
ctx
.
num_slices
)
a_type
,
b_type
,
c_type
=
op_type
.
matmul_dtypes
(
ctx
.
dtype
)
input_tensor
,
lora_weights
,
output_tensor
=
\
make_rand_tensors
(
a_shape
,
b_shape
,
c_shape
,
a_type
,
b_type
,
c_type
,
num_slices
=
ctx
.
num_slices
)
# Make metadata tensors.
# Keep the metadata tensors in the CPU for further processing if needed.
# The tensors get moved to the GPU before benchmarking.
assert
ctx
.
num_active_loras
<=
ctx
.
num_loras
total_tokens
=
ctx
.
batch_size
*
ctx
.
seq_length
# Prepare seq lens tensor
seq_len_tensor
=
torch
.
randint
(
ctx
.
seq_length
,
ctx
.
seq_length
+
1
,
(
ctx
.
batch_size
,
))
# Prepare seq_start_loc tensor
seq_start_loc_tensor
=
torch
.
cumsum
(
torch
.
tensor
(
[
0
]
+
seq_len_tensor
[:
-
1
].
tolist
(),
dtype
=
torch
.
long
),
dim
=
0
)
assert
total_tokens
==
seq_len_tensor
.
sum
()
# Prepare prompt lora indices tensor
prompt_lora_indices_tensor
=
make_prompt_lora_mapping
(
ctx
.
batch_size
,
ctx
.
num_active_loras
,
ctx
.
sort_by_lora_id
,
"cpu"
)
# Prepare token lora indices tensor
token_lora_indices_tensor
=
make_token_lora_mapping
(
total_tokens
,
ctx
.
batch_size
,
prompt_lora_indices_tensor
,
seq_len_tensor
,
"cpu"
)
return
BenchmarkTensors
(
input_tensor
,
lora_weights
,
output_tensor
,
seq_len_tensor
,
seq_start_loc_tensor
,
prompt_lora_indices_tensor
,
token_lora_indices_tensor
)
def
sanity_check
(
self
)
->
None
:
"""
Fails asserts when non-conformality is detected.
"""
num_tokens
=
self
.
input
.
shape
[
-
2
]
# check metadata tensors
assert
torch
.
sum
(
self
.
seq_lens
)
==
num_tokens
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
assert
self
.
seq_start_loc
.
shape
[
0
]
==
num_seqs
assert
self
.
prompt_lora_mapping
.
shape
[
0
]
==
num_seqs
assert
self
.
token_lora_mapping
.
shape
[
0
]
==
num_tokens
def
to_device
(
self
,
device
:
str
):
"""
Transfer tensors to device if the tensors aren't already on the device
"""
def
to_device
(
tensor
:
torch
.
Tensor
):
if
tensor
.
device
!=
device
:
tensor
=
tensor
.
to
(
device
=
device
)
return
tensor
self
.
input
=
to_device
(
self
.
input
)
self
.
output
=
to_device
(
self
.
output
)
self
.
seq_lens
=
to_device
(
self
.
seq_lens
)
self
.
seq_start_loc
=
to_device
(
self
.
seq_start_loc
)
self
.
prompt_lora_mapping
=
to_device
(
self
.
prompt_lora_mapping
)
self
.
token_lora_mapping
=
to_device
(
self
.
token_lora_mapping
)
for
i
in
range
(
len
(
self
.
lora_weights_lst
)):
self
.
lora_weights_lst
[
i
]
=
to_device
(
self
.
lora_weights_lst
[
i
])
def
metadata
(
self
)
->
Tuple
[
int
,
int
,
int
]:
"""
Return num_seqs, num_tokens and max_seq_len
"""
num_seqs
=
self
.
seq_lens
.
shape
[
0
]
num_tokens
=
self
.
token_lora_mapping
.
shape
[
0
]
max_seq_len
=
torch
.
max
(
self
.
seq_lens
).
item
()
num_slices
=
len
(
self
.
lora_weights_lst
)
return
num_seqs
,
num_tokens
,
max_seq_len
,
num_slices
def
convert_to_sgmv_benchmark_tensors
(
self
):
"""
For sgmv punica kernels, when consecutive sequences have the
same LoRA ID, we just merge them together.
This happens in punica.py::compute_metadata
"""
# Collapse seq_lens and seq_start_loc
_
,
seq_lens
=
torch
.
unique_consecutive
(
self
.
token_lora_mapping
,
return_counts
=
True
)
cum_result
=
torch
.
cumsum
(
seq_lens
,
dim
=
0
)
seq_start_loc
=
torch
.
zeros_like
(
seq_lens
)
seq_start_loc
[
1
:].
copy_
(
cum_result
[:
-
1
])
# Collapse prompt mapping
prompt_lora_mapping
=
torch
.
unique_consecutive
(
self
.
prompt_lora_mapping
)
assert
torch
.
sum
(
seq_lens
)
==
torch
.
sum
(
self
.
seq_lens
),
\
f
"dont match - new
{
torch
.
sum
(
seq_lens
)
}
vs
{
torch
.
sum
(
self
.
seq_lens
)
}
"
self
.
prompt_lora_mapping
=
prompt_lora_mapping
.
to
(
dtype
=
self
.
prompt_lora_mapping
.
dtype
)
self
.
seq_lens
=
seq_lens
.
to
(
dtype
=
self
.
seq_lens
.
dtype
)
self
.
seq_start_loc
=
seq_start_loc
.
to
(
dtype
=
self
.
seq_start_loc
.
dtype
)
def
as_sgmv_shrink_kwargs
(
self
)
->
Dict
[
str
,
Any
]:
self
.
convert_to_sgmv_benchmark_tensors
()
self
.
sanity_check
()
self
.
to_device
(
self
.
input
.
device
)
num_seqs
,
num_tokens
,
max_seq_len
,
num_slices
=
self
.
metadata
()
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
# Expected input shape [num_tokens, hidden_size]
assert
len
(
i_shape
)
==
2
assert
i_shape
[
0
]
==
num_tokens
hidden_size
=
i_shape
[
1
]
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
assert
len
(
lw_shape
)
==
3
assert
lw_shape
[
2
]
==
hidden_size
lora_rank
=
lw_shape
[
1
]
# Expected output shape [num_slices, num_tokens, lora_rank]
assert
len
(
o_shape
)
==
3
assert
o_shape
==
(
num_slices
,
num_tokens
,
lora_rank
)
return
{
'inputs'
:
self
.
input
,
'lora_a_weights'
:
self
.
lora_weights_lst
,
'output_tensor'
:
self
.
output
,
'b_seq_start_loc'
:
self
.
seq_start_loc
,
'seq_len_tensor'
:
self
.
seq_lens
,
'lora_indices_tensor'
:
self
.
prompt_lora_mapping
,
'batches'
:
num_seqs
,
'max_seq_length'
:
max_seq_len
,
'token_nums'
:
num_tokens
,
'scaling'
:
1.0
,
}
def
as_sgmv_expand_kwargs
(
self
,
add_inputs
:
bool
)
->
Dict
[
str
,
Any
]:
self
.
convert_to_sgmv_benchmark_tensors
()
self
.
sanity_check
()
self
.
to_device
(
self
.
input
.
device
)
num_seqs
,
num_tokens
,
max_seq_len
,
num_slices
=
self
.
metadata
()
# Sanity check matrix shapes.
i_shape
,
lw_shape
,
o_shape
=
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
# Expected input shape : [num_slices, num_tokens, lora_rank]
assert
len
(
i_shape
)
==
3
assert
i_shape
[
0
]
==
num_slices
assert
i_shape
[
1
]
==
num_tokens
lora_rank
=
i_shape
[
2
]
# Expected lora weight shape : [num_lora, hidden_size, lora_rank]
assert
len
(
lw_shape
)
==
3
assert
lw_shape
[
2
]
==
lora_rank
hidden_size
=
lw_shape
[
1
]
# Expected output shape : [num_tokens, hidden_size * num_slices]
assert
len
(
o_shape
)
==
2
assert
o_shape
==
(
num_tokens
,
hidden_size
*
num_slices
)
return
{
'inputs'
:
self
.
input
,
'lora_b_weights'
:
self
.
lora_weights_lst
,
'output_tensor'
:
self
.
output
,
'b_seq_start_loc'
:
self
.
seq_start_loc
,
'seq_len_tensor'
:
self
.
seq_lens
,
'lora_indices_tensor'
:
self
.
prompt_lora_mapping
,
'batches'
:
num_seqs
,
'max_seq_length'
:
max_seq_len
,
'token_nums'
:
num_tokens
,
'offset_start'
:
0
,
'add_inputs'
:
add_inputs
,
}
def
as_bgmv_shrink_kwargs
(
self
)
->
Dict
[
str
,
Any
]:
assert
len
(
self
.
lora_weights_lst
)
==
1
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
_
=
self
.
metadata
()
# Sanity check shapes
i_shape
,
lw_shape
,
o_shape
=
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
# Expected input shape [num_tokens, hidden_size]
assert
len
(
i_shape
)
==
2
assert
i_shape
[
0
]
==
num_tokens
hidden_size
=
i_shape
[
1
]
# Expected lora weight shape [num_loras, lora_rank, hidden_size]
assert
len
(
lw_shape
)
==
3
assert
lw_shape
[
2
]
==
hidden_size
lora_rank
=
lw_shape
[
1
]
# Expected output shape [num_tokens, lora_rank]
assert
len
(
o_shape
)
==
2
assert
o_shape
==
(
num_tokens
,
lora_rank
)
return
{
'inputs'
:
self
.
input
,
'lora_a_weights'
:
self
.
lora_weights_lst
[
0
],
'output_tensor'
:
self
.
output
,
'lora_indices_tensor'
:
self
.
token_lora_mapping
,
'scaling'
:
1.0
}
def
as_bgmv_expand_kwargs
(
self
,
add_inputs
:
bool
):
assert
len
(
self
.
lora_weights_lst
)
==
1
self
.
to_device
(
self
.
input
.
device
)
_
,
num_tokens
,
_
,
_
=
self
.
metadata
()
# Sanity check shapes
i_shape
,
lw_shape
,
o_shape
=
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
# Expected input shape [num_tokens, lora_rank]
assert
len
(
i_shape
)
==
2
assert
i_shape
[
0
]
==
num_tokens
lora_rank
=
i_shape
[
1
]
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
assert
len
(
lw_shape
)
==
3
assert
lw_shape
[
2
]
==
lora_rank
hidden_size
=
lw_shape
[
1
]
# Expected output shape [num_tokens, hidden_size]
assert
len
(
o_shape
)
==
2
assert
o_shape
==
(
num_tokens
,
hidden_size
)
return
{
'inputs'
:
self
.
input
,
'lora_b_weights'
:
self
.
lora_weights_lst
[
0
],
'output_tensor'
:
self
.
output
,
'lora_indices_tensor'
:
self
.
token_lora_mapping
,
'add_inputs'
:
add_inputs
}
def
as_bgmv_expand_slice_kwargs
(
self
,
add_inputs
:
bool
)
->
Dict
[
str
,
Any
]:
_
,
num_tokens
,
_
,
num_slices
=
self
.
metadata
()
# Sanity check shapes
i_shape
,
lw_shape
,
o_shape
=
self
.
input
.
shape
,
self
.
lora_weights_lst
[
0
].
shape
,
self
.
output
.
shape
# Expected input shape [num_slices, num_tokens, lora_rank]
assert
len
(
i_shape
)
==
3
assert
i_shape
[
0
]
==
num_slices
assert
i_shape
[
1
]
==
num_tokens
lora_rank
=
i_shape
[
2
]
# Expected lora weight shape [num_loras, hidden_size, lora_rank]
assert
len
(
lw_shape
)
==
3
assert
lw_shape
[
2
]
==
lora_rank
hidden_size
=
lw_shape
[
1
]
# Expected output shape [num_tokens, hidden_size * num_slices]
assert
len
(
o_shape
)
==
2
assert
o_shape
==
(
num_tokens
,
hidden_size
*
num_slices
)
self
.
to_device
(
self
.
input
.
device
)
kwargs_list
=
[]
for
i
in
range
(
num_slices
):
kwargs_list
.
append
({
'inputs'
:
self
.
input
[
i
],
'lora_b_weights'
:
self
.
lora_weights_lst
[
i
],
'output_tensor'
:
self
.
output
,
'lora_indices_tensor'
:
self
.
token_lora_mapping
,
'slice_offset'
:
i
*
hidden_size
,
'slice_size'
:
hidden_size
,
'add_inputs'
:
add_inputs
,
})
return
{
'kwargs_list'
:
kwargs_list
}
def
bench_fn_kwargs
(
self
,
op_type
:
OpType
,
add_inputs
:
Optional
[
bool
]
=
None
)
->
Dict
[
str
,
Any
]:
if
op_type
.
is_shrink_fn
():
assert
add_inputs
is
None
else
:
assert
add_inputs
is
not
None
if
op_type
==
OpType
.
SGMV_SHRINK
:
return
self
.
as_sgmv_shrink_kwargs
()
if
op_type
==
OpType
.
SGMV_EXPAND
:
return
self
.
as_sgmv_expand_kwargs
(
add_inputs
)
if
op_type
==
OpType
.
BGMV_SHRINK
:
return
self
.
as_bgmv_shrink_kwargs
()
if
op_type
==
OpType
.
BGMV_EXPAND
:
return
self
.
as_bgmv_expand_kwargs
(
add_inputs
)
if
op_type
==
OpType
.
BGMV_EXPAND_SLICE
:
return
self
.
as_bgmv_expand_slice_kwargs
(
add_inputs
)
raise
ValueError
(
f
"Unrecognized optype
{
self
}
"
)
def
test_correctness
(
self
,
op_type
:
OpType
,
expand_fn_add_inputs
:
Optional
[
bool
])
->
bool
:
"""
Test correctness of op_type implementation against a grouped gemm
reference implementation.
"""
seq_lens_cpu
=
self
.
seq_lens
.
to
(
device
=
"cpu"
)
prompt_lora_mapping_cpu
=
self
.
prompt_lora_mapping
.
to
(
device
=
"cpu"
)
ref_output
=
self
.
output
.
clone
()
self
.
output
.
zero_
()
op_type
.
bench_fn
()(
**
self
.
bench_fn_kwargs
(
op_type
,
expand_fn_add_inputs
))
op_type
.
run_ref_group_gemm
(
ref_output
,
self
.
input
,
self
.
lora_weights_lst
,
seq_lens_cpu
=
seq_lens_cpu
,
prompt_lora_mapping_cpu
=
prompt_lora_mapping_cpu
,
scaling
=
1.0
,
add_inputs
=
expand_fn_add_inputs
)
rtol
,
atol
=
{
torch
.
float16
:
(
6e-2
,
6e-2
),
torch
.
bfloat16
:
(
6e-2
,
6e-2
),
torch
.
float32
:
(
1e-2
,
1e-2
),
}[
self
.
output
.
dtype
]
return
torch
.
allclose
(
ref_output
,
self
.
output
,
rtol
=
rtol
,
atol
=
atol
)
def
bench_optype
(
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
op_type
:
OpType
,
cuda_graph_nops
:
Optional
[
int
]
=
None
,
expand_fn_add_inputs
:
Optional
[
bool
]
=
None
,
test_correctness
:
bool
=
False
)
->
TMeasurement
:
assert
arg_pool_size
>=
1
if
op_type
.
is_shrink_fn
():
assert
expand_fn_add_inputs
is
None
else
:
assert
expand_fn_add_inputs
is
not
None
# BenchmarkContext -> BenchmarkTensors
bench_tensors
:
List
[
BenchmarkTensors
]
=
\
[
BenchmarkTensors
.
make
(
ctx
,
op_type
)
for
_
in
range
(
arg_pool_size
)]
for
bt
in
bench_tensors
:
bt
.
sanity_check
()
# Test correctness of our implementation.
if
test_correctness
:
assert
all
([
bt
.
test_correctness
(
op_type
,
expand_fn_add_inputs
)
for
bt
in
bench_tensors
])
# BenchmarkTensors -> Dict (kwargs)
kwargs_list
=
[
bt
.
bench_fn_kwargs
(
op_type
,
add_inputs
=
expand_fn_add_inputs
)
for
bt
in
bench_tensors
]
# Clear LoRA optimization hash-maps.
_LORA_A_PTR_DICT
.
clear
()
_LORA_B_PTR_DICT
.
clear
()
# Run bench function so that _LORA_A_PTR_DICT and _LORA_B_PTR_DICT are setup
for
kwargs
in
kwargs_list
:
op_type
.
bench_fn
()(
**
kwargs
)
torch
.
cuda
.
synchronize
()
# Merge into a single kwargs and qualify arguments as ArgPool
kwargs
=
{
k
:
ArgPool
([])
for
k
in
kwargs_list
[
0
]}
for
_kwargs
in
kwargs_list
:
for
k
,
v
in
_kwargs
.
items
():
kwargs
[
k
].
values
.
append
(
v
)
describe_args
=
(
f
"add_inputs=
{
expand_fn_add_inputs
}
"
if
expand_fn_add_inputs
is
not
None
else
""
)
description
=
(
f
"
{
op_type
.
name
}
(
{
describe_args
}
) (
{
bench_tensors
[
0
].
io_types
()
}
)"
)
cuda_graph_params
=
None
if
cuda_graph_nops
:
cuda_graph_params
=
CudaGraphBenchParams
(
cuda_graph_nops
)
timer
=
None
with
Bench
(
cuda_graph_params
,
ctx
.
bench_label
(),
ctx
.
bench_sublabel
(
op_type
),
description
,
op_type
.
bench_fn
(),
**
kwargs
)
as
bench
:
timer
=
bench
.
run
()
return
timer
def
bench_torch_mm
(
ctx
:
BenchmarkContext
,
arg_pool_size
:
int
,
op_type
:
OpType
,
cuda_graph_nops
:
Optional
[
int
]
=
None
)
->
TMeasurement
:
"""
Benchmark basic torch.mm as a roofline.
When all the input tokens have the same LoRA ID, the LoRA kernels are just
a matmul. This torch.mm benchmark serves as a roofline for that case.
input op_type is used in determining the m, k, n dimensions for the matmul.
"""
batch_size
,
hidden_size
,
lora_rank
,
seq_length
,
dtype
=
(
ctx
.
batch_size
,
ctx
.
hidden_size
,
ctx
.
lora_rank
,
ctx
.
seq_length
,
ctx
.
dtype
)
m
,
k
,
n
=
op_type
.
mkn
(
batch_size
,
seq_length
,
hidden_size
,
lora_rank
)
# For a fairer comparison.
n
=
n
*
ctx
.
num_slices
# Get matmul input and output tensors for A x B = C
As
,
Bs
,
Cs
=
[],
[],
[]
for
_
in
range
(
arg_pool_size
):
As
.
append
(
torch
.
rand
((
m
,
k
),
dtype
=
dtype
).
to
(
"cuda"
))
Bs
.
append
(
torch
.
rand
((
n
,
k
),
dtype
=
dtype
).
to
(
"cuda"
).
t
())
Cs
.
append
(
torch
.
rand
((
m
,
n
),
dtype
=
dtype
).
to
(
"cuda"
))
# Make torch.mm kwargs
mm_kwargs
=
{
'input'
:
ArgPool
(
As
),
'mat2'
:
ArgPool
(
Bs
),
'out'
:
ArgPool
(
Cs
)}
description
=
(
f
"single-lora roofline using torch.mm (
{
dtype_to_str
(
dtype
)
}
"
f
"x
{
dtype_to_str
(
dtype
)
}
"
f
"=>
{
dtype_to_str
(
dtype
)
}
)"
)
cuda_graph_params
=
None
if
cuda_graph_nops
:
cuda_graph_params
=
CudaGraphBenchParams
(
cuda_graph_nops
)
with
Bench
(
cuda_graph_params
,
ctx
.
bench_label
(),
ctx
.
bench_sublabel
(
op_type
),
description
,
torch
.
mm
,
**
mm_kwargs
)
as
bench
:
return
bench
.
run
()
# runner
def
use_cuda_graph_recommendation
()
->
str
:
return
"""
Triton kernels have a significant launch overhead with
launched directly via python. This overhead is more noticeable
for small the problem sizes. For these cases, it is recommended
to use the script with `--cuda-graph-nops N` to benchmark N
consecutive invocations of the benchmarking operations from
inside a CUDA Graph. Note that the returned measurement is for N
invocations of the operation.
"""
def
print_timers
(
timers
:
List
[
TMeasurement
],
args
:
Optional
[
argparse
.
Namespace
]
=
None
):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
if
args
and
args
.
cuda_graph_nops
:
print
(
f
"Note : The timings reported above is for
{
args
.
cuda_graph_nops
}
"
"consecutive invocations of the benchmarking functions. "
f
"Please divide by
{
args
.
cuda_graph_nops
}
for single invocation "
"timings."
)
print
(
"Note on Comparison with torch.mm : The torch.mm numbers are "
"benchmark numbers of a simple matmul emulating the single lora "
"case. It is provided as a roofline for comparing our LoRA Kernel "
"implementations. It is expected that the LoRA kernels will be "
"slower than torch.mm in cases where num_loras is big. But for "
"small num_loras the goal should be to match the torch.mm numbers."
)
def
run
(
args
:
argparse
.
Namespace
,
bench_ctxs
:
List
[
BenchmarkContext
]):
if
args
.
cuda_graph_nops
is
not
None
:
assert
args
.
cuda_graph_nops
>
0
print
(
f
"Benchmarking
{
args
.
cuda_graph_nops
}
invocations inside a CUDA "
"Graph"
)
else
:
print
(
f
"CUDA Graphs not enabled.
\n
{
use_cuda_graph_recommendation
()
}
"
)
timers
=
[]
for
bench_ctx
in
bench_ctxs
:
for
seq_len
in
args
.
seq_lengths
:
bench_ops
:
List
[
OpType
]
=
[]
if
seq_len
==
1
:
# bench all decode ops
bench_ops
=
[
op
for
op
in
args
.
op_types
if
op
.
is_decode_op
()]
else
:
# bench all prefill ops
bench_ops
=
[
op
for
op
in
args
.
op_types
if
op
.
is_prefill_op
()]
seq_len_timers
=
[]
for
bench_op
in
bench_ops
:
for
num_slices
in
bench_op
.
num_slices
():
_ctx
=
bench_ctx
.
with_seq_length
(
seq_len
).
with_num_slices
(
num_slices
)
# Benchmark torch.mm as a roofline
seq_len_timers
.
append
(
bench_torch_mm
(
_ctx
,
args
.
arg_pool_size
,
bench_op
,
args
.
cuda_graph_nops
))
# Benchmark bench_op
expand_fn_add_inputs
=
[
None
]
if
bench_op
.
is_shrink_fn
()
else
args
.
expand_fn_add_inputs
for
add_input_arg
in
expand_fn_add_inputs
:
seq_len_timers
.
append
(
bench_optype
(
_ctx
,
args
.
arg_pool_size
,
bench_op
,
args
.
cuda_graph_nops
,
add_input_arg
,
args
.
test_correctness
))
print_timers
(
seq_len_timers
)
timers
.
extend
(
seq_len_timers
)
# Result stdout dump
print
(
"== All Results ===="
)
print_timers
(
timers
,
args
)
if
args
.
output_directory
:
# Result file dump
od
=
Path
(
args
.
output_directory
)
if
not
od
.
exists
():
od
.
mkdir
()
timestamp
=
int
(
time
.
time
())
pkl_file
=
od
/
f
"lora_bench-
{
timestamp
}
.pkl"
print
(
f
"Writing benchmarks to
{
pkl_file
}
"
)
with
open
(
pkl_file
,
"wb"
)
as
f
:
pickle
.
dump
(
timers
,
f
)
def
as_benchmark_contexts
(
hidden_sizes
:
List
[
int
],
lora_ranks
:
List
[
int
],
args
:
argparse
.
Namespace
)
->
List
[
BenchmarkContext
]:
ctxs
:
List
[
BenchmarkContext
]
=
[]
for
batch_size
,
hidden_size
,
lora_rank
,
num_loras
,
sort_by_lora_id
in
product
(
# noqa
args
.
batch_sizes
,
list
(
hidden_sizes
),
lora_ranks
,
args
.
num_loras
,
args
.
sort_by_lora_id
):
ctxs
.
append
(
BenchmarkContext
(
batch_size
=
batch_size
,
hidden_size
=
hidden_size
,
lora_rank
=
lora_rank
,
num_loras
=
num_loras
,
num_active_loras
=
args
.
num_active_loras
if
args
.
num_active_loras
else
num_loras
,
# To be filled based on the OpType to benchmark
seq_length
=
None
,
sort_by_lora_id
=
sort_by_lora_id
,
dtype
=
args
.
dtype
,
# To be filled based on the OpType to benchmark
num_slices
=
None
))
return
ctxs
def
run_list_bench
(
args
:
argparse
.
Namespace
):
print
(
args
)
print
(
"List bench :
\n
"
f
" Hidden Sizes
{
args
.
hidden_sizes
}
"
f
" LoRA Ranks
{
args
.
lora_ranks
}
"
)
# Get all benchmarking contexts
bench_contexts
:
List
[
BenchmarkContext
]
=
as_benchmark_contexts
(
hidden_sizes
=
args
.
hidden_sizes
,
lora_ranks
=
args
.
lora_ranks
,
args
=
args
)
run
(
args
,
bench_contexts
)
def
run_range_bench
(
args
:
argparse
.
Namespace
):
print
(
args
)
hidden_sizes
=
list
(
range
(
args
.
hidden_sizes_start
,
args
.
hidden_sizes_end
+
1
,
args
.
hidden_sizes_increment
))
lora_ranks
=
list
(
range
(
args
.
lora_ranks_start
,
args
.
lora_ranks_end
+
1
,
args
.
lora_ranks_increment
))
print
(
"Range bench :
\n
"
f
" Hidden Sizes
{
hidden_sizes
}
"
f
" LoRA Ranks
{
lora_ranks
}
"
)
# Get all benchmarking contexts
bench_contexts
:
List
[
BenchmarkContext
]
=
as_benchmark_contexts
(
hidden_sizes
=
hidden_sizes
,
lora_ranks
=
lora_ranks
,
args
=
args
)
run
(
args
,
bench_contexts
)
def
run_model_bench
(
args
:
argparse
.
Namespace
):
print
(
args
)
def
hidden_sizes_from_model
(
model
:
str
,
tp_size
:
int
)
->
set
[
int
]:
hidden_sizes
=
set
()
for
KN
,
tp_split_dim
in
WEIGHT_SHAPES
[
model
]:
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
hidden_sizes
.
add
(
KN
[
1
])
return
hidden_sizes
# Get all hidden sizes
hidden_sizes
:
set
[
int
]
=
set
()
for
model_name
,
tp_size
in
product
(
args
.
models
,
args
.
tp_sizes
):
hidden_sizes
=
hidden_sizes
.
union
(
hidden_sizes_from_model
(
model_name
,
tp_size
))
print
(
"Model bench :
\n
"
f
" Hidden Sizes
{
hidden_sizes
}
"
f
" LoRA Ranks
{
args
.
lora_ranks
}
"
)
# Get all benchmarking contexts
bench_contexts
:
List
[
BenchmarkContext
]
=
as_benchmark_contexts
(
hidden_sizes
=
hidden_sizes
,
lora_ranks
=
args
.
lora_ranks
,
args
=
args
)
run
(
args
,
bench_contexts
)
if
__name__
==
'__main__'
:
def
to_torch_dtype
(
dt
):
if
dt
==
"torch.float16"
:
return
torch
.
float16
if
dt
==
"torch.bfloat16"
:
return
torch
.
bfloat16
raise
ValueError
(
"unsupported dtype"
)
def
get_bool
(
s
:
str
)
->
bool
:
return
s
.
lower
()
in
[
'true'
,
'1'
]
def
add_common_command_args
(
p
:
argparse
.
ArgumentParser
):
p
.
add_argument
(
"--dtype"
,
type
=
to_torch_dtype
,
required
=
True
,
help
=
"Available options are ['torch.float16', 'torch.bfloat16']"
)
p
.
add_argument
(
"--arg-pool-size"
,
type
=
int
,
default
=
32
,
help
=
"Run profiles with a pool of input/output/meta tensors instead"
"of simply reusing the same tensors for all runs. A bigger arg-pool"
"mitigates hardware caching effects during benchmarking."
)
p
.
add_argument
(
"--cuda-graph-nops"
,
type
=
int
,
help
=
(
"when set profiling is done using cudagraph, "
"with the given number of operations in a graph."
"Note that the measurement returned is the time "
"taken for N consecutive executions of the benchmarking "
"functions, where N is the value of this argument."
))
p
.
add_argument
(
"--num-loras"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_NUM_LORAS
)
p
.
add_argument
(
"--num-active-loras"
,
type
=
int
,
default
=
None
,
help
=
"Active LoRAs. When None, all LoRAs are active"
)
p
.
add_argument
(
"--sort-by-lora-id"
,
nargs
=
"+"
,
type
=
get_bool
,
default
=
DEFAULT_SORT_BY_LORA_IDS
)
p
.
add_argument
(
"--op-types"
,
nargs
=
"+"
,
type
=
OpType
.
from_str
,
default
=
list
(
OpType
))
p
.
add_argument
(
'--seq-lengths'
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_SEQ_LENGTHS
)
p
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
p
.
add_argument
(
"--expand-fn-add-inputs"
,
nargs
=
"+"
,
type
=
get_bool
,
default
=
DEFAULT_EXPAND_FN_ADD_INPUTS
)
p
.
add_argument
(
'-o'
,
'--output-directory'
,
type
=
str
,
help
=
(
"Output directory to store a the list of benchmarking"
"TMeasurement objects as a pickle file"
))
p
.
add_argument
(
"--test-correctness"
,
action
=
'store_true'
,
help
=
(
"When enabled, the benchmarking functions are tested"
"for correctness before the actual benchmarking"
))
parser
=
FlexibleArgumentParser
(
description
=
f
"""
Benchmark LoRA kernels:
{
use_cuda_graph_recommendation
()
}
list_bench example:
python3 benchmarks/kernels/benchmark_lora.py list_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --hidden-sizes 2048 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
model_bench example:
python3 benchmarks/kernels/benchmark_lora.py model_bench --models meta-llama/Llama-3-8b --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --lora-ranks 16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32
range_bench example:
python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types bgmv_shrink bgmv_expand sgmv_shrink sgmv_expand bgmv_expand_slice --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8
"""
,
# noqa: E501
formatter_class
=
argparse
.
RawTextHelpFormatter
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
,
required
=
True
)
list_parser
=
subparsers
.
add_parser
(
"list_bench"
)
list_parser
.
add_argument
(
"--hidden-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_HIDDEN_SIZES
)
list_parser
.
add_argument
(
"--lora-ranks"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_LORA_RANKS
)
add_common_command_args
(
list_parser
)
list_parser
.
set_defaults
(
func
=
run_list_bench
)
range_parser
=
subparsers
.
add_parser
(
"range_bench"
)
range_parser
.
add_argument
(
"--hidden-sizes-start"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--hidden-sizes-end"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--hidden-sizes-increment"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--lora-ranks-start"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--lora-ranks-end"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--lora-ranks-increment"
,
type
=
int
,
required
=
True
)
add_common_command_args
(
range_parser
)
range_parser
.
set_defaults
(
func
=
run_range_bench
)
model_parser
=
subparsers
.
add_parser
(
"model_bench"
)
model_parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
())
model_parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
model_parser
.
add_argument
(
"--lora-ranks"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_LORA_RANKS
)
add_common_command_args
(
model_parser
)
model_parser
.
set_defaults
(
func
=
run_model_bench
)
args
=
parser
.
parse_args
()
args
.
func
(
args
)
benchmarks/kernels/utils.py
0 → 100644
View file @
5fd24ec0
import
dataclasses
from
typing
import
Any
,
Callable
,
Iterable
,
Optional
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
@
dataclasses
.
dataclass
class
CudaGraphBenchParams
:
num_ops_in_cuda_graph
:
int
@
dataclasses
.
dataclass
class
ArgPool
:
"""
When some argument of the benchmarking function is annotated with this type,
the benchmarking class (BenchMM) will collapse the argument to a pick a
single value from the given list of values, during function invocation.
For every invocation during a benchmarking run, it will choose a
different value from the list.
"""
values
:
Iterable
[
Any
]
def
__getitem__
(
self
,
index
):
return
self
.
values
[
index
]
class
Bench
:
class
ArgsIterator
:
def
__init__
(
self
,
args_list
,
kwargs_list
):
assert
len
(
args_list
)
==
len
(
kwargs_list
)
self
.
args_list
=
args_list
self
.
kwargs_list
=
kwargs_list
self
.
n
=
len
(
self
.
args_list
)
self
.
idx
=
0
def
__next__
(
self
):
while
True
:
yield
(
self
.
args_list
[
self
.
idx
],
self
.
kwargs_list
[
self
.
idx
])
self
.
idx
+=
1
self
.
idx
=
self
.
idx
%
self
.
n
def
reset
(
self
):
self
.
idx
=
0
@
property
def
n_args
(
self
):
return
self
.
n
def
__init__
(
self
,
cuda_graph_params
:
Optional
[
CudaGraphBenchParams
],
label
:
str
,
sub_label
:
str
,
description
:
str
,
fn
:
Callable
,
*
args
,
**
kwargs
):
self
.
cuda_graph_params
=
cuda_graph_params
self
.
use_cuda_graph
=
self
.
cuda_graph_params
is
not
None
self
.
label
=
label
self
.
sub_label
=
sub_label
self
.
description
=
description
self
.
fn
=
fn
# Process args
self
.
_args
=
args
self
.
_kwargs
=
kwargs
self
.
args_list
,
self
.
kwargs_list
=
self
.
collapse_argpool
(
*
args
,
**
kwargs
)
self
.
args_iterator
=
self
.
ArgsIterator
(
self
.
args_list
,
self
.
kwargs_list
)
# Cudagraph runner
self
.
g
=
None
if
self
.
use_cuda_graph
:
self
.
g
=
self
.
get_cuda_graph_runner
()
# benchmark run params
self
.
min_run_time
=
1
def
collapse_argpool
(
self
,
*
args
,
**
kwargs
):
argpool_args
=
[
arg
for
arg
in
args
if
isinstance
(
arg
,
ArgPool
)]
+
[
arg
for
arg
in
kwargs
.
values
()
if
isinstance
(
arg
,
ArgPool
)
]
if
len
(
argpool_args
)
==
0
:
return
[
args
],
[
kwargs
]
# Make sure all argpools are of the same size
argpool_size
=
len
(
argpool_args
[
0
].
values
)
assert
all
([
argpool_size
==
len
(
arg
.
values
)
for
arg
in
argpool_args
])
# create copies of the args
args_list
=
[]
kwargs_list
=
[]
for
_
in
range
(
argpool_size
):
args_list
.
append
(
args
)
kwargs_list
.
append
(
kwargs
.
copy
())
for
i
in
range
(
argpool_size
):
# collapse args; Just pick the ith value
args_list
[
i
]
=
tuple
([
arg
[
i
]
if
isinstance
(
arg
,
ArgPool
)
else
arg
for
arg
in
args_list
[
i
]
])
# collapse kwargs
kwargs_i
=
kwargs_list
[
i
]
arg_pool_keys
=
[
k
for
k
,
v
in
kwargs_i
.
items
()
if
isinstance
(
v
,
ArgPool
)
]
for
k
in
arg_pool_keys
:
# again just pick the ith value
kwargs_i
[
k
]
=
kwargs_i
[
k
][
i
]
kwargs_list
[
i
]
=
kwargs_i
return
args_list
,
kwargs_list
def
get_cuda_graph_runner
(
self
):
assert
self
.
use_cuda_graph
assert
self
.
args_iterator
is
not
None
num_graph_ops
=
self
.
cuda_graph_params
.
num_ops_in_cuda_graph
# warmup
args_it
=
self
.
args_iterator
.
__next__
()
for
_
in
range
(
2
):
args
,
kwargs
=
next
(
args_it
)
self
.
fn
(
*
args
,
**
kwargs
)
self
.
args_iterator
.
reset
()
args_it
=
self
.
args_iterator
.
__next__
()
stream
=
torch
.
cuda
.
Stream
()
with
torch
.
cuda
.
stream
(
stream
):
g
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
g
):
for
_
in
range
(
num_graph_ops
):
args
,
kwargs
=
next
(
args_it
)
self
.
fn
(
*
args
,
**
kwargs
)
return
g
def
run_cudagrah
(
self
)
->
TMeasurement
:
assert
self
.
use_cuda_graph
globals
=
{
'g'
:
self
.
g
}
return
TBenchmark
.
Timer
(
stmt
=
"g.replay()"
,
globals
=
globals
,
label
=
(
f
"
{
self
.
label
}
"
f
" | cugraph
{
self
.
cuda_graph_params
.
num_ops_in_cuda_graph
}
ops"
),
sub_label
=
self
.
sub_label
,
description
=
self
.
description
,
).
blocked_autorange
(
min_run_time
=
self
.
min_run_time
)
def
run_eager
(
self
)
->
TMeasurement
:
setup
=
None
stmt
=
None
globals
=
None
has_arg_pool
=
self
.
args_iterator
.
n_args
>
1
if
has_arg_pool
:
setup
=
'''
args_iterator.reset()
args_it = args_iterator.__next__()
'''
stmt
=
'''
args, kwargs = next(args_it)
fn(*args, **kwargs)
'''
globals
=
{
'fn'
:
self
.
fn
,
'args_iterator'
:
self
.
args_iterator
}
else
:
# no arg pool. Just use the args and kwargs directly
self
.
args_iterator
.
reset
()
args_it
=
self
.
args_iterator
.
__next__
()
args
,
kwargs
=
next
(
args_it
)
setup
=
""
stmt
=
'''
fn(*args, **kwargs)
'''
globals
=
{
'fn'
:
self
.
fn
,
'args'
:
args
,
'kwargs'
:
kwargs
}
return
TBenchmark
.
Timer
(
stmt
=
stmt
,
setup
=
setup
,
globals
=
globals
,
label
=
self
.
label
,
sub_label
=
self
.
sub_label
,
description
=
self
.
description
,
).
blocked_autorange
(
min_run_time
=
self
.
min_run_time
)
def
run
(
self
)
->
TMeasurement
:
timer
=
None
if
self
.
use_cuda_graph
:
# noqa SIM108
timer
=
self
.
run_cudagrah
()
else
:
timer
=
self
.
run_eager
()
if
not
timer
.
meets_confidence
()
or
timer
.
has_warnings
:
print
(
"Doesn't meet confidence - re-running bench ..."
)
return
self
.
run
()
return
timer
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
traceback
):
if
exc_type
:
print
(
f
"exc type
{
exc_type
}
"
)
print
(
f
"exc value
{
exc_value
}
"
)
print
(
f
"exc traceback
{
traceback
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment