Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
SAC_ac1ua3v7iw
liger-kernel
Commits
9b0e3a30
Commit
9b0e3a30
authored
Mar 25, 2026
by
cmx
Browse files
first commit
parent
fe5cd1fc
Pipeline
#3450
failed with stages
in 0 seconds
Changes
261
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4183 additions
and
0 deletions
+4183
-0
benchmark/scripts/benchmark_geglu.py
benchmark/scripts/benchmark_geglu.py
+115
-0
benchmark/scripts/benchmark_group_norm.py
benchmark/scripts/benchmark_group_norm.py
+137
-0
benchmark/scripts/benchmark_grpo_loss.py
benchmark/scripts/benchmark_grpo_loss.py
+234
-0
benchmark/scripts/benchmark_jsd.py
benchmark/scripts/benchmark_jsd.py
+157
-0
benchmark/scripts/benchmark_kl_div.py
benchmark/scripts/benchmark_kl_div.py
+117
-0
benchmark/scripts/benchmark_kto_loss.py
benchmark/scripts/benchmark_kto_loss.py
+314
-0
benchmark/scripts/benchmark_layer_norm.py
benchmark/scripts/benchmark_layer_norm.py
+125
-0
benchmark/scripts/benchmark_llama4_rope.py
benchmark/scripts/benchmark_llama4_rope.py
+245
-0
benchmark/scripts/benchmark_mhc.py
benchmark/scripts/benchmark_mhc.py
+255
-0
benchmark/scripts/benchmark_mhc_lm.py
benchmark/scripts/benchmark_mhc_lm.py
+455
-0
benchmark/scripts/benchmark_model_configs.py
benchmark/scripts/benchmark_model_configs.py
+258
-0
benchmark/scripts/benchmark_multi_token_attention.py
benchmark/scripts/benchmark_multi_token_attention.py
+218
-0
benchmark/scripts/benchmark_orpo_loss.py
benchmark/scripts/benchmark_orpo_loss.py
+169
-0
benchmark/scripts/benchmark_poly_norm.py
benchmark/scripts/benchmark_poly_norm.py
+197
-0
benchmark/scripts/benchmark_qwen2vl_mrope.py
benchmark/scripts/benchmark_qwen2vl_mrope.py
+241
-0
benchmark/scripts/benchmark_rms_norm.py
benchmark/scripts/benchmark_rms_norm.py
+162
-0
benchmark/scripts/benchmark_rope.py
benchmark/scripts/benchmark_rope.py
+223
-0
benchmark/scripts/benchmark_simpo_loss.py
benchmark/scripts/benchmark_simpo_loss.py
+167
-0
benchmark/scripts/benchmark_softmax.py
benchmark/scripts/benchmark_softmax.py
+140
-0
benchmark/scripts/benchmark_sparse_multi_token_attention.py
benchmark/scripts/benchmark_sparse_multi_token_attention.py
+254
-0
No files found.
Too many changes to show.
To preserve performance only
261 of 261+
files are displayed.
Plain diff
Email patch
benchmark/scripts/benchmark_geglu.py
0 → 100755
View file @
9b0e3a30
import
math
import
torch
from
benchmark_model_configs
import
compute_seq_len_sweep_config
from
benchmark_model_configs
import
estimate_kernel_peak_memory
from
benchmark_model_configs
import
get_benchmark_model_config
from
transformers.models.llama.configuration_llama
import
LlamaConfig
from
transformers.models.llama.modeling_llama
import
LlamaMLP
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
utils
import
run_memory_benchmark
from
utils
import
run_speed_benchmark
from
liger_kernel.transformers.geglu
import
LigerGEGLUMLP
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
def
_setup_geglu
(
input
:
SingleBenchmarkRunInput
):
"""Create input tensor and GEGLU layer from benchmark config."""
cfg
=
input
.
extra_benchmark_config
llama_config
=
LlamaConfig
(
hidden_size
=
cfg
[
"hidden_size"
],
intermediate_size
=
cfg
[
"intermediate_size"
],
hidden_act
=
cfg
[
"hidden_act"
],
)
x
=
torch
.
randn
(
cfg
[
"bsz"
],
input
.
x
,
cfg
[
"hidden_size"
],
device
=
device
,
dtype
=
cfg
[
"dtype"
],
requires_grad
=
True
,
)
if
input
.
kernel_provider
==
"liger"
:
layer
=
LigerGEGLUMLP
(
config
=
llama_config
).
to
(
device
).
to
(
cfg
[
"dtype"
])
elif
input
.
kernel_provider
==
"huggingface"
:
layer
=
LlamaMLP
(
config
=
llama_config
).
to
(
device
).
to
(
cfg
[
"dtype"
])
else
:
raise
ValueError
(
f
"Invalid provider:
{
input
.
kernel_provider
}
for GEGLU"
)
return
x
,
layer
def
bench_speed_geglu
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
x
,
layer
=
_setup_geglu
(
input
)
return
run_speed_benchmark
(
lambda
:
layer
(
x
),
input
.
kernel_operation_mode
,
[
x
])
def
bench_memory_geglu
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
x
,
layer
=
_setup_geglu
(
input
)
return
run_memory_benchmark
(
lambda
:
layer
(
x
),
input
.
kernel_operation_mode
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
model
=
get_benchmark_model_config
(
args
.
model
)
probe_seq_len
=
1024
def
_probe
():
probe_input
=
SingleBenchmarkRunInput
(
x
=
probe_seq_len
,
kernel_provider
=
"huggingface"
,
extra_benchmark_config
=
{
"bsz"
:
1
,
"hidden_size"
:
model
.
hidden_size
,
"intermediate_size"
:
model
.
intermediate_size
,
"hidden_act"
:
"gelu_pytorch_tanh"
,
"dtype"
:
model
.
dtype
,
},
)
x
,
layer
=
_setup_geglu
(
probe_input
)
return
layer
(
x
)
peak_bytes
=
estimate_kernel_peak_memory
(
probe_fn
=
_probe
)
kernel_bpt
=
peak_bytes
//
probe_seq_len
config
=
compute_seq_len_sweep_config
(
model
,
kernel_bytes_per_token
=
kernel_bpt
)
common_configs
=
{
"kernel_name"
:
"geglu"
,
"x_name"
:
"T"
,
"x_label"
:
"sequence length"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
10
,
int
(
math
.
log2
(
config
.
seq_len
))
+
1
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[
{
"bsz"
:
config
.
batch_size
,
"hidden_size"
:
model
.
hidden_size
,
"intermediate_size"
:
model
.
intermediate_size
,
"hidden_act"
:
"gelu_pytorch_tanh"
,
"dtype"
:
model
.
dtype
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_geglu
,
kernel_operation_modes
=
[
"full"
,
"forward"
,
"backward"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_geglu
,
kernel_operation_modes
=
[
"full"
,
"forward"
,
"backward"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs
,
)
benchmark/scripts/benchmark_group_norm.py
0 → 100755
View file @
9b0e3a30
import
torch
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.group_norm
import
LigerGroupNorm
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
def
bench_speed_group_norm
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
C
=
input
.
x
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
extra_benchmark_config
=
input
.
extra_benchmark_config
M
=
extra_benchmark_config
[
"M"
]
H
=
extra_benchmark_config
[
"H"
]
channels_per_group
=
extra_benchmark_config
[
"channels_per_group"
]
eps
=
extra_benchmark_config
[
"eps"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
x_shape
=
(
M
,
C
,
H
)
triton_ln
=
LigerGroupNorm
(
num_channels
=
C
,
num_groups
=
C
//
channels_per_group
,
eps
=
eps
).
to
(
device
)
torch_ln
=
torch
.
nn
.
GroupNorm
(
num_groups
=
C
//
channels_per_group
,
num_channels
=
C
,
eps
=
eps
).
to
(
device
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
def
y_fwd
():
if
provider
==
"liger"
:
return
triton_ln
(
x
)
if
provider
==
"huggingface"
:
return
torch_ln
(
x
)
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
y_fwd
,
quantiles
=
QUANTILES
,
grad_to_none
=
[
x
],
rep
=
500
)
elif
mode
==
"backward"
:
y
=
y_fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
dy
,
retain_graph
=
True
),
quantiles
=
QUANTILES
,
grad_to_none
=
[
x
],
rep
=
500
,
)
elif
mode
==
"full"
:
def
full
():
y
=
y_fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
quantiles
=
QUANTILES
,
grad_to_none
=
[
x
],
rep
=
500
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_group_norm
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
C
=
input
.
x
provider
=
input
.
kernel_provider
extra_benchmark_config
=
input
.
extra_benchmark_config
M
=
extra_benchmark_config
[
"M"
]
H
=
extra_benchmark_config
[
"H"
]
channels_per_group
=
extra_benchmark_config
[
"channels_per_group"
]
eps
=
extra_benchmark_config
[
"eps"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
x_shape
=
(
M
,
C
,
H
)
triton_ln
=
LigerGroupNorm
(
num_channels
=
C
,
num_groups
=
C
//
channels_per_group
,
eps
=
eps
).
to
(
device
)
torch_ln
=
torch
.
nn
.
GroupNorm
(
num_groups
=
C
//
channels_per_group
,
num_channels
=
C
,
eps
=
eps
).
to
(
device
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
def
y_fwd
():
if
provider
==
"liger"
:
return
triton_ln
(
x
)
if
provider
==
"huggingface"
:
return
torch_ln
(
x
)
def
full
():
y
=
y_fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs
=
{
"kernel_name"
:
"group_norm"
,
"x_name"
:
"C"
,
"x_label"
:
"num_channels"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
5
,
12
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[
{
"M"
:
128
,
"H"
:
512
,
"channels_per_group"
:
4
,
"dtype"
:
torch
.
float32
,
"eps"
:
1e-6
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_group_norm
,
kernel_operation_modes
=
[
"forward"
,
"full"
,
"backward"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_group_norm
,
kernel_operation_modes
=
[
"full"
,
"forward"
,
"backward"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs
,
)
benchmark/scripts/benchmark_grpo_loss.py
0 → 100755
View file @
9b0e3a30
import
os
import
sys
import
torch
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../.."
)))
#############################################################################
# Test the memory consumption of the linear fused GRPO loss
#############################################################################
def
bench_memory_fused_linear_grpo_loss
(
input
:
SingleBenchmarkRunInput
,
)
->
SingleBenchmarkRunOutput
:
from
test.chunked_loss.test_grpo_loss
import
LigerLMHeadGRPO
from
test.chunked_loss.test_grpo_loss
import
TorchLMHeadGRPO
B
=
input
.
x
T
=
input
.
extra_benchmark_config
[
"T"
]
H
=
input
.
extra_benchmark_config
[
"H"
]
V
=
input
.
extra_benchmark_config
[
"V"
]
dtype
=
input
.
extra_benchmark_config
[
"dtype"
]
importance_sampling_level
=
input
.
extra_benchmark_config
[
"importance_sampling_level"
]
provider
=
input
.
kernel_provider
# Instantiate once and retrieve the first output only
torch_lm_head_grpo
=
TorchLMHeadGRPO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
,
importance_sampling_level
=
importance_sampling_level
).
to
(
device
)
liger_lm_head_grpo
=
LigerLMHeadGRPO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
,
importance_sampling_level
=
importance_sampling_level
).
to
(
device
)
# Create inputs
_input
=
torch
.
randn
(
B
,
T
,
H
,
requires_grad
=
True
,
dtype
=
dtype
,
device
=
device
)
selected_token_ids
=
torch
.
randint
(
0
,
V
,
(
B
,
T
),
dtype
=
torch
.
long
,
device
=
device
)
attention_mask
=
torch
.
ones
(
B
,
T
,
device
=
device
)
advantages
=
torch
.
randn
(
B
,
dtype
=
dtype
,
device
=
device
)
ref_input
=
torch
.
randn
(
B
,
T
,
H
,
dtype
=
dtype
,
device
=
device
)
torch_fwd
=
lambda
:
torch_lm_head_grpo
(
_input
,
selected_token_ids
,
attention_mask
,
advantages
,
ref_input
=
ref_input
)[
0
]
liger_fwd
=
lambda
:
liger_lm_head_grpo
(
_input
,
selected_token_ids
,
attention_mask
,
advantages
,
ref_input
=
ref_input
)[
0
]
def
fwd
():
if
provider
==
"liger"
:
return
liger_fwd
()
elif
provider
==
"torch"
:
return
torch_fwd
()
def
full
():
y
=
fwd
()
y
.
backward
()
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
_iter
=
10
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
#############################################################################
# Test the speed of the fused linear GRPO loss
#############################################################################
def
bench_speed_fused_linear_grpo_loss
(
input
:
SingleBenchmarkRunInput
,
)
->
SingleBenchmarkRunOutput
:
from
test.chunked_loss.test_grpo_loss
import
LigerLMHeadGRPO
from
test.chunked_loss.test_grpo_loss
import
TorchLMHeadGRPO
B
=
input
.
x
T
=
input
.
extra_benchmark_config
[
"T"
]
H
=
input
.
extra_benchmark_config
[
"H"
]
V
=
input
.
extra_benchmark_config
[
"V"
]
dtype
=
input
.
extra_benchmark_config
[
"dtype"
]
importance_sampling_level
=
input
.
extra_benchmark_config
[
"importance_sampling_level"
]
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
# Instantiate once and retrieve the first output only
torch_lm_head_grpo
=
TorchLMHeadGRPO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
,
importance_sampling_level
=
importance_sampling_level
).
to
(
device
)
liger_lm_head_grpo
=
LigerLMHeadGRPO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
,
importance_sampling_level
=
importance_sampling_level
).
to
(
device
)
# Create inputs
_input
=
torch
.
randn
(
B
,
T
,
H
,
requires_grad
=
True
,
dtype
=
dtype
,
device
=
device
)
selected_token_ids
=
torch
.
randint
(
0
,
V
,
(
B
,
T
),
dtype
=
torch
.
long
,
device
=
device
)
attention_mask
=
torch
.
ones
(
B
,
T
,
device
=
device
)
advantages
=
torch
.
randn
(
B
,
dtype
=
dtype
,
device
=
device
)
ref_input
=
torch
.
randn
(
B
,
T
,
H
,
dtype
=
dtype
,
device
=
device
)
torch_fwd
=
lambda
:
torch_lm_head_grpo
(
_input
,
selected_token_ids
,
attention_mask
,
advantages
,
ref_input
=
ref_input
)[
0
]
liger_fwd
=
lambda
:
liger_lm_head_grpo
(
_input
,
selected_token_ids
,
attention_mask
,
advantages
,
ref_input
=
ref_input
)[
0
]
def
fwd
():
if
provider
==
"liger"
:
return
liger_fwd
()
elif
provider
==
"torch"
:
return
torch_fwd
()
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
rep
=
100
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"backward"
:
y
=
fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
retain_graph
=
True
),
grad_to_none
=
[
_input
],
rep
=
100
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"full"
:
def
full
():
y
=
fwd
()
y
.
backward
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
rep
=
100
,
quantiles
=
QUANTILES
,
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
# Benchmark token-level importance sampling (original GRPO)
token_configs
=
{
"kernel_name"
:
"fused_linear_grpo_loss_token"
,
"x_name"
:
"B"
,
"x_label"
:
"B"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
1
,
5
)],
"kernel_providers"
:
[
"liger"
,
"torch"
],
"extra_benchmark_configs"
:
[
{
"T"
:
1024
,
"H"
:
4096
,
"V"
:
128256
,
"importance_sampling_level"
:
"token"
,
"dtype"
:
torch
.
bfloat16
,
}
],
"overwrite"
:
args
.
overwrite
,
}
# Benchmark sequence-level importance sampling (GSPO)
sequence_configs
=
{
"kernel_name"
:
"fused_linear_grpo_loss_sequence"
,
"x_name"
:
"B"
,
"x_label"
:
"B"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
1
,
5
)],
"kernel_providers"
:
[
"liger"
,
"torch"
],
"extra_benchmark_configs"
:
[
{
"T"
:
1024
,
"H"
:
4096
,
"V"
:
128256
,
"importance_sampling_level"
:
"sequence"
,
"dtype"
:
torch
.
bfloat16
,
}
],
"overwrite"
:
args
.
overwrite
,
}
# Run benchmarks for token-level (GRPO)
print
(
"Benchmarking GRPO (token-level importance sampling)..."
)
run_benchmarks
(
bench_test_fn
=
bench_speed_fused_linear_grpo_loss
,
kernel_operation_modes
=
[
"forward"
,
"full"
,
"backward"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
token_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_fused_linear_grpo_loss
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
token_configs
,
)
# Run benchmarks for sequence-level (GSPO)
print
(
"Benchmarking GSPO (sequence-level importance sampling)..."
)
run_benchmarks
(
bench_test_fn
=
bench_speed_fused_linear_grpo_loss
,
kernel_operation_modes
=
[
"forward"
,
"full"
,
"backward"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
sequence_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_fused_linear_grpo_loss
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
sequence_configs
,
)
benchmark/scripts/benchmark_jsd.py
0 → 100755
View file @
9b0e3a30
import
torch
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.jsd
import
LigerJSD
from
liger_kernel.utils
import
get_total_gpu_memory
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
class
TorchJSD
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
beta
:
float
=
0.5
,
ignore_index
:
int
=
-
100
,
dtype
:
torch
.
dtype
=
torch
.
float
,
):
super
(
TorchJSD
,
self
).
__init__
()
self
.
kl
=
torch
.
nn
.
KLDivLoss
(
reduction
=
"none"
,
log_target
=
True
)
self
.
beta
=
beta
self
.
ignore_index
=
ignore_index
self
.
dtype
=
dtype
def
forward
(
self
,
log_q
:
torch
.
Tensor
,
# input
log_p
:
torch
.
Tensor
,
# target
label
=
None
,
):
log_p
,
log_q
=
log_p
.
to
(
torch
.
float
),
log_q
.
to
(
torch
.
float
)
log_p
,
log_q
=
log_p
.
view
(
-
1
,
log_p
.
size
(
-
1
)),
log_q
.
view
(
-
1
,
log_q
.
size
(
-
1
))
m
=
torch
.
lerp
(
torch
.
exp
(
log_q
),
torch
.
exp
(
log_p
),
self
.
beta
)
loss
=
self
.
beta
*
self
.
kl
(
torch
.
log
(
m
),
log_p
).
sum
(
dim
=-
1
)
+
(
1
-
self
.
beta
)
*
self
.
kl
(
torch
.
log
(
m
),
log_q
).
sum
(
dim
=-
1
)
if
label
is
not
None
:
loss
=
torch
.
where
(
label
!=
self
.
ignore_index
,
loss
,
0.0
)
n_non_ignore
=
(
label
!=
self
.
ignore_index
).
sum
().
item
()
if
n_non_ignore
==
0
:
loss
=
0.0
else
:
loss
=
(
loss
/
n_non_ignore
).
sum
()
else
:
loss
=
(
loss
/
log_q
.
shape
[
0
]).
sum
()
return
loss
.
to
(
self
.
dtype
)
def
bench_speed_jsd
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
V
=
input
.
x
B
,
T
=
input
.
extra_benchmark_config
[
"B"
],
input
.
extra_benchmark_config
[
"T"
]
torch_jsd
=
TorchJSD
()
liger_jsd
=
LigerJSD
()
_input
=
torch
.
randn
(
B
*
T
,
V
,
requires_grad
=
True
,
device
=
device
).
log_softmax
(
dim
=-
1
)
target
=
torch
.
randn
(
B
*
T
,
V
,
device
=
device
).
log_softmax
(
dim
=-
1
)
def
fwd
():
if
input
.
kernel_provider
==
"liger"
:
return
liger_jsd
(
_input
,
target
)
else
:
return
torch_jsd
(
_input
,
target
)
if
input
.
kernel_operation_mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
quantiles
=
QUANTILES
,
rep
=
100
)
elif
input
.
kernel_operation_mode
==
"backward"
:
y
=
fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
retain_graph
=
True
),
quantiles
=
QUANTILES
,
grad_to_none
=
[
_input
],
rep
=
100
,
)
elif
input
.
kernel_operation_mode
==
"full"
:
def
full
():
y
=
fwd
()
y
.
backward
(
retain_graph
=
True
)
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
quantiles
=
QUANTILES
,
rep
=
100
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_jsd
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
torch_jsd
=
TorchJSD
()
liger_jsd
=
LigerJSD
()
V
=
input
.
x
B
,
T
=
input
.
extra_benchmark_config
[
"B"
],
input
.
extra_benchmark_config
[
"T"
]
_input
=
torch
.
randn
(
B
*
T
,
V
,
requires_grad
=
True
,
device
=
device
).
log_softmax
(
dim
=-
1
)
target
=
torch
.
randn
(
B
*
T
,
V
,
device
=
device
).
log_softmax
(
dim
=-
1
)
def
fwd
():
if
input
.
kernel_provider
==
"liger"
:
return
liger_jsd
(
_input
,
target
)
else
:
return
torch_jsd
(
_input
,
target
)
def
full
():
y
=
fwd
()
y
.
backward
(
retain_graph
=
True
)
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
gpu_memory_gbs
=
get_total_gpu_memory
()
# We know that the full test will require 54GBs for vocab size 2^17 on torch
if
gpu_memory_gbs
>=
54
:
x_max
=
17
else
:
x_max
=
16
common_args
=
{
"kernel_name"
:
"jsd"
,
"x_name"
:
"V"
,
"x_label"
:
"vocab size"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
12
,
x_max
+
1
)],
"kernel_providers"
:
[
"liger"
,
"torch"
],
"extra_benchmark_configs"
:
[{
"B"
:
4
,
"T"
:
2048
}],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_memory_jsd
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_args
,
)
run_benchmarks
(
bench_test_fn
=
bench_speed_jsd
,
kernel_operation_modes
=
[
"forward"
,
"backward"
,
"full"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_args
,
)
benchmark/scripts/benchmark_kl_div.py
0 → 100755
View file @
9b0e3a30
import
torch
import
torch.nn
as
nn
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.kl_div
import
LigerKLDIVLoss
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
S
,
E
=
12
,
18
def
bench_speed_kldiv
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
reduction
=
"batchmean"
V
=
input
.
x
B
,
T
=
input
.
extra_benchmark_config
[
"B"
],
input
.
extra_benchmark_config
[
"T"
]
torch_kl_div
=
nn
.
KLDivLoss
(
reduction
=
reduction
)
liger_kl_div
=
LigerKLDIVLoss
(
reduction
=
reduction
)
_input
=
torch
.
randn
(
B
*
T
,
V
,
requires_grad
=
True
,
device
=
device
).
log_softmax
(
dim
=-
1
)
target
=
torch
.
randn
(
B
*
T
,
V
,
device
=
device
).
softmax
(
dim
=-
1
)
def
fwd
():
if
input
.
kernel_provider
==
"liger"
:
return
liger_kl_div
(
_input
,
target
)
else
:
return
torch_kl_div
(
_input
,
target
)
if
input
.
kernel_operation_mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
quantiles
=
QUANTILES
,
rep
=
100
)
elif
input
.
kernel_operation_mode
==
"backward"
:
y
=
fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
retain_graph
=
True
),
quantiles
=
QUANTILES
,
grad_to_none
=
[
_input
],
rep
=
100
,
)
elif
input
.
kernel_operation_mode
==
"full"
:
def
full
():
y
=
fwd
()
y
.
backward
(
retain_graph
=
True
)
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
quantiles
=
QUANTILES
,
rep
=
100
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_kldiv
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
reduction
=
"batchmean"
torch_kl_div
=
nn
.
KLDivLoss
(
reduction
=
reduction
)
liger_kl_div
=
LigerKLDIVLoss
(
reduction
=
reduction
)
V
=
input
.
x
B
,
T
=
input
.
extra_benchmark_config
[
"B"
],
input
.
extra_benchmark_config
[
"T"
]
_input
=
torch
.
randn
(
B
*
T
,
V
,
requires_grad
=
True
,
device
=
device
).
log_softmax
(
dim
=-
1
)
target
=
torch
.
randn
(
B
*
T
,
V
,
device
=
device
).
softmax
(
dim
=-
1
)
def
fwd
():
if
input
.
kernel_provider
==
"liger"
:
return
liger_kl_div
(
_input
,
target
)
else
:
return
torch_kl_div
(
_input
,
target
)
def
full
():
y
=
fwd
()
y
.
backward
(
retain_graph
=
True
)
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_args
=
{
"kernel_name"
:
"kl_div"
,
"x_name"
:
"V"
,
"x_label"
:
"vocab size"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
12
,
18
)],
"kernel_providers"
:
[
"liger"
,
"torch"
],
"extra_benchmark_configs"
:
[{
"B"
:
8
,
"T"
:
512
}],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_memory_kldiv
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_args
,
)
run_benchmarks
(
bench_test_fn
=
bench_speed_kldiv
,
kernel_operation_modes
=
[
"forward"
,
"backward"
,
"full"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_args
,
)
benchmark/scripts/benchmark_kto_loss.py
0 → 100755
View file @
9b0e3a30
import
os
import
sys
import
torch
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.chunked_loss
import
LigerFusedLinearKTOLoss
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../.."
)))
class
TorchLMHeadKTO
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
H
:
int
,
V
:
int
,
dtype
:
torch
.
dtype
,
use_bias
:
bool
=
False
,
use_ref_bias
:
bool
=
False
,
ignore_index
:
int
=
-
100
,
beta
:
float
=
0.1
,
):
from
test.chunked_loss.test_kto_loss
import
HFKTOLoss
super
().
__init__
()
self
.
lin
=
torch
.
nn
.
Linear
(
in_features
=
H
,
out_features
=
V
,
bias
=
use_bias
,
dtype
=
dtype
)
self
.
ref_lin
=
torch
.
nn
.
Linear
(
in_features
=
H
,
out_features
=
V
,
bias
=
use_ref_bias
,
dtype
=
dtype
)
self
.
KTO_loss
=
HFKTOLoss
(
ignore_index
=
ignore_index
,
beta
=
beta
,
use_ref_model
=
True
,
).
get_batch_loss_metrics
def
forward
(
self
,
x
,
ref_x
,
y
,
preference_labels
,
kl
=
None
):
return
self
.
KTO_loss
(
weight
=
self
.
lin
.
weight
,
_input
=
x
,
target
=
y
,
bias
=
self
.
lin
.
bias
,
ref_input
=
ref_x
,
ref_weight
=
self
.
ref_lin
.
weight
,
ref_bias
=
self
.
ref_lin
.
bias
,
preference_labels
=
preference_labels
,
kl
=
kl
,
)
class
LigerLMHeadKTO
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
H
:
int
,
V
:
int
,
dtype
:
torch
.
dtype
,
use_bias
:
bool
=
False
,
use_ref_bias
:
bool
=
False
,
ignore_index
:
int
=
-
100
,
beta
:
float
=
0.1
,
):
super
().
__init__
()
self
.
lin
=
torch
.
nn
.
Linear
(
in_features
=
H
,
out_features
=
V
,
bias
=
use_bias
,
dtype
=
dtype
)
self
.
ref_lin
=
torch
.
nn
.
Linear
(
in_features
=
H
,
out_features
=
V
,
bias
=
use_ref_bias
,
dtype
=
dtype
)
self
.
KTO_loss
=
LigerFusedLinearKTOLoss
(
ignore_index
=
ignore_index
,
beta
=
beta
,
use_ref_model
=
True
,
)
def
forward
(
self
,
x
,
ref_x
,
y
,
preference_labels
,
kl
=
None
):
return
self
.
KTO_loss
(
_input
=
x
,
lin_weight
=
self
.
lin
.
weight
,
target
=
y
,
preference_labels
=
preference_labels
,
bias
=
self
.
lin
.
bias
,
ref_input
=
ref_x
,
ref_weight
=
self
.
ref_lin
.
weight
,
ref_bias
=
self
.
ref_lin
.
bias
,
kl
=
kl
,
)
def
bench_memory_kto_loss
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
B
=
input
.
x
T
=
input
.
extra_benchmark_config
[
"T"
]
H
=
input
.
extra_benchmark_config
[
"H"
]
V
=
input
.
extra_benchmark_config
[
"V"
]
dtype
=
input
.
extra_benchmark_config
[
"dtype"
]
bias
=
input
.
extra_benchmark_config
[
"bias"
]
beta
=
input
.
extra_benchmark_config
[
"beta"
]
ignore_index
=
input
.
extra_benchmark_config
[
"ignore_index"
]
provider
=
input
.
kernel_provider
torch_kto_loss
=
TorchLMHeadKTO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
,
use_bias
=
bias
,
use_ref_bias
=
bias
,
ignore_index
=
ignore_index
,
beta
=
beta
,
).
to
(
device
)
liger_kto_loss
=
LigerLMHeadKTO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
,
use_bias
=
bias
,
use_ref_bias
=
bias
,
ignore_index
=
ignore_index
,
beta
=
beta
,
).
to
(
device
)
# Input shape: [B, T, H]
_input
=
torch
.
randn
(
B
,
T
,
H
,
device
=
device
,
dtype
=
dtype
)
# Target shape: [B, T]
target
=
torch
.
randint
(
V
,
(
B
,
T
),
dtype
=
torch
.
long
,
device
=
device
)
# Preference labels shape: [B]
# Create binary preference labels (0 or 1) for each sequence in the batch
# Used to indicate preferred sequences (1) vs non-preferred sequences (0)
preference_labels
=
torch
.
randint
(
2
,
(
B
,),
dtype
=
torch
.
bool
,
device
=
device
)
# Precomputed KL divergence between policy and reference distributions
kl
=
torch
.
randn
(
1
,
device
=
device
,
dtype
=
dtype
)
# Add ignore_index tokens to simulate padding
num_elements_to_assign
=
torch
.
randint
(
1
,
B
*
T
//
2
,
(
1
,)).
item
()
indices_to_assign
=
torch
.
randperm
(
B
*
T
)[:
num_elements_to_assign
]
target
.
view
(
-
1
)[
indices_to_assign
]
=
ignore_index
# Add ref_x with the same shape as _input
ref_input
=
torch
.
randn
(
B
,
T
,
H
,
device
=
device
,
dtype
=
dtype
)
def
fwd
():
if
provider
==
"liger"
:
return
liger_kto_loss
(
x
=
_input
,
ref_x
=
ref_input
,
y
=
target
,
preference_labels
=
preference_labels
,
kl
=
kl
,
)[
0
]
elif
provider
==
"huggingface"
:
return
torch_kto_loss
(
x
=
_input
,
ref_x
=
ref_input
,
y
=
target
,
preference_labels
=
preference_labels
,
kl
=
kl
,
)[
0
]
def
full
():
y
=
fwd
()
y
.
backward
()
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
_iter
=
10
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
def
bench_speed_kto_loss
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
B
=
input
.
x
T
=
input
.
extra_benchmark_config
[
"T"
]
H
=
input
.
extra_benchmark_config
[
"H"
]
V
=
input
.
extra_benchmark_config
[
"V"
]
dtype
=
input
.
extra_benchmark_config
[
"dtype"
]
bias
=
input
.
extra_benchmark_config
[
"bias"
]
beta
=
input
.
extra_benchmark_config
[
"beta"
]
ignore_index
=
input
.
extra_benchmark_config
[
"ignore_index"
]
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
torch_kto_loss
=
TorchLMHeadKTO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
,
beta
=
beta
,
ignore_index
=
ignore_index
,
use_bias
=
bias
,
).
to
(
device
)
liger_kto_loss
=
LigerLMHeadKTO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
,
beta
=
beta
,
ignore_index
=
ignore_index
,
use_bias
=
bias
,
).
to
(
device
)
# Input shape: [B, T, H]
_input
=
torch
.
randn
(
B
,
T
,
H
,
device
=
device
,
dtype
=
dtype
)
# Target shape: [B, T]
target
=
torch
.
randint
(
V
,
(
B
,
T
),
device
=
device
,
dtype
=
torch
.
long
)
# Preference labels shape: [B]
# Create binary preference labels (0 or 1) for each sequence in the batch
# Used to indicate preferred sequences (1) vs non-preferred sequences (0)
preference_labels
=
torch
.
randint
(
2
,
(
B
,),
dtype
=
torch
.
bool
,
device
=
device
)
# Precomputed KL divergence between policy and reference distributions
kl
=
torch
.
randn
(
1
,
device
=
device
,
dtype
=
dtype
)
# Add ignore_index tokens
num_elements_to_assign
=
torch
.
randint
(
1
,
B
*
T
//
2
,
(
1
,)).
item
()
indices_to_assign
=
torch
.
randperm
(
B
*
T
)[:
num_elements_to_assign
]
target
.
view
(
-
1
)[
indices_to_assign
]
=
ignore_index
# Add ref_x with the same shape as _input
ref_input
=
torch
.
randn
(
B
,
T
,
H
,
device
=
device
,
dtype
=
dtype
)
def
fwd
():
if
provider
==
"liger"
:
return
liger_kto_loss
(
x
=
_input
,
ref_x
=
ref_input
,
y
=
target
,
preference_labels
=
preference_labels
,
kl
=
kl
,
)[
0
]
elif
provider
==
"huggingface"
:
return
torch_kto_loss
(
x
=
_input
,
ref_x
=
ref_input
,
y
=
target
,
preference_labels
=
preference_labels
,
kl
=
kl
,
)[
0
]
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
rep
=
100
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"backward"
:
y
=
fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
retain_graph
=
True
),
grad_to_none
=
[
_input
],
rep
=
100
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"full"
:
def
full
():
y
=
fwd
()
y
.
backward
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
rep
=
100
,
quantiles
=
QUANTILES
,
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs
=
{
"kernel_name"
:
"kto_loss"
,
"x_name"
:
"B"
,
"x_label"
:
"Batch Size (B)"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
1
,
6
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[
{
"T"
:
512
,
"H"
:
1024
,
"V"
:
128256
,
"mode"
:
"forward"
,
"dtype"
:
torch
.
bfloat16
,
"bias"
:
True
,
"beta"
:
0.1
,
"ignore_index"
:
42
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_kto_loss
,
kernel_operation_modes
=
[
"forward"
,
"backward"
,
"full"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_kto_loss
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs
,
)
benchmark/scripts/benchmark_layer_norm.py
0 → 100755
View file @
9b0e3a30
import
torch
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.layer_norm
import
LigerLayerNorm
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
def
bench_speed_layer_norm
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
N
=
input
.
x
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
extra_benchmark_config
=
input
.
extra_benchmark_config
M
=
extra_benchmark_config
[
"M"
]
eps
=
extra_benchmark_config
[
"eps"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
x_shape
=
(
M
,
N
)
triton_ln
=
LigerLayerNorm
(
hidden_size
=
N
).
to
(
device
)
torch_ln
=
torch
.
nn
.
LayerNorm
(
N
,
eps
=
eps
).
to
(
device
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
def
y_fwd
():
if
provider
==
"liger"
:
return
triton_ln
(
x
)
if
provider
==
"huggingface"
:
return
torch_ln
(
x
)
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
y_fwd
,
quantiles
=
QUANTILES
,
grad_to_none
=
[
x
],
rep
=
500
)
elif
mode
==
"backward"
:
y
=
y_fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
dy
,
retain_graph
=
True
),
quantiles
=
QUANTILES
,
grad_to_none
=
[
x
],
rep
=
500
,
)
elif
mode
==
"full"
:
def
full
():
y
=
y_fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
quantiles
=
QUANTILES
,
grad_to_none
=
[
x
],
rep
=
500
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_layer_norm
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
N
=
input
.
x
provider
=
input
.
kernel_provider
dtype
=
input
.
extra_benchmark_config
[
"dtype"
]
M
=
input
.
extra_benchmark_config
[
"M"
]
eps
=
input
.
extra_benchmark_config
[
"eps"
]
x_shape
=
(
M
,
N
)
triton_ln
=
LigerLayerNorm
(
hidden_size
=
N
).
to
(
device
)
torch_ln
=
torch
.
nn
.
LayerNorm
(
N
,
eps
=
eps
).
to
(
device
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
def
y_fwd
():
if
provider
==
"liger"
:
return
triton_ln
(
x
)
if
provider
==
"huggingface"
:
return
torch_ln
(
x
)
def
full
():
y
=
y_fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs
=
{
"kernel_name"
:
"layer_norm"
,
"x_name"
:
"N"
,
"x_label"
:
"hidden size"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
10
,
15
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[{
"M"
:
4096
,
"dtype"
:
torch
.
float32
,
"eps"
:
1e-6
}],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_layer_norm
,
kernel_operation_modes
=
[
"forward"
,
"backward"
,
"full"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_layer_norm
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs
,
)
benchmark/scripts/benchmark_llama4_rope.py
0 → 100755
View file @
9b0e3a30
import
torch
import
triton
from
transformers.models.llama4.configuration_llama4
import
Llama4TextConfig
from
transformers.models.llama4.modeling_llama4
import
Llama4TextRotaryEmbedding
from
transformers.models.llama4.modeling_llama4
import
apply_rotary_emb
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.llama4_rope
import
liger_llama4_text_rotary_pos_emb
from
liger_kernel.utils
import
infer_device
from
liger_kernel.utils
import
transformers_version_dispatch
device
=
infer_device
()
def
bench_speed_llama4_rope
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
extra_benchmark_config
=
input
.
extra_benchmark_config
num_q_heads
=
extra_benchmark_config
[
"num_q_heads"
]
num_kv_heads
=
extra_benchmark_config
[
"num_kv_heads"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
# x can be either hidden_size or seq_len
hidden_size
=
extra_benchmark_config
[
"hidden_size"
]
if
"hidden_size"
in
extra_benchmark_config
else
input
.
x
seq_len
=
extra_benchmark_config
[
"seq_len"
]
if
"seq_len"
in
extra_benchmark_config
else
input
.
x
head_dim
=
hidden_size
//
num_q_heads
# Create Llama4TextConfig for the rotary embedding
config
=
Llama4TextConfig
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_q_heads
,
num_key_value_heads
=
num_kv_heads
,
head_dim
=
head_dim
,
max_position_embeddings
=
seq_len
,
)
rotary_emb
=
transformers_version_dispatch
(
"4.48.0"
,
Llama4TextRotaryEmbedding
,
Llama4TextRotaryEmbedding
,
before_kwargs
=
{
"config"
:
config
,
"device"
:
device
},
after_kwargs
=
{
"config"
:
config
,
"device"
:
device
},
)
q
=
torch
.
randn
(
(
1
,
seq_len
,
num_q_heads
,
head_dim
),
device
=
device
,
requires_grad
=
True
,
dtype
=
dtype
,
)
k
=
torch
.
randn
(
(
1
,
seq_len
,
num_kv_heads
,
head_dim
),
device
=
device
,
requires_grad
=
True
,
dtype
=
dtype
,
)
dq
,
dk
=
(
torch
.
randn_like
(
q
,
device
=
device
,
dtype
=
dtype
),
torch
.
randn_like
(
k
,
device
=
device
),
)
pos_ids
=
torch
.
arange
(
seq_len
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
freqs_cis
=
rotary_emb
(
q
,
pos_ids
)
def
fwd
():
if
provider
==
"liger"
:
return
liger_llama4_text_rotary_pos_emb
(
q
,
k
,
freqs_cis
)
elif
provider
==
"huggingface"
:
return
apply_rotary_emb
(
q
,
k
,
freqs_cis
)
else
:
raise
ValueError
(
f
"Invalid provider:
{
provider
}
for Llama4 RoPE embedding"
)
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
grad_to_none
=
[
q
,
k
],
rep
=
400
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"backward"
:
q_out
,
k_out
=
fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
autograd
.
grad
((
q_out
,
k_out
),
(
q
,
k
),
(
dq
,
dk
),
allow_unused
=
True
,
retain_graph
=
True
),
grad_to_none
=
[
q
,
k
],
rep
=
400
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"full"
:
def
full
():
q_out
,
k_out
=
fwd
()
torch
.
autograd
.
grad
((
q_out
,
k_out
),
(
q
,
k
),
(
dq
,
dk
),
allow_unused
=
True
)
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
grad_to_none
=
[
q
,
k
],
rep
=
400
,
quantiles
=
QUANTILES
,
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_llama4_rope
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
provider
=
input
.
kernel_provider
extra_benchmark_config
=
input
.
extra_benchmark_config
num_q_heads
=
extra_benchmark_config
[
"num_q_heads"
]
num_kv_heads
=
extra_benchmark_config
[
"num_kv_heads"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
# x can be either hidden_size or seq_len
hidden_size
=
extra_benchmark_config
[
"hidden_size"
]
if
"hidden_size"
in
extra_benchmark_config
else
input
.
x
seq_len
=
extra_benchmark_config
[
"seq_len"
]
if
"seq_len"
in
extra_benchmark_config
else
input
.
x
head_dim
=
hidden_size
//
num_q_heads
# Create Llama4TextConfig for the rotary embedding
config
=
Llama4TextConfig
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_q_heads
,
num_key_value_heads
=
num_kv_heads
,
head_dim
=
head_dim
,
max_position_embeddings
=
seq_len
,
)
rotary_emb
=
transformers_version_dispatch
(
"4.48.0"
,
Llama4TextRotaryEmbedding
,
Llama4TextRotaryEmbedding
,
before_kwargs
=
{
"config"
:
config
,
"device"
:
device
},
after_kwargs
=
{
"config"
:
config
,
"device"
:
device
},
)
q
=
torch
.
randn
(
(
1
,
seq_len
,
num_q_heads
,
head_dim
),
device
=
device
,
requires_grad
=
True
,
dtype
=
dtype
,
)
k
=
torch
.
randn
(
(
1
,
seq_len
,
num_kv_heads
,
head_dim
),
device
=
device
,
requires_grad
=
True
,
dtype
=
dtype
,
)
dq
,
dk
=
(
torch
.
randn_like
(
q
,
device
=
device
,
dtype
=
dtype
),
torch
.
randn_like
(
k
,
device
=
device
),
)
pos_ids
=
torch
.
arange
(
seq_len
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
freqs_cis
=
rotary_emb
(
q
,
pos_ids
)
def
full
():
if
provider
==
"liger"
:
q_out
,
k_out
=
liger_llama4_text_rotary_pos_emb
(
q
,
k
,
freqs_cis
)
else
:
q_out
,
k_out
=
apply_rotary_emb
(
q
,
k
,
freqs_cis
)
torch
.
autograd
.
grad
((
q_out
,
k_out
),
(
q
,
k
),
(
dq
,
dk
),
allow_unused
=
True
,
retain_graph
=
True
)
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
,
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs_varying_hidden_size
=
{
"kernel_name"
:
"llama4_rope"
,
"x_name"
:
"H"
,
"x_label"
:
"hidden size"
,
"x_values"
:
[
32
*
(
2
**
i
)
for
i
in
range
(
4
,
10
,
2
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[
{
"dtype"
:
torch
.
bfloat16
,
"seq_len"
:
2048
,
"num_q_heads"
:
32
,
"num_kv_heads"
:
8
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_llama4_rope
,
kernel_operation_modes
=
[
"forward"
,
"backward"
,
"full"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs_varying_hidden_size
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_llama4_rope
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs_varying_hidden_size
,
)
common_configs_varying_seq_len
=
{
"kernel_name"
:
"llama4_rope"
,
"x_name"
:
"T"
,
"x_label"
:
"sequence length"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
10
,
15
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[
{
"dtype"
:
torch
.
bfloat16
,
"hidden_size"
:
8192
,
"num_q_heads"
:
32
,
"num_kv_heads"
:
8
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_llama4_rope
,
kernel_operation_modes
=
[
"forward"
,
"backward"
,
"full"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs_varying_seq_len
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_llama4_rope
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs_varying_seq_len
,
)
benchmark/scripts/benchmark_mhc.py
0 → 100755
View file @
9b0e3a30
import
os
import
sys
import
torch
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.functional
import
liger_mhc_coeffs
from
liger_kernel.transformers.functional
import
liger_mhc_post_res
from
liger_kernel.transformers.functional
import
liger_mhc_pre
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../.."
)))
def
bench_speed_mhc
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
from
test.transformers.test_mhc
import
mhc_coeffs_ref
T
=
input
.
x
B
=
input
.
extra_benchmark_config
[
"B"
]
HC
=
input
.
extra_benchmark_config
[
"HC"
]
C
=
input
.
extra_benchmark_config
[
"C"
]
sub_kernel
=
input
.
extra_benchmark_config
[
"sub_kernel"
]
tmax
=
input
.
extra_benchmark_config
[
"tmax"
]
rms_eps
=
input
.
extra_benchmark_config
[
"rms_eps"
]
pre_eps
=
input
.
extra_benchmark_config
[
"pre_eps"
]
sinkhorn_eps
=
input
.
extra_benchmark_config
[
"sinkhorn_eps"
]
post_mult
=
input
.
extra_benchmark_config
[
"post_mult"
]
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
coeffs_cfg
=
dict
(
tmax
=
tmax
,
rms_eps
=
rms_eps
,
pre_eps
=
pre_eps
,
sinkhorn_eps
=
sinkhorn_eps
,
post_mult
=
post_mult
)
need_grad
=
mode
in
(
"backward"
,
"full"
)
x
=
torch
.
randn
(
B
,
T
,
HC
,
C
,
device
=
device
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
need_grad
)
K
,
M
=
HC
*
C
,
HC
*
HC
+
2
*
HC
phi
=
(
torch
.
randn
(
K
,
M
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
*
0.02
).
requires_grad_
(
need_grad
)
b_param
=
torch
.
zeros
(
M
,
device
=
device
,
dtype
=
torch
.
float32
,
requires_grad
=
need_grad
)
alpha_pre
=
torch
.
tensor
(
1.0
,
device
=
device
,
dtype
=
torch
.
float32
,
requires_grad
=
need_grad
)
alpha_post
=
torch
.
tensor
(
1.0
,
device
=
device
,
dtype
=
torch
.
float32
,
requires_grad
=
need_grad
)
alpha_res
=
torch
.
tensor
(
1.0
,
device
=
device
,
dtype
=
torch
.
float32
,
requires_grad
=
need_grad
)
grad_to_none
=
[
x
,
phi
,
b_param
,
alpha_pre
,
alpha_post
,
alpha_res
]
if
need_grad
else
None
if
sub_kernel
==
"coeffs"
:
def
fwd
():
if
provider
==
"liger"
:
return
liger_mhc_coeffs
(
x
,
phi
,
b_param
,
alpha_pre
,
alpha_post
,
alpha_res
,
**
coeffs_cfg
)
return
mhc_coeffs_ref
(
x
,
phi
,
b_param
,
alpha_pre
,
alpha_post
,
alpha_res
,
**
coeffs_cfg
)
def
fwd_loss
():
h_pre
,
h_post
,
h_res
=
fwd
()
return
h_pre
.
square
().
mean
()
+
h_post
.
square
().
mean
()
+
h_res
.
square
().
mean
()
elif
sub_kernel
==
"pre"
:
with
torch
.
no_grad
():
h_pre_c
,
_
,
_
=
liger_mhc_coeffs
(
x
.
detach
(),
phi
.
detach
(),
b_param
.
detach
(),
alpha_pre
.
detach
(),
alpha_post
.
detach
(),
alpha_res
.
detach
(),
**
coeffs_cfg
,
)
h_pre_c
.
requires_grad_
(
need_grad
)
grad_to_none
=
[
x
,
h_pre_c
]
if
need_grad
else
None
def
fwd
():
if
provider
==
"liger"
:
return
liger_mhc_pre
(
x
,
h_pre_c
)
return
(
x
.
float
()
*
h_pre_c
.
unsqueeze
(
-
1
)).
sum
(
dim
=-
2
)
def
fwd_loss
():
return
fwd
().
square
().
mean
()
elif
sub_kernel
==
"post_res"
:
with
torch
.
no_grad
():
_
,
h_post_c
,
h_res_c
=
liger_mhc_coeffs
(
x
.
detach
(),
phi
.
detach
(),
b_param
.
detach
(),
alpha_pre
.
detach
(),
alpha_post
.
detach
(),
alpha_res
.
detach
(),
**
coeffs_cfg
,
)
h_post_c
.
requires_grad_
(
need_grad
)
h_res_c
.
requires_grad_
(
need_grad
)
f_out
=
torch
.
randn
(
B
,
T
,
C
,
device
=
device
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
need_grad
)
grad_to_none
=
[
x
,
f_out
,
h_post_c
,
h_res_c
]
if
need_grad
else
None
def
fwd
():
if
provider
==
"liger"
:
return
liger_mhc_post_res
(
x
,
f_out
,
h_post_c
,
h_res_c
)
return
torch
.
einsum
(
"...oi,...ic->...oc"
,
h_res_c
,
x
.
float
())
+
h_post_c
.
unsqueeze
(
-
1
)
*
f_out
.
float
().
unsqueeze
(
-
2
)
def
fwd_loss
():
return
fwd
().
square
().
mean
()
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
rep
=
100
,
quantiles
=
QUANTILES
)
elif
mode
==
"backward"
:
y
=
fwd_loss
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
retain_graph
=
True
),
grad_to_none
=
grad_to_none
,
rep
=
100
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"full"
:
def
full
():
y
=
fwd_loss
()
y
.
backward
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
grad_to_none
=
grad_to_none
,
rep
=
100
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
)
def
bench_memory_mhc
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
from
test.transformers.test_mhc
import
mhc_coeffs_ref
T
=
input
.
x
B
=
input
.
extra_benchmark_config
[
"B"
]
HC
=
input
.
extra_benchmark_config
[
"HC"
]
C
=
input
.
extra_benchmark_config
[
"C"
]
sub_kernel
=
input
.
extra_benchmark_config
[
"sub_kernel"
]
tmax
=
input
.
extra_benchmark_config
[
"tmax"
]
rms_eps
=
input
.
extra_benchmark_config
[
"rms_eps"
]
pre_eps
=
input
.
extra_benchmark_config
[
"pre_eps"
]
sinkhorn_eps
=
input
.
extra_benchmark_config
[
"sinkhorn_eps"
]
post_mult
=
input
.
extra_benchmark_config
[
"post_mult"
]
provider
=
input
.
kernel_provider
coeffs_cfg
=
dict
(
tmax
=
tmax
,
rms_eps
=
rms_eps
,
pre_eps
=
pre_eps
,
sinkhorn_eps
=
sinkhorn_eps
,
post_mult
=
post_mult
)
x
=
torch
.
randn
(
B
,
T
,
HC
,
C
,
device
=
device
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
True
)
K
,
M
=
HC
*
C
,
HC
*
HC
+
2
*
HC
phi
=
(
torch
.
randn
(
K
,
M
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
*
0.02
).
requires_grad_
(
True
)
b_param
=
torch
.
zeros
(
M
,
device
=
device
,
dtype
=
torch
.
float32
,
requires_grad
=
True
)
alpha_pre
=
torch
.
tensor
(
1.0
,
device
=
device
,
dtype
=
torch
.
float32
,
requires_grad
=
True
)
alpha_post
=
torch
.
tensor
(
1.0
,
device
=
device
,
dtype
=
torch
.
float32
,
requires_grad
=
True
)
alpha_res
=
torch
.
tensor
(
1.0
,
device
=
device
,
dtype
=
torch
.
float32
,
requires_grad
=
True
)
if
sub_kernel
==
"coeffs"
:
def
full
():
if
provider
==
"liger"
:
hp
,
hpo
,
hr
=
liger_mhc_coeffs
(
x
,
phi
,
b_param
,
alpha_pre
,
alpha_post
,
alpha_res
,
**
coeffs_cfg
)
else
:
hp
,
hpo
,
hr
=
mhc_coeffs_ref
(
x
,
phi
,
b_param
,
alpha_pre
,
alpha_post
,
alpha_res
,
**
coeffs_cfg
)
(
hp
.
square
().
mean
()
+
hpo
.
square
().
mean
()
+
hr
.
square
().
mean
()).
backward
()
elif
sub_kernel
==
"pre"
:
with
torch
.
no_grad
():
h_pre_c
,
_
,
_
=
liger_mhc_coeffs
(
x
.
detach
(),
phi
.
detach
(),
b_param
.
detach
(),
alpha_pre
.
detach
(),
alpha_post
.
detach
(),
alpha_res
.
detach
(),
**
coeffs_cfg
,
)
h_pre_c
.
requires_grad_
(
True
)
def
full
():
if
provider
==
"liger"
:
out
=
liger_mhc_pre
(
x
,
h_pre_c
)
else
:
out
=
(
x
.
float
()
*
h_pre_c
.
unsqueeze
(
-
1
)).
sum
(
dim
=-
2
)
out
.
square
().
mean
().
backward
()
elif
sub_kernel
==
"post_res"
:
with
torch
.
no_grad
():
_
,
h_post_c
,
h_res_c
=
liger_mhc_coeffs
(
x
.
detach
(),
phi
.
detach
(),
b_param
.
detach
(),
alpha_pre
.
detach
(),
alpha_post
.
detach
(),
alpha_res
.
detach
(),
**
coeffs_cfg
,
)
h_post_c
.
requires_grad_
(
True
)
h_res_c
.
requires_grad_
(
True
)
f_out
=
torch
.
randn
(
B
,
T
,
C
,
device
=
device
,
dtype
=
torch
.
bfloat16
,
requires_grad
=
True
)
def
full
():
if
provider
==
"liger"
:
out
=
liger_mhc_post_res
(
x
,
f_out
,
h_post_c
,
h_res_c
)
else
:
out
=
torch
.
einsum
(
"...oi,...ic->...oc"
,
h_res_c
,
x
.
float
())
+
h_post_c
.
unsqueeze
(
-
1
)
*
f_out
.
float
().
unsqueeze
(
-
2
)
out
.
square
().
mean
().
backward
()
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
_iter
=
10
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
for
sub_kernel
in
[
"coeffs"
,
"pre"
,
"post_res"
]:
common_configs
=
{
"kernel_name"
:
f
"mhc_
{
sub_kernel
}
"
,
"x_name"
:
"T"
,
"x_label"
:
"Sequence Length (T)"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
7
,
12
)],
"kernel_providers"
:
[
"liger"
,
"torch"
],
"extra_benchmark_configs"
:
[
{
"B"
:
4
,
"HC"
:
4
,
"C"
:
4096
,
"tmax"
:
20
,
"rms_eps"
:
1e-6
,
"pre_eps"
:
0.0
,
"sinkhorn_eps"
:
1e-6
,
"post_mult"
:
2.0
,
"sub_kernel"
:
sub_kernel
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_mhc
,
kernel_operation_modes
=
[
"forward"
,
"backward"
,
"full"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_mhc
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs
,
)
benchmark/scripts/benchmark_mhc_lm.py
0 → 100755
View file @
9b0e3a30
import
os
import
sys
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.mhc
import
LigerMHC
from
liger_kernel.utils
import
infer_device
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../.."
)))
device
=
infer_device
()
class
RMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
*
,
eps
:
float
,
dtype
:
torch
.
dtype
,
device
:
str
):
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
device
))
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
var
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
var
+
self
.
eps
)
return
x
*
self
.
weight
def
_build_rope_cache
(
seq_len
:
int
,
head_dim
:
int
,
*
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
):
inv_freq
=
1.0
/
(
10000
**
(
torch
.
arange
(
0
,
head_dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
/
head_dim
))
positions
=
torch
.
arange
(
seq_len
,
device
=
device
,
dtype
=
torch
.
float32
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
positions
,
inv_freq
)
cos
=
freqs
.
cos
().
to
(
dtype
)
sin
=
freqs
.
sin
().
to
(
dtype
)
return
cos
,
sin
def
_apply_rope
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
cos
=
cos
[
None
,
None
,
:,
:]
sin
=
sin
[
None
,
None
,
:,
:]
return
torch
.
cat
([
x1
*
cos
-
x2
*
sin
,
x1
*
sin
+
x2
*
cos
],
dim
=-
1
)
class
MiniLlamaAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
*
,
dtype
:
torch
.
dtype
,
device
:
str
):
super
().
__init__
()
assert
hidden_size
%
num_heads
==
0
self
.
hidden_size
=
hidden_size
self
.
num_heads
=
num_heads
self
.
head_dim
=
hidden_size
//
num_heads
assert
self
.
head_dim
%
2
==
0
,
"head_dim must be even for RoPE"
self
.
q_proj
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
k_proj
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
v_proj
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
o_proj
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bsz
,
seq_len
,
_
=
x
.
shape
q
=
self
.
q_proj
(
x
)
k
=
self
.
k_proj
(
x
)
v
=
self
.
v_proj
(
x
)
q
=
q
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k
=
k
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
v
=
v
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
cos
,
sin
=
_build_rope_cache
(
seq_len
,
self
.
head_dim
,
device
=
x
.
device
,
dtype
=
q
.
dtype
)
q
=
_apply_rope
(
q
,
cos
,
sin
)
k
=
_apply_rope
(
k
,
cos
,
sin
)
attn
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
is_causal
=
True
)
attn
=
attn
.
transpose
(
1
,
2
).
contiguous
().
view
(
bsz
,
seq_len
,
self
.
hidden_size
)
return
self
.
o_proj
(
attn
)
class
MiniLlamaMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_mult
:
int
,
*
,
dtype
:
torch
.
dtype
,
device
:
str
):
super
().
__init__
()
intermediate_size
=
hidden_size
*
intermediate_mult
self
.
gate_proj
=
nn
.
Linear
(
hidden_size
,
intermediate_size
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
up_proj
=
nn
.
Linear
(
hidden_size
,
intermediate_size
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
down_proj
=
nn
.
Linear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
down_proj
(
F
.
silu
(
self
.
gate_proj
(
x
))
*
self
.
up_proj
(
x
))
class
AttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
*
,
dtype
:
torch
.
dtype
,
device
:
str
):
super
().
__init__
()
self
.
norm
=
RMSNorm
(
hidden_size
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
self
.
attn
=
MiniLlamaAttention
(
hidden_size
,
num_heads
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
attn
(
self
.
norm
(
x
))
class
MLPBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_mult
:
int
,
*
,
dtype
:
torch
.
dtype
,
device
:
str
):
super
().
__init__
()
self
.
norm
=
RMSNorm
(
hidden_size
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
self
.
mlp
=
MiniLlamaMLP
(
hidden_size
,
intermediate_mult
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
mlp
(
self
.
norm
(
x
))
class
TorchMHC
(
nn
.
Module
):
def
__init__
(
self
,
layer
:
nn
.
Module
,
*
,
hc
:
int
,
c
:
int
,
tmax
:
int
,
rms_eps
:
float
,
pre_eps
:
float
,
sinkhorn_eps
:
float
,
post_mult
:
float
,
phi_dtype
:
torch
.
dtype
,
):
super
().
__init__
()
self
.
layer
=
layer
self
.
hc
=
int
(
hc
)
self
.
c
=
int
(
c
)
self
.
tmax
=
int
(
tmax
)
self
.
rms_eps
=
float
(
rms_eps
)
self
.
pre_eps
=
float
(
pre_eps
)
self
.
sinkhorn_eps
=
float
(
sinkhorn_eps
)
self
.
post_mult
=
float
(
post_mult
)
layer_param
=
next
(
layer
.
parameters
())
device
=
layer_param
.
device
m
=
hc
*
hc
+
2
*
hc
k
=
hc
*
c
self
.
phi
=
nn
.
Parameter
(
torch
.
randn
(
k
,
m
,
dtype
=
phi_dtype
,
device
=
device
)
*
0.02
)
self
.
b
=
nn
.
Parameter
(
torch
.
zeros
(
m
,
dtype
=
torch
.
float32
,
device
=
device
))
self
.
alpha_pre
=
nn
.
Parameter
(
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
))
self
.
alpha_post
=
nn
.
Parameter
(
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
))
self
.
alpha_res
=
nn
.
Parameter
(
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
))
self
.
layer_dtype
=
layer_param
.
dtype
def
_coeffs
(
self
,
x
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
from
test.transformers.test_mhc
import
mhc_coeffs_ref
return
mhc_coeffs_ref
(
x
,
self
.
phi
,
self
.
b
,
self
.
alpha_pre
,
self
.
alpha_post
,
self
.
alpha_res
,
tmax
=
self
.
tmax
,
rms_eps
=
self
.
rms_eps
,
pre_eps
=
self
.
pre_eps
,
sinkhorn_eps
=
self
.
sinkhorn_eps
,
post_mult
=
self
.
post_mult
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
h_pre
,
h_post
,
h_res
=
self
.
_coeffs
(
x
)
x_in
=
(
x
.
float
()
*
h_pre
.
unsqueeze
(
-
1
)).
sum
(
dim
=-
2
)
if
x_in
.
dtype
!=
self
.
layer_dtype
:
x_in
=
x_in
.
to
(
self
.
layer_dtype
)
f_out
=
self
.
layer
(
x_in
)
x_out
=
torch
.
einsum
(
"...oi,...ic->...oc"
,
h_res
,
x
.
float
())
+
h_post
.
unsqueeze
(
-
1
)
*
f_out
.
float
().
unsqueeze
(
-
2
)
return
x_out
.
to
(
x
.
dtype
)
class
MHCDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
mhc_cls
:
type
[
nn
.
Module
],
*
,
hidden_size
:
int
,
hc
:
int
,
num_heads
:
int
,
intermediate_mult
:
int
,
tmax
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
):
super
().
__init__
()
attn
=
AttentionBlock
(
hidden_size
,
num_heads
,
dtype
=
dtype
,
device
=
device
)
mlp
=
MLPBlock
(
hidden_size
,
intermediate_mult
,
dtype
=
dtype
,
device
=
device
)
self
.
attn
=
mhc_cls
(
attn
,
hc
=
hc
,
c
=
hidden_size
,
tmax
=
tmax
,
rms_eps
=
1e-6
,
pre_eps
=
1e-4
,
sinkhorn_eps
=
1e-6
,
post_mult
=
2.0
,
phi_dtype
=
dtype
,
)
self
.
mlp
=
mhc_cls
(
mlp
,
hc
=
hc
,
c
=
hidden_size
,
tmax
=
tmax
,
rms_eps
=
1e-6
,
pre_eps
=
1e-4
,
sinkhorn_eps
=
1e-6
,
post_mult
=
2.0
,
phi_dtype
=
dtype
,
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
attn
(
x
)
x
=
self
.
mlp
(
x
)
return
x
class
BenchMiniMHCLM
(
nn
.
Module
):
def
__init__
(
self
,
mhc_cls
:
type
[
nn
.
Module
],
*
,
vocab_size
:
int
,
hidden_size
:
int
,
hc
:
int
,
num_layers
:
int
,
num_heads
:
int
,
intermediate_mult
:
int
,
tmax
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
,
):
super
().
__init__
()
self
.
hc
=
hc
self
.
hidden_size
=
hidden_size
self
.
embed
=
nn
.
Embedding
(
vocab_size
,
hc
*
hidden_size
,
dtype
=
dtype
,
device
=
device
)
self
.
layers
=
nn
.
ModuleList
(
[
MHCDecoderLayer
(
mhc_cls
,
hidden_size
=
hidden_size
,
hc
=
hc
,
num_heads
=
num_heads
,
intermediate_mult
=
intermediate_mult
,
tmax
=
tmax
,
dtype
=
dtype
,
device
=
device
,
)
for
_
in
range
(
num_layers
)
]
)
self
.
final_norm
=
RMSNorm
(
hidden_size
,
eps
=
1e-6
,
dtype
=
dtype
,
device
=
device
)
self
.
lm_head
=
nn
.
Linear
(
hidden_size
,
vocab_size
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
embed
(
input_ids
)
bsz
,
seq_len
,
_
=
x
.
shape
x
=
x
.
view
(
bsz
,
seq_len
,
self
.
hc
,
self
.
hidden_size
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x
=
x
.
mean
(
dim
=-
2
)
x
=
self
.
final_norm
(
x
)
return
self
.
lm_head
(
x
)
def
_build_model
(
provider
:
str
,
*
,
hidden_size
:
int
,
hc
:
int
,
num_layers
:
int
,
num_heads
:
int
,
intermediate_mult
:
int
,
vocab_size
:
int
,
tmax
:
int
,
dtype
:
torch
.
dtype
,
):
mhc_cls
=
LigerMHC
if
provider
==
"liger"
else
TorchMHC
return
BenchMiniMHCLM
(
mhc_cls
,
vocab_size
=
vocab_size
,
hidden_size
=
hidden_size
,
hc
=
hc
,
num_layers
=
num_layers
,
num_heads
=
num_heads
,
intermediate_mult
=
intermediate_mult
,
tmax
=
tmax
,
dtype
=
dtype
,
device
=
device
,
)
def
bench_speed_mhc_lm
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
hidden_size
=
int
(
input
.
x
)
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
extra
=
input
.
extra_benchmark_config
bsz
=
extra
[
"B"
]
seq_len
=
extra
[
"T"
]
hc
=
extra
[
"HC"
]
num_layers
=
extra
[
"layers"
]
num_heads
=
extra
[
"heads"
]
vocab_size
=
extra
[
"vocab"
]
dtype
=
extra
[
"dtype"
]
tmax
=
extra
[
"tmax"
]
intermediate_mult
=
extra
[
"intermediate_mult"
]
if
hidden_size
%
num_heads
!=
0
:
raise
ValueError
(
"hidden_size must be divisible by num_heads"
)
model
=
_build_model
(
provider
,
hidden_size
=
hidden_size
,
hc
=
hc
,
num_layers
=
num_layers
,
num_heads
=
num_heads
,
intermediate_mult
=
intermediate_mult
,
vocab_size
=
vocab_size
,
tmax
=
tmax
,
dtype
=
dtype
,
)
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
bsz
,
seq_len
),
device
=
device
)
def
fwd
():
return
model
(
input_ids
)
def
fwd_loss
():
return
fwd
().
float
().
mean
()
grad_to_none
=
list
(
model
.
parameters
())
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
quantiles
=
QUANTILES
,
grad_to_none
=
grad_to_none
,
rep
=
100
)
elif
mode
==
"backward"
:
loss
=
fwd_loss
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
loss
.
backward
(
retain_graph
=
True
),
quantiles
=
QUANTILES
,
grad_to_none
=
grad_to_none
,
rep
=
100
,
)
elif
mode
==
"full"
:
def
full
():
loss
=
fwd_loss
()
loss
.
backward
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
quantiles
=
QUANTILES
,
grad_to_none
=
grad_to_none
,
rep
=
100
)
else
:
raise
ValueError
(
f
"Unknown mode:
{
mode
}
"
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_mhc_lm
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
hidden_size
=
int
(
input
.
x
)
provider
=
input
.
kernel_provider
extra
=
input
.
extra_benchmark_config
bsz
=
extra
[
"B"
]
seq_len
=
extra
[
"T"
]
hc
=
extra
[
"HC"
]
num_layers
=
extra
[
"layers"
]
num_heads
=
extra
[
"heads"
]
vocab_size
=
extra
[
"vocab"
]
dtype
=
extra
[
"dtype"
]
tmax
=
extra
[
"tmax"
]
intermediate_mult
=
extra
[
"intermediate_mult"
]
if
hidden_size
%
num_heads
!=
0
:
raise
ValueError
(
"hidden_size must be divisible by num_heads"
)
model
=
_build_model
(
provider
,
hidden_size
=
hidden_size
,
hc
=
hc
,
num_layers
=
num_layers
,
num_heads
=
num_heads
,
intermediate_mult
=
intermediate_mult
,
vocab_size
=
vocab_size
,
tmax
=
tmax
,
dtype
=
dtype
,
)
input_ids
=
torch
.
randint
(
0
,
vocab_size
,
(
bsz
,
seq_len
),
device
=
device
)
def
fwd
():
return
model
(
input_ids
)
def
full
():
loss
=
fwd
().
float
().
mean
()
loss
.
backward
()
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs
=
{
"kernel_name"
:
"mhc_llama_like_lm"
,
"x_name"
:
"hidden_size"
,
"x_label"
:
"hidden_size"
,
"x_values"
:
[
256
,
512
,
1024
],
"kernel_providers"
:
[
"liger"
,
"torch"
],
"extra_benchmark_configs"
:
[
{
"B"
:
2
,
"T"
:
256
,
"HC"
:
4
,
"layers"
:
2
,
"heads"
:
8
,
"vocab"
:
4096
,
"dtype"
:
torch
.
bfloat16
,
"tmax"
:
8
,
"intermediate_mult"
:
4
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_mhc_lm
,
kernel_operation_modes
=
[
"forward"
,
"backward"
,
"full"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_mhc_lm
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs
,
)
benchmark/scripts/benchmark_model_configs.py
0 → 100755
View file @
9b0e3a30
"""
Standardized benchmark model configurations.
Provides canonical model architecture profiles and device-specific benchmark
parameters. All benchmark scripts should derive their tensor shapes from these
shared configs rather than defining ad-hoc per-script constants.
Usage::
from benchmark_model_configs import (
get_benchmark_model_config,
compute_seq_len_sweep_config,
estimate_kernel_peak_memory,
)
args = parse_benchmark_script_args()
model = get_benchmark_model_config(args.model)
# Measure actual memory via a small probe, then compute sweep config
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
bpt = peak_bytes // probe_num_tokens
config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=bpt)
"""
import
gc
import
math
from
dataclasses
import
dataclass
from
typing
import
Callable
from
typing
import
Dict
from
typing
import
Optional
import
torch
from
liger_kernel.utils
import
get_total_gpu_memory
from
liger_kernel.utils
import
infer_device
@
dataclass
(
frozen
=
True
)
class
ModelConfig
:
"""Canonical model architecture profile.
Each field corresponds to a standard LLM hyperparameter. Benchmark scripts
pick the fields they need (e.g. hidden_size for RMSNorm, vocab_size for
CrossEntropy) while kernel-specific overrides (e.g. hidden_act for GEGLU)
are applied locally in the benchmark script.
"""
name
:
str
hidden_size
:
int
intermediate_size
:
int
vocab_size
:
int
num_attention_heads
:
int
num_key_value_heads
:
int
head_dim
:
int
hidden_act
:
str
max_position_embeddings
:
int
=
8192
rms_norm_eps
:
float
=
1e-5
dtype
:
torch
.
dtype
=
torch
.
bfloat16
@
dataclass
(
frozen
=
True
)
class
SeqLenSweepConfig
:
"""Config for benchmarks that sweep sequence length (e.g. GEGLU, SwiGLU).
Attributes:
batch_size: Safe batch size for the sweep.
seq_len: Max sequence length (upper bound for x_values).
"""
batch_size
:
int
seq_len
:
int
@
dataclass
(
frozen
=
True
)
class
HiddenSizeSweepConfig
:
"""Config for benchmarks that sweep hidden_size with fixed BT (e.g. DyT).
Attributes:
bt: Fixed batch * seq dimension.
max_hidden_size: Upper bound for hidden_size sweep.
"""
bt
:
int
max_hidden_size
:
int
# ── Model Profiles ──────────────────────────────────────────────────────────
LLAMA_2_7B
=
ModelConfig
(
name
=
"llama_2_7b"
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
vocab_size
=
32000
,
num_attention_heads
=
32
,
num_key_value_heads
=
32
,
head_dim
=
128
,
hidden_act
=
"silu"
,
max_position_embeddings
=
4096
,
)
LLAMA_3_8B
=
ModelConfig
(
name
=
"llama_3_8b"
,
hidden_size
=
4096
,
intermediate_size
=
14336
,
vocab_size
=
128256
,
num_attention_heads
=
32
,
num_key_value_heads
=
8
,
head_dim
=
128
,
hidden_act
=
"silu"
,
max_position_embeddings
=
8192
,
)
MODEL_REGISTRY
:
Dict
[
str
,
ModelConfig
]
=
{
"llama_2_7b"
:
LLAMA_2_7B
,
"llama_3_8b"
:
LLAMA_3_8B
,
}
DEFAULT_MODEL_CONFIG
=
LLAMA_3_8B
def
get_benchmark_model_config
(
model_name
:
Optional
[
str
]
=
None
)
->
ModelConfig
:
"""Resolve benchmark model config from name.
Returns the canonical model architecture profile (hidden_size, vocab_size,
dtype, etc.) for benchmark runs. Use this to obtain model attributes
when building benchmark tensors and shapes.
Args:
model_name: Registry key (e.g. ``llama_2_7b``, ``llama_3_8b``).
If None, returns ``DEFAULT_MODEL_CONFIG``.
"""
return
MODEL_REGISTRY
[
model_name
]
if
model_name
else
DEFAULT_MODEL_CONFIG
def
estimate_kernel_peak_memory
(
probe_fn
:
Callable
[[],
torch
.
Tensor
])
->
int
:
"""Run a forward + backward probe to measure peak memory (bytes).
Call this with the *pure PyTorch* (e.g. huggingface) implementation --
that typically has the highest memory footprint and therefore gives a
safe upper-bound estimate. Returns the total peak bytes; divide by
num_tokens if you need bytes-per-token for :func:`compute_seq_len_sweep_config`.
The probe_fn performs setup and forward pass internally; cleanup is
automatic, so callers do not need to manage tensor/layer lifecycle.
Example::
peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
kernel_bpt = peak_bytes // num_tokens # if needed
Args:
probe_fn: Callable that performs setup, runs a forward pass, and
returns an output tensor suitable for ``.backward()``.
"""
device_str
=
infer_device
()
torch_device_mod
=
getattr
(
torch
,
device_str
)
gc
.
collect
()
torch_device_mod
.
empty_cache
()
torch_device_mod
.
memory
.
reset_peak_memory_stats
()
y
=
probe_fn
()
y
.
backward
(
torch
.
randn_like
(
y
))
peak_bytes
=
torch_device_mod
.
max_memory_allocated
()
del
y
gc
.
collect
()
torch_device_mod
.
empty_cache
()
return
max
(
1
,
peak_bytes
)
def
compute_seq_len_sweep_config
(
model_cfg
:
ModelConfig
,
kernel_bytes_per_token
:
Optional
[
int
]
=
None
,
memory_utilization
:
float
=
0.4
,
max_seq_len
:
Optional
[
int
]
=
None
,
max_batch_size
:
int
=
32
,
)
->
SeqLenSweepConfig
:
"""Compute safe batch_size and seq_len for sequence-length sweep (e.g. GEGLU).
Peak memory is estimated as
``batch_size * seq_len * kernel_bytes_per_token`` and is capped at
device memory * memory_utilization. Device memory is obtained
internally via :func:`~liger_kernel.utils.get_total_gpu_memory`.
Prefer obtaining *kernel_bytes_per_token* via
:func:`estimate_kernel_peak_memory` (divide by num_tokens) rather
than hardcoding an analytical estimate.
Args:
model_cfg: Model architecture config.
kernel_bytes_per_token: Peak memory **per token** (``batch * seq_len``
axis). Best obtained from :func:`estimate_kernel_peak_memory` / num_tokens.
Falls back to a conservative heuristic
(``hidden_size * dtype_bytes * 16``) when *None*.
memory_utilization: Fraction of total device memory to target (0 to 1).
Lower values are safer. Default ``0.4`` leaves headroom for
framework overhead and CUDA/NPU context.
max_seq_len: Hard upper bound for sequence length. Defaults to
``model_cfg.max_position_embeddings`` so the sweep never exceeds
the model's native context window.
max_batch_size: Hard upper bound for batch size.
"""
total_memory_gb
=
get_total_gpu_memory
()
dtype_bytes
=
2
if
model_cfg
.
dtype
in
(
torch
.
bfloat16
,
torch
.
float16
)
else
4
if
kernel_bytes_per_token
is
None
:
kernel_bytes_per_token
=
model_cfg
.
hidden_size
*
dtype_bytes
*
16
if
max_seq_len
is
None
:
max_seq_len
=
model_cfg
.
max_position_embeddings
usable_bytes
=
total_memory_gb
*
(
1024
**
3
)
*
memory_utilization
max_tokens
=
max
(
1
,
int
(
usable_bytes
/
kernel_bytes_per_token
))
seq_len
=
min
(
max_seq_len
,
max_tokens
)
seq_len
=
2
**
int
(
math
.
log2
(
seq_len
))
if
seq_len
>=
1024
else
1024
batch_size
=
max
(
1
,
min
(
max_tokens
//
seq_len
,
max_batch_size
))
return
SeqLenSweepConfig
(
batch_size
=
batch_size
,
seq_len
=
seq_len
)
def
compute_hidden_size_sweep_config
(
model_cfg
:
ModelConfig
,
kernel_peak_bytes
:
int
,
bt
:
int
=
4096
,
memory_utilization
:
float
=
0.4
,
max_hidden_size_multiplier
:
int
=
4
,
)
->
HiddenSizeSweepConfig
:
"""Compute safe max_hidden_size for hidden_size sweep (e.g. DyT).
For kernels with shape (BT, hidden_size) where BT is fixed and we sweep
hidden_size. Uses probe peak memory to derive max_hidden_size.
Device memory is obtained internally via :func:`~liger_kernel.utils.get_total_gpu_memory`.
Args:
model_cfg: Model config.
kernel_peak_bytes: Peak memory from probe (BT, model.hidden_size).
bt: Fixed BT dimension; must match the probe.
memory_utilization: Fraction of device memory to use.
max_hidden_size_multiplier: Cap max_hidden_size at model.hidden_size * this.
"""
total_memory_gb
=
get_total_gpu_memory
()
usable_bytes
=
total_memory_gb
*
(
1024
**
3
)
*
memory_utilization
kernel_bpt
=
max
(
1
,
kernel_peak_bytes
//
bt
)
max_hidden_size
=
min
(
model_cfg
.
hidden_size
*
max_hidden_size_multiplier
,
max
(
model_cfg
.
hidden_size
,
int
(
usable_bytes
*
model_cfg
.
hidden_size
/
(
bt
*
kernel_bpt
)),
),
)
max_hidden_size
=
max
(
1024
,
2
**
int
(
math
.
log2
(
max_hidden_size
)))
return
HiddenSizeSweepConfig
(
bt
=
bt
,
max_hidden_size
=
max_hidden_size
)
benchmark/scripts/benchmark_multi_token_attention.py
0 → 100755
View file @
9b0e3a30
import
torch
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.multi_token_attention
import
LigerMultiTokenAttention
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
class
TorchMultiTokenAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
C_in
,
C_out
,
K
,
groups
,
bias
,
dtype
,
device
):
super
().
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
C_out
,
C_in
//
groups
,
K
,
K
,
dtype
=
dtype
,
device
=
device
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
C_out
,
dtype
=
dtype
,
device
=
device
))
if
bias
else
None
self
.
K
=
K
self
.
groups
=
groups
def
forward
(
self
,
scores
):
B
,
C_in
,
L
,
_
=
scores
.
shape
mask
=
torch
.
tril
(
torch
.
ones
(
L
,
L
,
dtype
=
torch
.
bool
,
device
=
scores
.
device
)).
view
(
1
,
1
,
L
,
L
)
inf
=
torch
.
tensor
(
-
1e9
,
device
=
scores
.
device
,
dtype
=
scores
.
dtype
)
zero
=
torch
.
tensor
(
0.0
,
device
=
scores
.
device
,
dtype
=
scores
.
dtype
)
s_inf
=
scores
.
masked_fill
(
~
mask
,
inf
)
probs
=
torch
.
nn
.
functional
.
softmax
(
s_inf
,
dim
=-
1
)
out_c
=
torch
.
nn
.
functional
.
conv2d
(
probs
,
self
.
weight
,
self
.
bias
,
stride
=
1
,
padding
=
self
.
K
//
2
,
groups
=
self
.
groups
)
return
out_c
.
masked_fill
(
~
mask
,
zero
)
def
bench_speed_multi_token_attention
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
L
=
input
.
x
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
extra_benchmark_config
=
input
.
extra_benchmark_config
B
=
extra_benchmark_config
[
"B"
]
C_in
=
extra_benchmark_config
[
"C_in"
]
C_out
=
extra_benchmark_config
[
"C_out"
]
K
=
extra_benchmark_config
[
"K"
]
groups
=
extra_benchmark_config
[
"groups"
]
bias
=
extra_benchmark_config
[
"bias"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
x_shape
=
(
B
,
C_in
,
L
,
L
)
triton_attn
=
(
LigerMultiTokenAttention
(
in_channels
=
C_in
,
out_channels
=
C_out
,
kernel_size
=
K
,
stride
=
1
,
padding
=
K
//
2
,
dilation
=
1
,
groups
=
groups
,
bias
=
bias
,
)
.
to
(
device
)
.
to
(
dtype
)
)
torch_attn
=
TorchMultiTokenAttention
(
C_in
=
C_in
,
C_out
=
C_out
,
K
=
K
,
groups
=
groups
,
bias
=
bias
,
dtype
=
dtype
,
device
=
device
)
with
torch
.
no_grad
():
torch_attn
.
weight
.
copy_
(
triton_attn
.
weight
)
if
bias
:
torch_attn
.
bias
.
copy_
(
triton_attn
.
bias
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
def
fwd
():
if
provider
==
"liger"
:
return
triton_attn
(
x
)
elif
provider
==
"torch"
:
return
torch_attn
(
x
)
print
(
f
"Starting Warmup for input size:
{
x_shape
}
"
)
_
=
fwd
()
if
mode
in
(
"backward"
,
"full"
):
y
=
_
y
.
backward
(
dy
,
retain_graph
=
True
)
print
(
"Done Warmup"
)
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
grad_to_none
=
[
x
],
rep
=
100
,
quantiles
=
QUANTILES
)
elif
mode
==
"backward"
:
y
=
fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
dy
,
retain_graph
=
True
),
grad_to_none
=
[
x
],
rep
=
100
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"full"
:
def
full
():
y
=
fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
grad_to_none
=
[
x
],
rep
=
100
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_multi_token_attention
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
L
=
input
.
x
provider
=
input
.
kernel_provider
extra_benchmark_config
=
input
.
extra_benchmark_config
B
=
extra_benchmark_config
[
"B"
]
C_in
=
extra_benchmark_config
[
"C_in"
]
C_out
=
extra_benchmark_config
[
"C_out"
]
K
=
extra_benchmark_config
[
"K"
]
groups
=
extra_benchmark_config
[
"groups"
]
bias
=
extra_benchmark_config
[
"bias"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
x_shape
=
(
B
,
C_in
,
L
,
L
)
triton_attn
=
(
LigerMultiTokenAttention
(
in_channels
=
C_in
,
out_channels
=
C_out
,
kernel_size
=
K
,
stride
=
1
,
padding
=
K
//
2
,
dilation
=
1
,
groups
=
groups
,
bias
=
bias
,
)
.
to
(
device
)
.
to
(
dtype
)
)
torch_attn
=
TorchMultiTokenAttention
(
C_in
=
C_in
,
C_out
=
C_out
,
K
=
K
,
groups
=
groups
,
bias
=
bias
,
dtype
=
dtype
,
device
=
device
)
with
torch
.
no_grad
():
torch_attn
.
weight
.
copy_
(
triton_attn
.
weight
)
if
bias
:
torch_attn
.
bias
.
copy_
(
triton_attn
.
bias
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
def
fwd
():
if
provider
==
"liger"
:
return
triton_attn
(
x
)
elif
provider
==
"torch"
:
return
torch_attn
(
x
)
def
full
():
y
=
fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs
=
{
"kernel_name"
:
"multi_token_attention"
,
"x_name"
:
"L"
,
"x_label"
:
"sequence length"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
5
,
10
)],
"kernel_providers"
:
[
"liger"
,
"torch"
],
"extra_benchmark_configs"
:
[
{
"B"
:
2
,
"C_in"
:
4
,
"C_out"
:
4
,
"K"
:
3
,
"groups"
:
1
,
"bias"
:
True
,
"dtype"
:
torch
.
bfloat16
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_multi_token_attention
,
kernel_operation_modes
=
[
"forward"
,
"full"
,
"backward"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_multi_token_attention
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs
,
)
benchmark/scripts/benchmark_orpo_loss.py
0 → 100755
View file @
9b0e3a30
import
os
import
sys
import
torch
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../.."
)))
#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
def
bench_memory_fused_linear_orpo_loss
(
input
:
SingleBenchmarkRunInput
,
)
->
SingleBenchmarkRunOutput
:
from
test.chunked_loss.test_orpo_loss
import
LigerLMHeadORPO
from
test.chunked_loss.test_orpo_loss
import
TorchLMHeadORPO
B
=
input
.
x
T
=
input
.
extra_benchmark_config
[
"T"
]
H
=
input
.
extra_benchmark_config
[
"H"
]
V
=
input
.
extra_benchmark_config
[
"V"
]
dtype
=
input
.
extra_benchmark_config
[
"dtype"
]
provider
=
input
.
kernel_provider
# Instantiate once and retrieve the first output only
torch_lm_head_orpo
=
TorchLMHeadORPO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
).
to
(
device
)
liger_lm_head_orpo
=
LigerLMHeadORPO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
).
to
(
device
)
torch_fwd
=
lambda
x
,
target
,
nll_target
:
torch_lm_head_orpo
(
x
,
target
,
nll_target
)[
0
]
liger_fwd
=
lambda
x
,
target
,
nll_target
:
liger_lm_head_orpo
(
x
,
target
,
nll_target
)[
0
]
_input
=
torch
.
randn
(
B
,
T
,
H
,
requires_grad
=
True
,
dtype
=
dtype
,
device
=
device
)
target
=
torch
.
randint
(
V
,
(
B
,
T
),
dtype
=
torch
.
long
,
device
=
device
)
nll_target
=
torch
.
randint
(
V
,
(
B
,
T
),
dtype
=
torch
.
long
,
device
=
device
)
def
fwd
():
if
provider
==
"liger"
:
return
liger_fwd
(
_input
,
target
,
nll_target
)
elif
provider
==
"huggingface"
:
return
torch_fwd
(
_input
,
target
,
nll_target
)
def
full
():
y
=
fwd
()
y
.
backward
()
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
_iter
=
10
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
# #############################################################################
# # Test the speed of the fused linear cross entropy loss
# #############################################################################
def
bench_speed_fused_linear_orpo_loss
(
input
:
SingleBenchmarkRunInput
,
)
->
SingleBenchmarkRunOutput
:
from
test.chunked_loss.test_orpo_loss
import
LigerLMHeadORPO
from
test.chunked_loss.test_orpo_loss
import
TorchLMHeadORPO
B
=
input
.
x
T
=
input
.
extra_benchmark_config
[
"T"
]
H
=
input
.
extra_benchmark_config
[
"H"
]
V
=
input
.
extra_benchmark_config
[
"V"
]
dtype
=
input
.
extra_benchmark_config
[
"dtype"
]
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
# Instantiate once and retrieve the first output only
torch_lm_head_orpo
=
TorchLMHeadORPO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
).
to
(
device
)
liger_lm_head_orpo
=
LigerLMHeadORPO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
).
to
(
device
)
torch_fwd
=
lambda
x
,
target
,
nll_target
:
torch_lm_head_orpo
(
x
,
target
,
nll_target
)[
0
]
liger_fwd
=
lambda
x
,
target
,
nll_target
:
liger_lm_head_orpo
(
x
,
target
,
nll_target
)[
0
]
_input
=
torch
.
randn
(
B
,
T
,
H
,
requires_grad
=
True
,
dtype
=
dtype
,
device
=
device
)
target
=
torch
.
randint
(
V
,
(
B
,
T
),
dtype
=
torch
.
long
,
device
=
device
)
nll_target
=
torch
.
randint
(
V
,
(
B
,
T
),
dtype
=
torch
.
long
,
device
=
device
)
def
fwd
():
if
provider
==
"liger"
:
return
liger_fwd
(
_input
,
target
,
nll_target
)
elif
provider
==
"huggingface"
:
return
torch_fwd
(
_input
,
target
,
nll_target
)
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
rep
=
100
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"backward"
:
y
=
fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
retain_graph
=
True
),
grad_to_none
=
[
_input
],
rep
=
100
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"full"
:
def
full
():
y
=
fwd
()
y
.
backward
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
rep
=
100
,
quantiles
=
QUANTILES
,
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs
=
{
"kernel_name"
:
"fused_linear_orpo_loss"
,
"x_name"
:
"B"
,
"x_label"
:
"B"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
1
,
5
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[
{
"T"
:
1024
,
"H"
:
4096
,
"V"
:
128256
,
"mode"
:
"forward"
,
"dtype"
:
torch
.
bfloat16
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_fused_linear_orpo_loss
,
kernel_operation_modes
=
[
"forward"
,
"full"
,
"backward"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_fused_linear_orpo_loss
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs
,
)
benchmark/scripts/benchmark_poly_norm.py
0 → 100755
View file @
9b0e3a30
import
torch
import
torch.nn
as
nn
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.poly_norm
import
LigerPolyNorm
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
class
NaivePolyNorm
(
nn
.
Module
):
"""
Naive PyTorch implementation of PolyNorm.
Reference:
https://github.com/BryceZhuo/PolyCom/
PolyNorm formula:
y = w₀·norm(x³) + w₁·norm(x²) + w₂·norm(x) + b
where norm(u) = u / sqrt(mean(u²) + ε)
"""
def
__init__
(
self
,
eps
=
1e-6
):
super
().
__init__
()
# Align with PolyCom reference: (1/3, 1/3, 1/3) and bias=1.0
self
.
weight
=
nn
.
Parameter
(
torch
.
full
((
3
,),
1.0
/
3.0
))
self
.
bias
=
nn
.
Parameter
(
torch
.
tensor
(
1.0
))
self
.
variance_epsilon
=
eps
def
_norm
(
self
,
x
):
"""RMSNorm operation"""
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
variance_epsilon
)
def
forward
(
self
,
hidden_states
):
"""
Forward pass of PolyNorm
Args:
hidden_states: input tensor of shape (..., H)
Returns:
output tensor of same shape as input
"""
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
# Compute powers
x_pow3
=
hidden_states
**
3
x_pow2
=
hidden_states
**
2
x_pow1
=
hidden_states
**
1
# Normalize each power
norm_x3
=
self
.
_norm
(
x_pow3
)
norm_x2
=
self
.
_norm
(
x_pow2
)
norm_x1
=
self
.
_norm
(
x_pow1
)
# Weighted sum with bias
output
=
self
.
weight
[
0
]
*
norm_x3
+
self
.
weight
[
1
]
*
norm_x2
+
self
.
weight
[
2
]
*
norm_x1
+
self
.
bias
return
output
.
to
(
input_dtype
)
def
bench_speed_poly_norm
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
N
=
input
.
x
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
extra_benchmark_config
=
input
.
extra_benchmark_config
M
=
extra_benchmark_config
[
"M"
]
eps
=
extra_benchmark_config
[
"eps"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
x_shape
=
(
M
,
N
)
triton_poly
=
LigerPolyNorm
(
eps
=
eps
).
to
(
device
)
naive_poly
=
NaivePolyNorm
(
eps
=
eps
).
to
(
device
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
# utility functions
def
y_fwd
():
if
provider
==
"liger"
:
return
triton_poly
(
x
)
if
provider
==
"huggingface"
:
return
naive_poly
(
x
)
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
y_fwd
,
grad_to_none
=
[
x
],
rep
=
500
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"backward"
:
y
=
y_fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
dy
,
retain_graph
=
True
),
grad_to_none
=
[
x
],
rep
=
500
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"full"
:
def
full
():
y
=
y_fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
grad_to_none
=
[
x
],
rep
=
500
,
quantiles
=
QUANTILES
,
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_poly_norm
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
N
=
input
.
x
provider
=
input
.
kernel_provider
extra_benchmark_config
=
input
.
extra_benchmark_config
M
=
extra_benchmark_config
[
"M"
]
eps
=
extra_benchmark_config
[
"eps"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
x_shape
=
(
M
,
N
)
triton_poly
=
LigerPolyNorm
(
eps
=
eps
).
to
(
device
)
naive_poly
=
NaivePolyNorm
(
eps
=
eps
).
to
(
device
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
# utility functions
def
y_fwd
():
if
provider
==
"liger"
:
return
triton_poly
(
x
)
if
provider
==
"huggingface"
:
return
naive_poly
(
x
)
def
full
():
y
=
y_fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs
=
{
"kernel_name"
:
"poly_norm"
,
"x_name"
:
"H"
,
"x_label"
:
"hidden size"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
10
,
16
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[{
"M"
:
2048
,
"dtype"
:
torch
.
bfloat16
,
"eps"
:
1e-6
}],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_poly_norm
,
kernel_operation_modes
=
[
"forward"
,
"full"
,
"backward"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_poly_norm
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs
,
)
benchmark/scripts/benchmark_qwen2vl_mrope.py
0 → 100755
View file @
9b0e3a30
import
torch
import
triton
from
transformers.models.qwen2_vl.configuration_qwen2_vl
import
Qwen2VLTextConfig
from
transformers.models.qwen2_vl.modeling_qwen2_vl
import
Qwen2VLRotaryEmbedding
from
transformers.models.qwen2_vl.modeling_qwen2_vl
import
apply_multimodal_rotary_pos_emb
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.qwen2vl_mrope
import
liger_multimodal_rotary_pos_emb
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
def
bench_speed_qwen2vl_mrope
(
input
:
SingleBenchmarkRunInput
,
)
->
SingleBenchmarkRunOutput
:
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
extra_benchmark_config
=
input
.
extra_benchmark_config
num_q_heads
=
extra_benchmark_config
[
"num_q_heads"
]
num_kv_heads
=
extra_benchmark_config
[
"num_kv_heads"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
# x can be either hidden_size or seq_len
hidden_size
=
extra_benchmark_config
[
"hidden_size"
]
if
"hidden_size"
in
extra_benchmark_config
else
input
.
x
seq_len
=
extra_benchmark_config
[
"seq_len"
]
if
"seq_len"
in
extra_benchmark_config
else
input
.
x
head_dim
=
hidden_size
//
num_q_heads
mrope_section_hw
=
head_dim
*
3
//
16
mrope_section
=
[
head_dim
//
2
-
2
*
mrope_section_hw
,
mrope_section_hw
,
mrope_section_hw
,
]
config
=
Qwen2VLTextConfig
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_q_heads
,
num_key_value_heads
=
num_kv_heads
,
rope_theta
=
1000000.0
,
mrope_section
=
mrope_section
,
)
rotary_emb
=
Qwen2VLRotaryEmbedding
(
config
,
device
=
device
)
q
=
torch
.
randn
(
(
1
,
seq_len
,
num_q_heads
,
head_dim
),
device
=
device
,
requires_grad
=
True
,
dtype
=
dtype
,
).
transpose
(
1
,
2
)
k
=
torch
.
randn
(
(
1
,
seq_len
,
num_kv_heads
,
head_dim
),
device
=
device
,
requires_grad
=
True
,
dtype
=
dtype
,
).
transpose
(
1
,
2
)
dq
,
dk
=
(
torch
.
randn_like
(
q
,
device
=
device
,
dtype
=
dtype
),
torch
.
randn_like
(
k
,
device
=
device
,
dtype
=
dtype
),
)
pos_ids
=
torch
.
arange
(
seq_len
*
3
,
device
=
device
,
dtype
=
torch
.
long
).
view
(
3
,
1
,
-
1
)
cos
,
sin
=
rotary_emb
(
k
,
pos_ids
)
def
fwd
():
if
provider
==
"liger"
:
return
liger_multimodal_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
mrope_section
)
elif
provider
==
"huggingface"
:
return
apply_multimodal_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
mrope_section
)
else
:
raise
ValueError
(
f
"Invalid provider:
{
provider
}
for M-RoPE embedding"
)
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
grad_to_none
=
[
q
,
k
],
rep
=
400
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"backward"
:
q_out
,
k_out
=
fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
autograd
.
grad
((
q_out
,
k_out
),
(
q
,
k
),
(
dq
,
dk
),
allow_unused
=
True
,
retain_graph
=
True
),
grad_to_none
=
[
q
,
k
],
rep
=
400
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"full"
:
def
full
():
q_out
,
k_out
=
fwd
()
torch
.
autograd
.
grad
((
q_out
,
k_out
),
(
q
,
k
),
(
dq
,
dk
),
allow_unused
=
True
)
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
grad_to_none
=
[
q
,
k
],
rep
=
400
,
quantiles
=
QUANTILES
,
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_qwen2vl_mrope
(
input
:
SingleBenchmarkRunInput
,
)
->
SingleBenchmarkRunOutput
:
provider
=
input
.
kernel_provider
extra_benchmark_config
=
input
.
extra_benchmark_config
num_q_heads
=
extra_benchmark_config
[
"num_q_heads"
]
num_kv_heads
=
extra_benchmark_config
[
"num_kv_heads"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
# x can be either hidden_size or seq_len
hidden_size
=
extra_benchmark_config
[
"hidden_size"
]
if
"hidden_size"
in
extra_benchmark_config
else
input
.
x
seq_len
=
extra_benchmark_config
[
"seq_len"
]
if
"seq_len"
in
extra_benchmark_config
else
input
.
x
head_dim
=
hidden_size
//
num_q_heads
mrope_section_hw
=
head_dim
*
3
//
16
mrope_section
=
[
head_dim
//
2
-
2
*
mrope_section_hw
,
mrope_section_hw
,
mrope_section_hw
,
]
config
=
Qwen2VLTextConfig
(
hidden_size
=
hidden_size
,
num_attention_heads
=
num_q_heads
,
num_key_value_heads
=
num_kv_heads
,
rope_theta
=
1000000.0
,
mrope_section
=
mrope_section
,
)
rotary_emb
=
Qwen2VLRotaryEmbedding
(
config
,
device
=
device
)
q
=
torch
.
randn
(
(
1
,
seq_len
,
num_q_heads
,
head_dim
),
device
=
device
,
requires_grad
=
True
,
dtype
=
dtype
,
).
transpose
(
1
,
2
)
k
=
torch
.
randn
(
(
1
,
seq_len
,
num_kv_heads
,
head_dim
),
device
=
device
,
requires_grad
=
True
,
dtype
=
dtype
,
).
transpose
(
1
,
2
)
dq
,
dk
=
(
torch
.
randn_like
(
q
,
device
=
device
,
dtype
=
dtype
),
torch
.
randn_like
(
k
,
device
=
device
,
dtype
=
dtype
),
)
pos_ids
=
torch
.
arange
(
seq_len
*
3
,
device
=
device
,
dtype
=
torch
.
long
).
view
(
3
,
1
,
-
1
)
cos
,
sin
=
rotary_emb
(
k
,
pos_ids
)
def
full
():
if
provider
==
"liger"
:
q_out
,
k_out
=
liger_multimodal_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
mrope_section
)
else
:
q_out
,
k_out
=
apply_multimodal_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
mrope_section
)
torch
.
autograd
.
grad
((
q_out
,
k_out
),
(
q
,
k
),
(
dq
,
dk
),
allow_unused
=
True
,
retain_graph
=
True
)
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
,
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs_varying_hidden_size
=
{
"kernel_name"
:
"qwen2vl_mrope"
,
"x_name"
:
"H"
,
"x_label"
:
"hidden size"
,
"x_values"
:
[
32
*
(
2
**
i
)
for
i
in
range
(
4
,
10
,
2
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[
{
"dtype"
:
torch
.
bfloat16
,
"seq_len"
:
2048
,
"num_q_heads"
:
32
,
"num_kv_heads"
:
8
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_qwen2vl_mrope
,
kernel_operation_modes
=
[
"forward"
,
"backward"
,
"full"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs_varying_hidden_size
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_qwen2vl_mrope
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs_varying_hidden_size
,
)
common_configs_varying_seq_len
=
{
"kernel_name"
:
"qwen2vl_mrope"
,
"x_name"
:
"T"
,
"x_label"
:
"sequence length"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
10
,
15
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[
{
"dtype"
:
torch
.
bfloat16
,
"hidden_size"
:
8192
,
"num_q_heads"
:
32
,
"num_kv_heads"
:
8
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_qwen2vl_mrope
,
kernel_operation_modes
=
[
"forward"
,
"backward"
,
"full"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs_varying_seq_len
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_qwen2vl_mrope
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs_varying_seq_len
,
)
benchmark/scripts/benchmark_rms_norm.py
0 → 100755
View file @
9b0e3a30
import
torch
import
torch.nn
as
nn
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.rms_norm
import
LigerRMSNorm
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
class
LlamaRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
def
bench_speed_rms_norm
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
N
=
input
.
x
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
extra_benchmark_config
=
input
.
extra_benchmark_config
M
=
extra_benchmark_config
[
"M"
]
eps
=
extra_benchmark_config
[
"eps"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
x_shape
=
(
M
,
N
)
triton_rms
=
LigerRMSNorm
(
hidden_size
=
N
,
eps
=
eps
).
to
(
device
)
llama_rms
=
LlamaRMSNorm
(
hidden_size
=
N
,
eps
=
eps
).
to
(
device
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
# utility functions
def
y_fwd
():
if
provider
==
"liger"
:
return
triton_rms
(
x
)
if
provider
==
"huggingface"
:
return
llama_rms
(
x
)
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
y_fwd
,
grad_to_none
=
[
x
],
rep
=
500
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"backward"
:
y
=
y_fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
dy
,
retain_graph
=
True
),
grad_to_none
=
[
x
],
rep
=
500
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"full"
:
def
full
():
y
=
y_fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
grad_to_none
=
[
x
],
rep
=
500
,
quantiles
=
QUANTILES
,
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_rms_norm
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
N
=
input
.
x
provider
=
input
.
kernel_provider
extra_benchmark_config
=
input
.
extra_benchmark_config
M
=
extra_benchmark_config
[
"M"
]
eps
=
extra_benchmark_config
[
"eps"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
x_shape
=
(
M
,
N
)
triton_rms
=
LigerRMSNorm
(
hidden_size
=
N
,
eps
=
eps
).
to
(
device
)
llama_rms
=
LlamaRMSNorm
(
hidden_size
=
N
,
eps
=
eps
).
to
(
device
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
# utility functions
def
y_fwd
():
if
provider
==
"liger"
:
return
triton_rms
(
x
)
if
provider
==
"huggingface"
:
return
llama_rms
(
x
)
def
full
():
y
=
y_fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs
=
{
"kernel_name"
:
"rms_norm"
,
"x_name"
:
"H"
,
"x_label"
:
"hidden size"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
10
,
16
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[{
"M"
:
2048
,
"dtype"
:
torch
.
bfloat16
,
"eps"
:
1e-6
}],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_rms_norm
,
kernel_operation_modes
=
[
"forward"
,
"full"
,
"backward"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_rms_norm
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs
,
)
benchmark/scripts/benchmark_rope.py
0 → 100755
View file @
9b0e3a30
import
torch
import
triton
from
transformers.models.llama.configuration_llama
import
LlamaConfig
from
transformers.models.llama.modeling_llama
import
LlamaRotaryEmbedding
from
transformers.models.llama.modeling_llama
import
apply_rotary_pos_emb
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.rope
import
liger_rotary_pos_emb
from
liger_kernel.utils
import
infer_device
from
liger_kernel.utils
import
transformers_version_dispatch
device
=
infer_device
()
def
bench_speed_rope
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
extra_benchmark_config
=
input
.
extra_benchmark_config
num_q_heads
=
extra_benchmark_config
[
"num_q_heads"
]
num_kv_heads
=
extra_benchmark_config
[
"num_kv_heads"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
# x can be either hidden_size or seq_len
hidden_size
=
extra_benchmark_config
[
"hidden_size"
]
if
"hidden_size"
in
extra_benchmark_config
else
input
.
x
seq_len
=
extra_benchmark_config
[
"seq_len"
]
if
"seq_len"
in
extra_benchmark_config
else
input
.
x
head_dim
=
hidden_size
//
num_q_heads
rotary_emb
=
transformers_version_dispatch
(
"4.48.0"
,
LlamaRotaryEmbedding
,
LlamaRotaryEmbedding
,
before_kwargs
=
{
"dim"
:
head_dim
,
"device"
:
device
},
after_kwargs
=
{
"config"
:
LlamaConfig
(
num_kv_heads
=
num_kv_heads
,
head_dim
=
head_dim
),
"device"
:
device
},
)
q
=
torch
.
randn
(
(
1
,
seq_len
,
num_q_heads
,
head_dim
),
device
=
device
,
requires_grad
=
True
,
dtype
=
dtype
,
).
transpose
(
1
,
2
)
k
=
torch
.
randn
(
(
1
,
seq_len
,
num_kv_heads
,
head_dim
),
device
=
device
,
requires_grad
=
True
,
dtype
=
dtype
,
).
transpose
(
1
,
2
)
dq
,
dk
=
(
torch
.
randn_like
(
q
,
device
=
device
,
dtype
=
dtype
),
torch
.
randn_like
(
k
,
device
=
device
),
)
pos_ids
=
torch
.
arange
(
seq_len
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
cos
,
sin
=
rotary_emb
(
k
,
pos_ids
)
def
fwd
():
if
provider
==
"liger"
:
return
liger_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
pos_ids
)
elif
provider
==
"huggingface"
:
return
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
pos_ids
)
else
:
raise
ValueError
(
f
"Invalid provider:
{
provider
}
for RoPE embedding"
)
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
grad_to_none
=
[
q
,
k
],
rep
=
400
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"backward"
:
q_out
,
k_out
=
fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
torch
.
autograd
.
grad
((
q_out
,
k_out
),
(
q
,
k
),
(
dq
,
dk
),
allow_unused
=
True
,
retain_graph
=
True
),
grad_to_none
=
[
q
,
k
],
rep
=
400
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"full"
:
def
full
():
q_out
,
k_out
=
fwd
()
torch
.
autograd
.
grad
((
q_out
,
k_out
),
(
q
,
k
),
(
dq
,
dk
),
allow_unused
=
True
)
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
grad_to_none
=
[
q
,
k
],
rep
=
400
,
quantiles
=
QUANTILES
,
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_rope
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
provider
=
input
.
kernel_provider
extra_benchmark_config
=
input
.
extra_benchmark_config
num_q_heads
=
extra_benchmark_config
[
"num_q_heads"
]
num_kv_heads
=
extra_benchmark_config
[
"num_kv_heads"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
# x can be either hidden_size or seq_len
hidden_size
=
extra_benchmark_config
[
"hidden_size"
]
if
"hidden_size"
in
extra_benchmark_config
else
input
.
x
seq_len
=
extra_benchmark_config
[
"seq_len"
]
if
"seq_len"
in
extra_benchmark_config
else
input
.
x
head_dim
=
hidden_size
//
num_q_heads
rotary_emb
=
transformers_version_dispatch
(
"4.48.0"
,
LlamaRotaryEmbedding
,
LlamaRotaryEmbedding
,
before_kwargs
=
{
"dim"
:
head_dim
,
"device"
:
device
},
after_kwargs
=
{
"config"
:
LlamaConfig
(
num_kv_heads
=
num_kv_heads
,
head_dim
=
head_dim
),
"device"
:
device
},
)
q
=
torch
.
randn
(
(
1
,
seq_len
,
num_q_heads
,
head_dim
),
device
=
device
,
requires_grad
=
True
,
dtype
=
dtype
,
).
transpose
(
1
,
2
)
k
=
torch
.
randn
(
(
1
,
seq_len
,
num_kv_heads
,
head_dim
),
device
=
device
,
requires_grad
=
True
,
dtype
=
dtype
,
).
transpose
(
1
,
2
)
dq
,
dk
=
(
torch
.
randn_like
(
q
,
device
=
device
,
dtype
=
dtype
),
torch
.
randn_like
(
k
,
device
=
device
),
)
pos_ids
=
torch
.
arange
(
seq_len
,
device
=
device
,
dtype
=
torch
.
long
).
unsqueeze
(
0
)
cos
,
sin
=
rotary_emb
(
k
,
pos_ids
)
def
full
():
if
provider
==
"liger"
:
q_out
,
k_out
=
liger_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
pos_ids
)
else
:
q_out
,
k_out
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
pos_ids
)
torch
.
autograd
.
grad
((
q_out
,
k_out
),
(
q
,
k
),
(
dq
,
dk
),
allow_unused
=
True
,
retain_graph
=
True
)
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
,
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs_varying_hidden_size
=
{
"kernel_name"
:
"rope"
,
"x_name"
:
"H"
,
"x_label"
:
"hidden size"
,
"x_values"
:
[
32
*
(
2
**
i
)
for
i
in
range
(
4
,
10
,
2
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[
{
"dtype"
:
torch
.
bfloat16
,
"seq_len"
:
2048
,
"num_q_heads"
:
32
,
"num_kv_heads"
:
8
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_rope
,
kernel_operation_modes
=
[
"forward"
,
"backward"
,
"full"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs_varying_hidden_size
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_rope
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs_varying_hidden_size
,
)
common_configs_varying_seq_len
=
{
"kernel_name"
:
"rope"
,
"x_name"
:
"T"
,
"x_label"
:
"sequence length"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
10
,
15
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[
{
"dtype"
:
torch
.
bfloat16
,
"hidden_size"
:
8192
,
"num_q_heads"
:
32
,
"num_kv_heads"
:
8
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_rope
,
kernel_operation_modes
=
[
"forward"
,
"backward"
,
"full"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs_varying_seq_len
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_rope
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs_varying_seq_len
,
)
benchmark/scripts/benchmark_simpo_loss.py
0 → 100755
View file @
9b0e3a30
import
os
import
sys
import
torch
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
sys
.
path
.
insert
(
0
,
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../.."
)))
#############################################################################
# Test the memory consumption of the linear fused cross entropy loss
#############################################################################
def
bench_memory_fused_linear_simpo_loss
(
input
:
SingleBenchmarkRunInput
,
)
->
SingleBenchmarkRunOutput
:
from
test.chunked_loss.test_simpo_loss
import
LigerLMHeadSimPO
from
test.chunked_loss.test_simpo_loss
import
TorchLMHeadCPO
B
=
input
.
x
T
=
input
.
extra_benchmark_config
[
"T"
]
H
=
input
.
extra_benchmark_config
[
"H"
]
V
=
input
.
extra_benchmark_config
[
"V"
]
dtype
=
input
.
extra_benchmark_config
[
"dtype"
]
provider
=
input
.
kernel_provider
# Instantiate once and retrieve the first output only
torch_lm_head_simpo
=
TorchLMHeadCPO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
).
to
(
device
)
liger_lm_head_simpo
=
LigerLMHeadSimPO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
).
to
(
device
)
torch_fwd
=
lambda
x
,
target
:
torch_lm_head_simpo
(
x
,
target
)[
0
]
liger_fwd
=
lambda
x
,
target
:
liger_lm_head_simpo
(
x
,
target
)[
0
]
_input
=
torch
.
randn
(
B
,
T
,
H
,
requires_grad
=
True
,
dtype
=
dtype
,
device
=
device
)
target
=
torch
.
randint
(
V
,
(
B
,
T
),
dtype
=
torch
.
long
,
device
=
device
)
def
fwd
():
if
provider
==
"liger"
:
return
liger_fwd
(
_input
,
target
)
elif
provider
==
"huggingface"
:
return
torch_fwd
(
_input
,
target
)
def
full
():
y
=
fwd
()
y
.
backward
()
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
_iter
=
10
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
# #############################################################################
# # Test the speed of the fused linear cross entropy loss
# #############################################################################
def
bench_speed_fused_linear_simpo_loss
(
input
:
SingleBenchmarkRunInput
,
)
->
SingleBenchmarkRunOutput
:
from
test.chunked_loss.test_simpo_loss
import
LigerLMHeadSimPO
from
test.chunked_loss.test_simpo_loss
import
TorchLMHeadCPO
B
=
input
.
x
T
=
input
.
extra_benchmark_config
[
"T"
]
H
=
input
.
extra_benchmark_config
[
"H"
]
V
=
input
.
extra_benchmark_config
[
"V"
]
dtype
=
input
.
extra_benchmark_config
[
"dtype"
]
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
# Instantiate once and retrieve the first output only
torch_lm_head_simpo
=
TorchLMHeadCPO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
).
to
(
device
)
liger_lm_head_simpo
=
LigerLMHeadSimPO
(
H
=
H
,
V
=
V
,
dtype
=
dtype
).
to
(
device
)
torch_fwd
=
lambda
x
,
target
:
torch_lm_head_simpo
(
x
,
target
)[
0
]
liger_fwd
=
lambda
x
,
target
:
liger_lm_head_simpo
(
x
,
target
)[
0
]
_input
=
torch
.
randn
(
B
,
T
,
H
,
requires_grad
=
True
,
dtype
=
dtype
,
device
=
device
)
target
=
torch
.
randint
(
V
,
(
B
,
T
),
dtype
=
torch
.
long
,
device
=
device
)
def
fwd
():
if
provider
==
"liger"
:
return
liger_fwd
(
_input
,
target
)
elif
provider
==
"huggingface"
:
return
torch_fwd
(
_input
,
target
)
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
rep
=
100
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"backward"
:
y
=
fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
retain_graph
=
True
),
grad_to_none
=
[
_input
],
rep
=
100
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"full"
:
def
full
():
y
=
fwd
()
y
.
backward
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
rep
=
100
,
quantiles
=
QUANTILES
,
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs
=
{
"kernel_name"
:
"fused_linear_simpo_loss"
,
"x_name"
:
"B"
,
"x_label"
:
"B"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
1
,
5
)],
"kernel_providers"
:
[
"liger"
,
"huggingface"
],
"extra_benchmark_configs"
:
[
{
"T"
:
1024
,
"H"
:
4096
,
"V"
:
128256
,
"mode"
:
"forward"
,
"dtype"
:
torch
.
bfloat16
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_fused_linear_simpo_loss
,
kernel_operation_modes
=
[
"forward"
,
"full"
,
"backward"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_fused_linear_simpo_loss
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs
,
)
benchmark/scripts/benchmark_softmax.py
0 → 100755
View file @
9b0e3a30
import
torch
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.softmax
import
LigerSoftmax
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
def
bench_speed_softmax
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
N
=
input
.
x
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
extra_benchmark_config
=
input
.
extra_benchmark_config
M
=
extra_benchmark_config
[
"M"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
x_shape
=
(
M
,
N
)
liger_softmax
=
LigerSoftmax
().
to
(
device
).
to
(
dtype
)
torch_softmax
=
torch
.
nn
.
Softmax
(
dim
=-
1
).
to
(
device
).
to
(
dtype
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
def
y_fwd
():
if
provider
==
"liger"
:
return
liger_softmax
(
x
)
if
provider
==
"torch"
:
return
torch_softmax
(
x
)
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
y_fwd
,
quantiles
=
QUANTILES
,
grad_to_none
=
[
x
],
rep
=
500
)
elif
mode
==
"backward"
:
y
=
y_fwd
()
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
dy
,
retain_graph
=
True
),
quantiles
=
QUANTILES
,
grad_to_none
=
[
x
],
rep
=
500
,
)
elif
mode
==
"full"
:
def
full
():
y
=
y_fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
quantiles
=
QUANTILES
,
grad_to_none
=
[
x
],
rep
=
500
)
if
any
(
val
is
None
for
val
in
(
ms_20
,
ms_50
,
ms_80
)):
raise
RuntimeError
(
f
"Benchmark speed result is None: ms_20=
{
ms_20
}
, ms_50=
{
ms_50
}
, ms_80=
{
ms_80
}
"
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_softmax
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
shape
=
input
.
x
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
extra_benchmark_config
=
input
.
extra_benchmark_config
dtype
=
extra_benchmark_config
.
get
(
"dtype"
,
torch
.
float32
)
torch_softmax
=
torch
.
nn
.
Softmax
(
dim
=-
1
)
liger_softmax
=
LigerSoftmax
().
to
(
device
).
to
(
dtype
)
x
=
torch
.
randn
(
shape
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
def
fwd
():
if
provider
==
"liger"
:
return
liger_softmax
(
x
)
elif
provider
==
"torch"
:
return
torch_softmax
(
x
)
else
:
raise
ValueError
(
f
"Invalid provider:
{
provider
}
for softmax"
)
def
full
():
y
=
fwd
()
y
.
backward
(
torch
.
ones_like
(
y
),
retain_graph
=
True
)
if
mode
==
"forward"
:
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
fwd
,
quantiles
=
QUANTILES
)
elif
mode
==
"backward"
:
do
=
torch
.
ones_like
(
x
)
y
=
fwd
()
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
lambda
:
y
.
backward
(
do
,
retain_graph
=
True
),
quantiles
=
QUANTILES
)
else
:
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
)
if
any
(
val
is
None
for
val
in
(
mem_20
,
mem_50
,
mem_80
)):
raise
RuntimeError
(
f
"Benchmark memory result is None: mem_20=
{
mem_20
}
, mem_50=
{
mem_50
}
, mem_80=
{
mem_80
}
"
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs
=
dict
(
kernel_name
=
"softmax"
,
x_name
=
"N"
,
x_label
=
"hidden size"
,
x_values
=
[
128
,
256
,
512
,
1024
,
2048
,
4096
],
kernel_providers
=
[
"liger"
,
"torch"
],
extra_benchmark_configs
=
[
{
"M"
:
2048
,
"dtype"
:
torch
.
float32
},
{
"M"
:
2048
,
"dtype"
:
torch
.
bfloat16
},
],
)
run_benchmarks
(
bench_test_fn
=
bench_speed_softmax
,
kernel_operation_modes
=
[
"forward"
,
"full"
,
"backward"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
overwrite
=
args
.
overwrite
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_softmax
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
overwrite
=
args
.
overwrite
,
**
common_configs
,
)
benchmark/scripts/benchmark_sparse_multi_token_attention.py
0 → 100755
View file @
9b0e3a30
import
torch
import
triton
from
utils
import
QUANTILES
from
utils
import
SingleBenchmarkRunInput
from
utils
import
SingleBenchmarkRunOutput
from
utils
import
_test_memory
from
utils
import
parse_benchmark_script_args
from
utils
import
run_benchmarks
from
liger_kernel.transformers.multi_token_attention
import
LigerMultiTokenAttention
from
liger_kernel.utils
import
infer_device
device
=
infer_device
()
class
TorchSparseMultiTokenAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
C_in
,
C_out
,
K
,
groups
,
bias
,
dtype
,
device
):
super
().
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
C_out
,
C_in
//
groups
,
K
,
K
,
dtype
=
dtype
,
device
=
device
))
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
C_out
,
dtype
=
dtype
,
device
=
device
))
if
bias
else
None
self
.
K
=
K
self
.
groups
=
groups
self
.
dtype
=
dtype
self
.
compute_dtype
=
torch
.
float32
def
forward
(
self
,
scores
):
B
,
C_in
,
L
,
_
=
scores
.
shape
mask
=
torch
.
tril
(
torch
.
ones
(
L
,
L
,
dtype
=
torch
.
bool
,
device
=
scores
.
device
)).
view
(
1
,
1
,
L
,
L
)
inf
=
torch
.
tensor
(
-
1e9
,
device
=
scores
.
device
,
dtype
=
self
.
compute_dtype
)
zero
=
torch
.
tensor
(
0.0
,
device
=
scores
.
device
,
dtype
=
self
.
compute_dtype
)
s_compute
=
scores
.
to
(
self
.
compute_dtype
)
s_inf
=
s_compute
.
masked_fill
(
~
mask
,
inf
)
dim
=
-
1
z
=
s_inf
z_sorted
,
_
=
torch
.
sort
(
z
,
dim
=
dim
,
descending
=
True
)
cum_sum
=
torch
.
cumsum
(
z_sorted
,
dim
=
dim
)
k_indices
=
torch
.
arange
(
1
,
L
+
1
,
device
=
z
.
device
,
dtype
=
z
.
dtype
).
view
(
1
,
1
,
1
,
L
)
is_positive
=
z_sorted
>
-
1e8
condition
=
(
1
+
k_indices
*
z_sorted
>
cum_sum
)
&
is_positive
k_sparsemax
=
torch
.
sum
(
condition
,
dim
=
dim
,
keepdim
=
True
)
k_sparsemax_safe
=
torch
.
max
(
k_sparsemax
,
torch
.
ones_like
(
k_sparsemax
))
cum_sum_k
=
torch
.
gather
(
cum_sum
,
dim
=
dim
,
index
=
k_sparsemax_safe
.
long
()
-
1
)
tau
=
(
cum_sum_k
-
1
)
/
k_sparsemax_safe
.
to
(
z
.
dtype
)
tau
=
torch
.
where
(
k_sparsemax
==
0
,
torch
.
full_like
(
tau
,
float
(
"inf"
)),
tau
)
probs
=
torch
.
clamp
(
z
-
tau
,
min
=
0
)
weight_compute
=
self
.
weight
.
to
(
self
.
compute_dtype
)
bias_compute
=
self
.
bias
.
to
(
self
.
compute_dtype
)
if
self
.
bias
is
not
None
else
None
out_c
=
torch
.
nn
.
functional
.
conv2d
(
probs
,
weight_compute
,
bias_compute
,
stride
=
1
,
padding
=
self
.
K
//
2
,
groups
=
self
.
groups
)
return
out_c
.
masked_fill
(
~
mask
,
zero
).
to
(
scores
.
dtype
)
def
bench_speed_sparse_multi_token_attention
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
L
=
input
.
x
provider
=
input
.
kernel_provider
mode
=
input
.
kernel_operation_mode
extra_benchmark_config
=
input
.
extra_benchmark_config
B
=
extra_benchmark_config
[
"B"
]
C_in
=
extra_benchmark_config
[
"C_in"
]
C_out
=
extra_benchmark_config
[
"C_out"
]
K
=
extra_benchmark_config
[
"K"
]
groups
=
extra_benchmark_config
[
"groups"
]
bias
=
extra_benchmark_config
[
"bias"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
x_shape
=
(
B
,
C_in
,
L
,
L
)
liger_attn
=
(
LigerMultiTokenAttention
(
in_channels
=
C_in
,
out_channels
=
C_out
,
kernel_size
=
K
,
stride
=
1
,
padding
=
K
//
2
,
dilation
=
1
,
groups
=
groups
,
bias
=
bias
,
sparse
=
True
,
)
.
to
(
device
)
.
to
(
dtype
)
)
torch_attn
=
TorchSparseMultiTokenAttention
(
C_in
=
C_in
,
C_out
=
C_out
,
K
=
K
,
groups
=
groups
,
bias
=
bias
,
dtype
=
dtype
,
device
=
device
)
with
torch
.
no_grad
():
torch
.
nn
.
init
.
kaiming_uniform_
(
liger_attn
.
weight
,
a
=
5
**
0.5
)
if
bias
:
torch
.
nn
.
init
.
zeros_
(
liger_attn
.
bias
)
torch_attn
.
weight
.
copy_
(
liger_attn
.
weight
)
if
bias
:
torch_attn
.
bias
.
copy_
(
liger_attn
.
bias
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
def
fwd
():
if
provider
==
"liger"
:
return
liger_attn
(
x
)
elif
provider
==
"torch"
:
return
torch_attn
(
x
)
print
(
f
"Starting Warmup for input size:
{
x_shape
}
"
)
_
=
fwd
()
if
mode
in
(
"backward"
,
"full"
):
y
=
_
y
.
backward
(
dy
,
retain_graph
=
True
)
print
(
"Done Warmup"
)
if
mode
==
"forward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
fwd
,
grad_to_none
=
[
x
],
rep
=
100
,
quantiles
=
QUANTILES
)
elif
mode
==
"backward"
:
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
lambda
:
y
.
backward
(
dy
,
retain_graph
=
True
),
grad_to_none
=
[
x
],
rep
=
100
,
quantiles
=
QUANTILES
,
)
elif
mode
==
"full"
:
def
full
():
y
=
fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
ms_50
,
ms_20
,
ms_80
=
triton
.
testing
.
do_bench
(
full
,
grad_to_none
=
[
x
],
rep
=
100
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
ms_20
,
y_50
=
ms_50
,
y_80
=
ms_80
,
)
def
bench_memory_sparse_multi_token_attention
(
input
:
SingleBenchmarkRunInput
)
->
SingleBenchmarkRunOutput
:
L
=
input
.
x
provider
=
input
.
kernel_provider
extra_benchmark_config
=
input
.
extra_benchmark_config
B
=
extra_benchmark_config
[
"B"
]
C_in
=
extra_benchmark_config
[
"C_in"
]
C_out
=
extra_benchmark_config
[
"C_out"
]
K
=
extra_benchmark_config
[
"K"
]
groups
=
extra_benchmark_config
[
"groups"
]
bias
=
extra_benchmark_config
[
"bias"
]
dtype
=
extra_benchmark_config
[
"dtype"
]
x_shape
=
(
B
,
C_in
,
L
,
L
)
liger_attn
=
(
LigerMultiTokenAttention
(
in_channels
=
C_in
,
out_channels
=
C_out
,
kernel_size
=
K
,
stride
=
1
,
padding
=
K
//
2
,
dilation
=
1
,
groups
=
groups
,
bias
=
bias
,
sparse
=
True
,
)
.
to
(
device
)
.
to
(
dtype
)
)
torch_attn
=
TorchSparseMultiTokenAttention
(
C_in
=
C_in
,
C_out
=
C_out
,
K
=
K
,
groups
=
groups
,
bias
=
bias
,
dtype
=
dtype
,
device
=
device
)
with
torch
.
no_grad
():
torch
.
nn
.
init
.
kaiming_uniform_
(
liger_attn
.
weight
,
a
=
5
**
0.5
)
if
bias
:
torch
.
nn
.
init
.
zeros_
(
liger_attn
.
bias
)
torch_attn
.
weight
.
copy_
(
liger_attn
.
weight
)
if
bias
:
torch_attn
.
bias
.
copy_
(
liger_attn
.
bias
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
device
)
dy
=
torch
.
randn_like
(
x
)
x
.
requires_grad_
(
True
)
def
fwd
():
if
provider
==
"liger"
:
return
liger_attn
(
x
)
elif
provider
==
"torch"
:
return
torch_attn
(
x
)
def
full
():
y
=
fwd
()
y
.
backward
(
dy
,
retain_graph
=
True
)
mem_50
,
mem_20
,
mem_80
=
_test_memory
(
full
,
quantiles
=
QUANTILES
)
return
SingleBenchmarkRunOutput
(
y_20
=
mem_20
,
y_50
=
mem_50
,
y_80
=
mem_80
,
)
if
__name__
==
"__main__"
:
args
=
parse_benchmark_script_args
()
common_configs
=
{
"kernel_name"
:
"sparse_multi_token_attention"
,
"x_name"
:
"L"
,
"x_label"
:
"sequence length"
,
"x_values"
:
[
2
**
i
for
i
in
range
(
5
,
10
)],
"kernel_providers"
:
[
"liger"
,
"torch"
],
"extra_benchmark_configs"
:
[
{
"B"
:
2
,
"C_in"
:
4
,
"C_out"
:
4
,
"K"
:
3
,
"groups"
:
1
,
"bias"
:
True
,
"dtype"
:
torch
.
float32
,
}
],
"overwrite"
:
args
.
overwrite
,
}
run_benchmarks
(
bench_test_fn
=
bench_speed_sparse_multi_token_attention
,
kernel_operation_modes
=
[
"forward"
,
"full"
,
"backward"
],
metric_name
=
"speed"
,
metric_unit
=
"ms"
,
**
common_configs
,
)
run_benchmarks
(
bench_test_fn
=
bench_memory_sparse_multi_token_attention
,
kernel_operation_modes
=
[
"full"
],
metric_name
=
"memory"
,
metric_unit
=
"MB"
,
**
common_configs
,
)
Prev
1
2
3
4
5
6
…
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment