Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
711aa9d5
Commit
711aa9d5
authored
Jul 30, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.10.0' into v0.10.0-dev
parents
751c492c
6d8d0a24
Changes
519
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4072 additions
and
168 deletions
+4072
-168
benchmarks/benchmark_dataset.py
benchmarks/benchmark_dataset.py
+6
-1
benchmarks/benchmark_serving.py
benchmarks/benchmark_serving.py
+8
-96
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+6
-0
benchmarks/kernels/bench_nvfp4_gemm.py
benchmarks/kernels/bench_nvfp4_gemm.py
+141
-0
benchmarks/kernels/bench_per_token_quant_fp8.py
benchmarks/kernels/bench_per_token_quant_fp8.py
+98
-0
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+81
-25
benchmarks/kernels/benchmark_moe_align_block_size.py
benchmarks/kernels/benchmark_moe_align_block_size.py
+2
-5
benchmarks/kernels/benchmark_trtllm_attention.py
benchmarks/kernels/benchmark_trtllm_attention.py
+240
-0
benchmarks/kv_cache/benchmark_block_pool.py
benchmarks/kv_cache/benchmark_block_pool.py
+108
-0
cmake/cpu_extension.cmake
cmake/cpu_extension.cmake
+24
-4
csrc/attention/attention_kernels.cuh
csrc/attention/attention_kernels.cuh
+1
-7
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
+372
-0
csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp
...mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp
+203
-0
csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp
...s_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp
+2023
-0
csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp
...mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp
+165
-0
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
+283
-0
csrc/attention/paged_attention_v1.cu
csrc/attention/paged_attention_v1.cu
+1
-7
csrc/attention/paged_attention_v2.cu
csrc/attention/paged_attention_v2.cu
+1
-7
csrc/cpu/cpu_types_arm.hpp
csrc/cpu/cpu_types_arm.hpp
+264
-3
csrc/cpu/dnnl_helper.hpp
csrc/cpu/dnnl_helper.hpp
+45
-13
No files found.
Too many changes to show.
To preserve performance only
519 of 519+
files are displayed.
Plain diff
Email patch
benchmarks/benchmark_dataset.py
View file @
711aa9d5
...
...
@@ -324,6 +324,9 @@ class RandomDataset(BenchmarkDataset):
input_low
=
int
(
real_input_len
*
(
1
-
range_ratio
))
input_high
=
int
(
real_input_len
*
(
1
+
range_ratio
))
output_low
=
int
(
output_len
*
(
1
-
range_ratio
))
# Ensure the lower bound for output length is at least 1 to prevent
# sampling 0 tokens, which can cause request failures.
output_low
=
max
(
output_low
,
1
)
output_high
=
int
(
output_len
*
(
1
+
range_ratio
))
# Add logging for debugging
...
...
@@ -701,6 +704,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self
,
dataset_path
:
str
,
dataset_split
:
str
,
no_stream
:
bool
=
False
,
dataset_subset
:
Optional
[
str
]
=
None
,
**
kwargs
,
)
->
None
:
...
...
@@ -708,6 +712,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self
.
dataset_split
=
dataset_split
self
.
dataset_subset
=
dataset_subset
self
.
load_stream
=
not
no_stream
self
.
load_data
()
def
load_data
(
self
)
->
None
:
...
...
@@ -716,7 +721,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self
.
dataset_path
,
name
=
self
.
dataset_subset
,
split
=
self
.
dataset_split
,
streaming
=
True
,
streaming
=
self
.
load_stream
,
)
self
.
data
=
self
.
data
.
shuffle
(
seed
=
self
.
random_seed
)
...
...
benchmarks/benchmark_serving.py
View file @
711aa9d5
...
...
@@ -30,7 +30,7 @@ import os
import
random
import
time
import
warnings
from
collections.abc
import
AsyncGenerator
,
Iterable
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
datetime
import
datetime
from
typing
import
Any
,
Literal
,
Optional
...
...
@@ -73,6 +73,7 @@ from benchmark_dataset import (
VisionArenaDataset
,
)
from
benchmark_utils
import
convert_to_pytorch_benchmark_format
,
write_to_json
from
vllm.benchmarks.serve
import
get_request
MILLISECONDS_TO_SECONDS_CONVERSION
=
1000
...
...
@@ -107,101 +108,6 @@ class BenchmarkMetrics:
percentiles_e2el_ms
:
list
[
tuple
[
float
,
float
]]
def
_get_current_request_rate
(
ramp_up_strategy
:
Optional
[
Literal
[
"linear"
,
"exponential"
]],
ramp_up_start_rps
:
Optional
[
int
],
ramp_up_end_rps
:
Optional
[
int
],
request_index
:
int
,
total_requests
:
int
,
request_rate
:
float
,
)
->
float
:
if
(
ramp_up_strategy
and
ramp_up_start_rps
is
not
None
and
ramp_up_end_rps
is
not
None
):
progress
=
request_index
/
max
(
total_requests
-
1
,
1
)
if
ramp_up_strategy
==
"linear"
:
increase
=
(
ramp_up_end_rps
-
ramp_up_start_rps
)
*
progress
return
ramp_up_start_rps
+
increase
elif
ramp_up_strategy
==
"exponential"
:
ratio
=
ramp_up_end_rps
/
ramp_up_start_rps
return
ramp_up_start_rps
*
(
ratio
**
progress
)
else
:
raise
ValueError
(
f
"Unknown ramp-up strategy:
{
ramp_up_strategy
}
"
)
return
request_rate
async
def
get_request
(
input_requests
:
list
[
SampleRequest
],
request_rate
:
float
,
burstiness
:
float
=
1.0
,
ramp_up_strategy
:
Optional
[
Literal
[
"linear"
,
"exponential"
]]
=
None
,
ramp_up_start_rps
:
Optional
[
int
]
=
None
,
ramp_up_end_rps
:
Optional
[
int
]
=
None
,
)
->
AsyncGenerator
[
tuple
[
SampleRequest
,
float
],
None
]:
"""
Asynchronously generates requests at a specified rate
with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
Args:
input_requests:
A list of input requests, each represented as a SampleRequest.
request_rate:
The rate at which requests are generated (requests/s).
burstiness (optional):
The burstiness factor of the request generation.
Only takes effect when request_rate is not inf.
Default value is 1, which follows a Poisson process.
Otherwise, the request intervals follow a gamma distribution.
A lower burstiness value (0 < burstiness < 1) results
in more bursty requests, while a higher burstiness value
(burstiness > 1) results in a more uniform arrival of requests.
ramp_up_strategy (optional):
The ramp-up strategy. Can be "linear" or "exponential".
If None, uses constant request rate (specified by request_rate).
ramp_up_start_rps (optional):
The starting request rate for ramp-up.
ramp_up_end_rps (optional):
The ending request rate for ramp-up.
"""
assert
burstiness
>
0
,
(
f
"A positive burstiness factor is expected, but given
{
burstiness
}
."
)
# Convert to list to get length for ramp-up calculations
if
isinstance
(
input_requests
,
Iterable
)
and
not
isinstance
(
input_requests
,
list
):
input_requests
=
list
(
input_requests
)
total_requests
=
len
(
input_requests
)
request_index
=
0
for
request
in
input_requests
:
current_request_rate
=
_get_current_request_rate
(
ramp_up_strategy
,
ramp_up_start_rps
,
ramp_up_end_rps
,
request_index
,
total_requests
,
request_rate
,
)
yield
request
,
current_request_rate
request_index
+=
1
if
current_request_rate
==
float
(
"inf"
):
# If the request rate is infinity, then we don't need to wait.
continue
theta
=
1.0
/
(
current_request_rate
*
burstiness
)
# Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution.
interval
=
np
.
random
.
gamma
(
shape
=
burstiness
,
scale
=
theta
)
# The next request will be sent after the interval.
await
asyncio
.
sleep
(
interval
)
def
calculate_metrics
(
input_requests
:
list
[
SampleRequest
],
outputs
:
list
[
RequestFuncOutput
],
...
...
@@ -825,6 +731,7 @@ def main(args: argparse.Namespace):
dataset_subset
=
args
.
hf_subset
,
dataset_split
=
args
.
hf_split
,
random_seed
=
args
.
seed
,
no_stream
=
args
.
no_stream
,
).
sample
(
num_requests
=
args
.
num_prompts
,
tokenizer
=
tokenizer
,
...
...
@@ -1033,6 +940,11 @@ def create_argument_parser():
help
=
"Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset."
,
)
parser
.
add_argument
(
"--no-stream"
,
action
=
"store_true"
,
help
=
"Do not load the dataset in streaming mode."
,
)
parser
.
add_argument
(
"--max-concurrency"
,
type
=
int
,
...
...
benchmarks/benchmark_throughput.py
View file @
711aa9d5
...
...
@@ -410,6 +410,7 @@ def get_requests(args, tokenizer):
elif
args
.
dataset_name
==
"burstgpt"
:
dataset_cls
=
BurstGPTDataset
elif
args
.
dataset_name
==
"hf"
:
common_kwargs
[
"no_stream"
]
=
args
.
no_stream
if
args
.
dataset_path
in
VisionArenaDataset
.
SUPPORTED_DATASET_PATHS
:
dataset_cls
=
VisionArenaDataset
common_kwargs
[
"dataset_subset"
]
=
None
...
...
@@ -666,6 +667,11 @@ def create_argument_parser():
help
=
"Name of the dataset to benchmark on."
,
default
=
"sharegpt"
,
)
parser
.
add_argument
(
"--no-stream"
,
action
=
"store_true"
,
help
=
"Do not load the dataset in streaming mode."
,
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
...
...
benchmarks/kernels/bench_nvfp4_gemm.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.triton_utils
import
triton
if
not
current_platform
.
has_device_capability
(
100
):
raise
RuntimeError
(
"NVFP4 requires compute capability of 10.0 (Blackwell)"
)
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
PROVIDER_CFGS
=
{
"torch-bf16"
:
dict
(
enabled
=
True
),
"nvfp4"
:
dict
(
no_a_quant
=
False
,
enabled
=
True
),
"nvfp4-noquant"
:
dict
(
no_a_quant
=
True
,
enabled
=
True
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
_quant_weight_nvfp4
(
b
:
torch
.
Tensor
,
device
:
str
):
# Compute global scale for weight
b_amax
=
torch
.
abs
(
b
).
max
().
to
(
torch
.
float32
)
b_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
b_amax
b_fp4
,
scale_b_fp4
=
ops
.
scaled_fp4_quant
(
b
,
b_global_scale
)
return
b_fp4
,
scale_b_fp4
,
b_global_scale
def
build_nvfp4_runner
(
cfg
,
a
,
b
,
dtype
,
device
):
b_fp4
,
scale_b_fp4
,
b_global_scale
=
_quant_weight_nvfp4
(
b
,
device
)
# Compute global scale for activation
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
a_amax
=
torch
.
abs
(
a
).
max
().
to
(
torch
.
float32
)
a_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
a_amax
# Alpha for the GEMM operation
alpha
=
1.0
/
(
a_global_scale
*
b_global_scale
)
if
cfg
[
"no_a_quant"
]:
# Pre-quantize activation
a_fp4
,
scale_a_fp4
=
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
)
def
run
():
return
ops
.
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
scale_a_fp4
,
scale_b_fp4
,
alpha
,
dtype
)
return
run
# Quantize activation on-the-fly
def
run
():
a_fp4
,
scale_a_fp4
=
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
)
return
ops
.
cutlass_scaled_fp4_mm
(
a_fp4
,
b_fp4
,
scale_a_fp4
,
scale_b_fp4
,
alpha
,
dtype
)
return
run
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
16384
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"TFLOP/s (larger is better)"
,
plot_name
=
"BF16 vs NVFP4 GEMMs"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
b
=
torch
.
randn
((
N
,
K
),
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch-bf16"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
torch
.
nn
.
functional
.
linear
(
a
,
b
),
quantiles
=
quantiles
)
else
:
cfg
=
PROVIDER_CFGS
[
provider
]
run_quant
=
build_nvfp4_runner
(
cfg
,
a
,
b
,
dtype
,
device
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
run_quant
(),
quantiles
=
quantiles
)
to_tflops
=
lambda
t_ms
:
(
2
*
M
*
N
*
K
)
*
1e-12
/
(
t_ms
*
1e-3
)
return
to_tflops
(
ms
),
to_tflops
(
max_ms
),
to_tflops
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
args
=
parser
.
parse_args
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
{
model
}
, N=
{
N
}
K=
{
K
}
, BF16 vs NVFP4 GEMMs TFLOP/s:"
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
,
save_path
=
f
"bench_nvfp4_res_n
{
N
}
_k
{
K
}
"
,
N
=
N
,
K
=
K
,
)
print
(
"Benchmark finished!"
)
benchmarks/kernels/bench_per_token_quant_fp8.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
itertools
from
typing
import
Callable
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
CompilationConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.triton_utils
import
triton
# TODO(luka): use standalone_compile utility
def
with_dyn_arg
(
fn
:
Callable
,
arg_index
:
int
,
dim_index
:
int
):
def
inner
(
*
args
):
torch
.
_dynamo
.
mark_dynamic
(
args
[
arg_index
],
dim_index
)
return
fn
(
*
args
)
return
inner
torch
.
_dynamo
.
config
.
recompile_limit
=
8888
compilation_config
=
CompilationConfig
(
custom_ops
=
[
"none"
])
with
set_current_vllm_config
(
VllmConfig
(
compilation_config
=
compilation_config
)):
torch_per_token_quant_fp8
=
torch
.
compile
(
QuantFP8
(
False
,
GroupShape
.
PER_TOKEN
),
fullgraph
=
True
,
dynamic
=
False
,
# recompile for different shapes
)
# First dim is explicitly dynamic to simulate vLLM usage
torch_per_token_quant_fp8
=
with_dyn_arg
(
torch_per_token_quant_fp8
,
0
,
0
)
def
cuda_per_token_quant_fp8
(
input
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
ops
.
scaled_fp8_quant
(
input
)
def
calculate_diff
(
batch_size
:
int
,
seq_len
:
int
):
"""Calculate difference between Triton and CUDA implementations."""
device
=
torch
.
device
(
"cuda"
)
x
=
torch
.
rand
((
batch_size
*
seq_len
,
4096
),
dtype
=
torch
.
float16
,
device
=
device
)
torch_out
,
torch_scale
=
torch_per_token_quant_fp8
(
x
)
cuda_out
,
cuda_scale
=
cuda_per_token_quant_fp8
(
x
)
if
torch
.
allclose
(
cuda_out
.
to
(
torch
.
float32
),
torch_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
and
torch
.
allclose
(
cuda_scale
,
torch_scale
,
rtol
=
1e-3
,
atol
=
1e-5
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
batch_size_range
=
[
1
,
16
,
32
,
64
,
128
]
seq_len_range
=
[
1
,
16
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_len_range
))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"torch"
,
"cuda"
],
line_names
=
[
"Torch"
,
"CUDA"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"per-token-dynamic-quant-fp8-performance"
,
args
=
{},
)
)
def
benchmark_quantization
(
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
float16
device
=
torch
.
device
(
"cuda"
)
x
=
torch
.
randn
(
batch_size
*
seq_len
,
4096
,
device
=
device
,
dtype
=
dtype
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"torch"
:
fn
=
lambda
:
torch_per_token_quant_fp8
(
x
.
clone
())
elif
provider
==
"cuda"
:
fn
=
lambda
:
cuda_per_token_quant_fp8
(
x
.
clone
())
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
fn
,
quantiles
=
quantiles
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
if
__name__
==
"__main__"
:
calculate_diff
(
batch_size
=
4
,
seq_len
=
4096
)
benchmark_quantization
.
run
(
print_data
=
True
)
benchmarks/kernels/benchmark_moe.py
View file @
711aa9d5
...
...
@@ -7,19 +7,19 @@ import time
from
contextlib
import
nullcontext
from
datetime
import
datetime
from
itertools
import
product
from
typing
import
Any
,
TypedDict
from
typing
import
Any
,
TypedDict
,
Optional
import
ray
import
torch
from
ray.experimental.tqdm_ray
import
tqdm
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.config
import
get_config
from
vllm.triton_utils
import
triton
from
vllm.utils
import
FlexibleArgumentParser
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
# 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype()
class
BenchmarkConfig
(
TypedDict
):
...
...
@@ -47,8 +47,11 @@ def benchmark_config(
use_deep_gemm
:
bool
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
float
:
from
vllm.platforms
import
current_platform
device
=
torch
.
cuda
.
current_device
()
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
if
use_int8_w8a16
:
if
not
nn_moe
:
w1
=
torch
.
randint
(
...
...
@@ -60,6 +63,7 @@ def benchmark_config(
hidden_size
,
),
dtype
=
torch
.
int8
,
device
=
device
,
)
w2
=
torch
.
randint
(
-
127
,
...
...
@@ -70,6 +74,7 @@ def benchmark_config(
shard_intermediate_size
//
2
,
),
dtype
=
torch
.
int8
,
device
=
device
,
)
else
:
w1
=
torch
.
randint
(
...
...
@@ -81,6 +86,7 @@ def benchmark_config(
shard_intermediate_size
,
),
dtype
=
torch
.
int8
,
device
=
device
,
)
w2
=
torch
.
randint
(
-
127
,
...
...
@@ -91,23 +97,24 @@ def benchmark_config(
hidden_size
,
),
dtype
=
torch
.
int8
,
device
=
device
,
)
else
:
if
not
nn_moe
:
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
,
device
=
device
)
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
,
device
=
device
)
else
:
w1
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
,
dtype
=
init_dtype
num_experts
,
hidden_size
,
shard_intermediate_size
,
dtype
=
init_dtype
,
device
=
device
)
w2
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
,
dtype
=
init_dtype
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
,
dtype
=
init_dtype
,
device
=
device
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
w1_scale
=
None
w2_scale
=
None
...
...
@@ -115,9 +122,12 @@ def benchmark_config(
a2_scale
=
None
if
use_int8_w8a16
:
w1_scale
=
torch
.
randn
(
(
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
(
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
,
device
=
device
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
device
)
if
use_deep_gemm
:
# we use the default block shape for deepgemm
block_quant_shape
=
[
128
,
128
]
if
use_fp8_w8a8
:
if
block_quant_shape
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
...
...
@@ -130,24 +140,26 @@ def benchmark_config(
k_tiles_w1
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles_w2
=
(
N
+
block_k
-
1
)
//
block_k
w1_scale
=
(
torch
.
rand
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
)
torch
.
rand
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
,
device
=
device
)
*
factor_for_scale
)
w2_scale
=
(
torch
.
rand
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
)
torch
.
rand
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
,
device
=
device
)
*
factor_for_scale
)
else
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
# 获取 FP8_DTYPE
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
w1
=
w1
.
to
(
FP8_DTYPE
)
w2
=
w2
.
to
(
FP8_DTYPE
)
input_gating
=
torch
.
empty
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
input_gating
=
torch
.
empty
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
...
...
@@ -266,6 +278,9 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
def
get_configs_compute_bound
(
use_fp16
,
block_quant_shape
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
list
[
dict
[
str
,
int
]]:
configs
:
list
[
BenchmarkConfig
]
=
[]
# 局部导入 current_platform
from
vllm.platforms
import
current_platform
if
current_platform
.
is_rocm
():
param_ranges
=
get_rocm_tuning_space
(
use_fp16
,
nn_moe
)
...
...
@@ -426,12 +441,18 @@ def merge_unique_dicts(list1, list2):
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
,
device_id
:
int
)
->
None
:
torch
.
set_default_device
(
"cuda:"
+
str
(
device_id
))
from
vllm.platforms
import
current_platform
import
os
if
current_platform
.
is_rocm
():
# In ROCm environment with Ray, let Ray handle device assignment
# Don't manually set default device as it may conflict with Ray's device mapping
pass
else
:
torch
.
set_default_device
(
"cuda:"
+
str
(
device_id
))
current_platform
.
seed_everything
(
seed
)
self
.
seed
=
seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
# Store the logical device ID for Ray
self
.
device_id
=
device_id
def
benchmark
(
...
...
@@ -448,7 +469,13 @@ class BenchmarkWorker:
use_deep_gemm
:
bool
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
# 局部导入 current_platform
from
vllm.platforms
import
current_platform
current_platform
.
seed_everything
(
self
.
seed
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
get_config_dtype_str
,
get_moe_configs
,
get_default_config
)
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
...
...
@@ -502,6 +529,9 @@ class BenchmarkWorker:
use_deep_gemm
:
bool
,
nn_moe
:
Optional
[
bool
]
=
False
,
)
->
dict
[
str
,
int
]:
from
vllm.platforms
import
current_platform
import
os
best_config
=
None
best_time
=
float
(
"inf"
)
if
current_platform
.
is_rocm
():
...
...
@@ -515,10 +545,16 @@ class BenchmarkWorker:
topk
,
)
# In ROCm environments with Ray, device context is already handled by Ray
# Using torch.cuda.device() may cause device ordinal conflicts
need_device_guard
=
False
if
current_platform
.
is_rocm
():
visible_device
=
os
.
environ
.
get
(
"ROCR_VISIBLE_DEVICES"
,
None
)
if
visible_device
!=
f
"
{
self
.
device_id
}
"
:
# For ROCm with Ray, skip additional device context management
need_device_guard
=
False
else
:
# For other platforms, use device guard if needed
visible_devices
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
None
)
if
visible_devices
is
not
None
and
len
(
visible_devices
.
split
(
','
))
>
1
:
need_device_guard
=
True
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
need_device_guard
else
nullcontext
():
...
...
@@ -587,6 +623,10 @@ def save_configs(
block_quant_shape
:
list
[
int
],
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
None
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
get_config_dtype_str
,
get_config_file_name
)
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
...
...
@@ -611,6 +651,13 @@ def get_weight_block_size_safety(config, default_value=None):
def
main
(
args
:
argparse
.
Namespace
):
import
os
import
logging
from
vllm.platforms
import
current_platform
logger
=
logging
.
getLogger
(
__name__
)
print
(
args
)
tp_size
=
args
.
tp_size
...
...
@@ -628,7 +675,11 @@ def main(args: argparse.Namespace):
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
(
"DeepseekV3ForCausalLM"
,
"DeepseekV2ForCausalLM"
,
"Glm4MoeForCausalLM"
):
elif
config
.
architectures
[
0
]
in
(
"DeepseekV3ForCausalLM"
,
"DeepseekV2ForCausalLM"
,
"Glm4MoeForCausalLM"
,
):
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
...
...
@@ -638,6 +689,11 @@ def main(args: argparse.Namespace):
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
(
"HunYuanMoEV1ForCausalLM"
):
E
=
config
.
num_experts
topk
=
config
.
moe_topk
[
0
]
intermediate_size
=
config
.
moe_intermediate_size
[
0
]
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
else
:
# Support for llama4
config
=
config
.
get_text_config
()
...
...
benchmarks/kernels/benchmark_moe_align_block_size.py
View file @
711aa9d5
...
...
@@ -33,15 +33,13 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
sorted_ids_triton
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids_triton
.
fill_
(
topk_ids
.
numel
())
# fill with sentinel value
expert_ids_triton
=
torch
.
zeros
(
expert_ids_triton
=
torch
.
empty
(
(
max_num_tokens_padded
//
block_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
num_tokens_post_pad_triton
=
torch
.
empty
((
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids_vllm
=
torch
.
empty_like
(
sorted_ids_triton
)
sorted_ids_vllm
.
fill_
(
topk_ids
.
numel
())
expert_ids_vllm
=
torch
.
zeros_like
(
expert_ids_triton
)
expert_ids_vllm
=
torch
.
empty_like
(
expert_ids_triton
)
num_tokens_post_pad_vllm
=
torch
.
empty_like
(
num_tokens_post_pad_triton
)
# 2. run implementations
...
...
@@ -102,7 +100,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
num_tokens_post_pad
=
torch
.
empty
((
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
...
benchmarks/kernels/benchmark_trtllm_attention.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
csv
import
os
import
random
from
datetime
import
datetime
import
flashinfer
import
torch
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
def
to_float8
(
x
,
dtype
=
torch
.
float8_e4m3fn
):
finfo
=
torch
.
finfo
(
dtype
)
min_val
,
max_val
=
x
.
aminmax
()
amax
=
torch
.
maximum
(
min_val
.
abs
(),
max_val
.
abs
()).
clamp
(
min
=
1e-12
)
scale
=
finfo
.
max
/
amax
*
0.1
x_scl_sat
=
(
x
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
x_scl_sat
.
to
(
dtype
),
scale
.
float
().
reciprocal
()
@
torch
.
no_grad
()
def
benchmark_decode
(
num_seqs
,
max_seq_len
,
page_size
=
16
,
dtype
=
torch
.
bfloat16
,
kv_layout
=
"HND"
,
num_kv_heads
=
8
,
kv_cache_dtype
=
"auto"
,
head_dim
=
128
,
warmup
=
10
,
trials
=
20
,
):
torch
.
set_default_device
(
"cuda"
)
device
=
"cuda"
torch
.
manual_seed
(
0
)
# Currently only HEAD_GRP_SIZE == 8 is supported
HEAD_GRP_SIZE
=
8
MAX_SEQ_LEN
=
max_seq_len
# large number to reduce kv_cache reuse
NUM_BLOCKS
=
int
(
256000
/
page_size
)
workspace_buffer
=
torch
.
empty
(
1024
*
1024
*
1024
,
dtype
=
torch
.
int8
,
device
=
device
)
# For decode, batch_size is num_decode_token
num_qo_heads
=
num_kv_heads
*
HEAD_GRP_SIZE
sm_scale
=
float
(
1.0
/
(
head_dim
**
0.5
))
q
=
torch
.
randn
(
num_seqs
,
num_qo_heads
,
head_dim
,
device
=
device
,
dtype
=
dtype
)
kv_lens
=
[
random
.
randint
(
1
,
MAX_SEQ_LEN
)
for
_
in
range
(
num_seqs
)]
max_kv_len
=
max
(
kv_lens
)
kv_lens_tensor
=
torch
.
tensor
(
kv_lens
,
dtype
=
torch
.
int
,
device
=
device
)
max_num_blocks_per_seq
=
(
max_kv_len
+
page_size
-
1
)
//
page_size
block_tables
=
torch
.
randint
(
0
,
NUM_BLOCKS
,
(
num_seqs
,
max_num_blocks_per_seq
),
dtype
=
torch
.
int32
)
kv_cache_shape
=
(
NUM_BLOCKS
,
2
,
num_kv_heads
,
page_size
,
head_dim
)
kv_cache
=
torch
.
randn
(
size
=
kv_cache_shape
,
device
=
device
,
dtype
=
dtype
)
k_scale
=
v_scale
=
1.0
if
kv_cache_dtype
.
startswith
(
"fp8"
):
kv_cache
,
_
=
to_float8
(
kv_cache
)
# Benchmark TRT decode
def
trt_decode
():
return
flashinfer
.
decode
.
trtllm_batch_decode_with_kv_cache
(
q
,
kv_cache
,
workspace_buffer
,
num_qo_heads
,
num_kv_heads
,
sm_scale
,
block_tables
,
kv_lens_tensor
,
page_size
,
max_kv_len
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
def
time_fn
(
fn
,
warmup
=
10
,
trials
=
20
):
torch
.
cuda
.
synchronize
()
start
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
times
=
[]
for
i
in
range
(
warmup
):
fn
()
for
i
in
range
(
trials
):
start
.
record
()
fn
()
end
.
record
()
torch
.
cuda
.
synchronize
()
times
.
append
(
start
.
elapsed_time
(
end
))
# ms
return
sum
(
times
)
/
len
(
times
),
torch
.
std
(
torch
.
tensor
(
times
))
# TRT Decode
trt_mean
,
trt_std
=
time_fn
(
trt_decode
)
kv_indptr
=
[
0
]
kv_indices
=
[]
kv_last_page_lens
=
[]
for
i
in
range
(
num_seqs
):
seq_len
=
kv_lens
[
i
]
assert
seq_len
>
0
num_blocks
=
(
seq_len
+
page_size
-
1
)
//
page_size
kv_indices
.
extend
(
block_tables
[
i
,
:
num_blocks
])
kv_indptr
.
append
(
kv_indptr
[
-
1
]
+
num_blocks
)
kv_last_page_len
=
seq_len
%
page_size
if
kv_last_page_len
==
0
:
kv_last_page_len
=
page_size
kv_last_page_lens
.
append
(
kv_last_page_len
)
kv_indptr
=
torch
.
tensor
(
kv_indptr
,
dtype
=
torch
.
int32
)
kv_indices
=
torch
.
tensor
(
kv_indices
,
dtype
=
torch
.
int32
)
kv_last_page_lens
=
torch
.
tensor
(
kv_last_page_lens
,
dtype
=
torch
.
int32
)
wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
kv_layout
,
use_tensor_cores
=
((
num_qo_heads
//
num_kv_heads
)
>
4
),
)
wrapper
.
plan
(
kv_indptr
,
kv_indices
,
kv_last_page_lens
,
num_qo_heads
,
num_kv_heads
,
head_dim
,
page_size
,
"NONE"
,
q_data_type
=
dtype
,
kv_data_type
=
torch
.
float8_e4m3fn
if
kv_cache_dtype
.
startswith
(
"fp8"
)
else
dtype
,
)
def
baseline_decode
():
return
wrapper
.
run
(
q
,
kv_cache
,
sm_scale
,
k_scale
,
v_scale
)
baseline_mean
,
baseline_std
=
time_fn
(
baseline_decode
)
# Calculate percentage speedup (positive means TRT is faster)
speedup_percent
=
(
baseline_mean
-
trt_mean
)
/
baseline_mean
print
(
f
"
\t
{
num_seqs
}
\t
{
max_seq_len
}
\t
{
trt_mean
:.
3
f
}
\t
{
trt_std
.
item
():.
3
f
}
"
f
"
\t
{
baseline_mean
:.
3
f
}
\t
{
baseline_std
.
item
():.
3
f
}
\t
{
speedup_percent
:.
3
f
}
"
)
# Return results for CSV writing
return
{
"num_seqs"
:
num_seqs
,
"trt_mean"
:
trt_mean
,
"trt_std"
:
trt_std
.
item
(),
"baseline_mean"
:
baseline_mean
,
"baseline_std"
:
baseline_std
.
item
(),
"speedup_percent"
:
speedup_percent
,
"q_dtype"
:
str
(
dtype
),
"kv_cache_dtype"
:
kv_cache_dtype
,
"page_size"
:
page_size
,
"num_kv_heads"
:
num_kv_heads
,
"head_dim"
:
head_dim
,
"max_seq_len"
:
max_seq_len
,
}
def
write_results_to_csv
(
results
,
filename
=
None
):
"""Write benchmark results to CSV file."""
if
filename
is
None
:
timestamp
=
datetime
.
now
().
strftime
(
"%Y%m%d_%H%M%S"
)
filename
=
f
"flashinfer_trtllm_benchmark_
{
timestamp
}
.csv"
fieldnames
=
[
"num_seqs"
,
"trt_mean"
,
"trt_std"
,
"baseline_mean"
,
"baseline_std"
,
"speedup_percent"
,
"q_dtype"
,
"kv_cache_dtype"
,
"page_size"
,
"num_kv_heads"
,
"head_dim"
,
"max_seq_len"
,
]
file_exists
=
os
.
path
.
exists
(
filename
)
with
open
(
filename
,
"a"
,
newline
=
""
)
as
csvfile
:
writer
=
csv
.
DictWriter
(
csvfile
,
fieldnames
=
fieldnames
)
if
not
file_exists
:
writer
.
writeheader
()
for
result
in
results
:
writer
.
writerow
(
result
)
print
(
f
"Results written to
{
filename
}
"
)
if
__name__
==
"__main__"
:
num_seqs
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
]
max_seq_lens
=
[
1024
,
2048
,
4096
,
8192
,
16384
,
32768
,
65536
,
131072
]
all_results
=
[]
print
(
"Running benchmark for kv_cache_dtype: bfloat16"
)
print
(
"
\t
num_seqs
\t
max_seq_len
\t
trt_mean
\t
trt_std
\t
baseline_mean
\t
baseline_std
\t
speedup_percent"
)
for
max_seq_len
in
max_seq_lens
:
for
bs
in
num_seqs
:
result
=
benchmark_decode
(
bs
,
max_seq_len
,
dtype
=
torch
.
bfloat16
,
kv_cache_dtype
=
"auto"
)
all_results
.
append
(
result
)
print
(
"Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8"
)
print
(
"
\t
num_seqs
\t
max_seq_len
\t
trt_mean
\t
trt_std
\t
baseline_mean
\t
baseline_std
\t
speedup_percent"
)
for
max_seq_len
in
max_seq_lens
:
for
bs
in
num_seqs
:
result
=
benchmark_decode
(
bs
,
max_seq_len
,
dtype
=
torch
.
bfloat16
,
kv_cache_dtype
=
"fp8"
)
all_results
.
append
(
result
)
# Write all results to CSV
write_results_to_csv
(
all_results
)
benchmarks/kv_cache/benchmark_block_pool.py
0 → 100644
View file @
711aa9d5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
gc
import
time
from
typing
import
Optional
from
tabulate
import
tabulate
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.v1.core.block_pool
import
BlockPool
class
Metric
:
def
__init__
(
self
)
->
None
:
self
.
cnt
:
int
=
0
self
.
sum_v
:
int
=
0
self
.
max_v
:
Optional
[
int
]
=
None
def
update
(
self
,
v
:
int
)
->
None
:
self
.
cnt
+=
1
self
.
sum_v
+=
v
if
self
.
max_v
is
None
:
self
.
max_v
=
v
else
:
self
.
max_v
=
max
(
self
.
max_v
,
v
)
def
avg_v
(
self
)
->
float
:
return
self
.
sum_v
*
1.0
/
self
.
cnt
def
main
(
args
):
rows
=
[]
for
allocate_block
in
args
.
allocate_blocks
:
# Enforce a GC collect ahead to minimize the impact among runs
gc
.
collect
()
block_pool
=
BlockPool
(
num_gpu_blocks
=
args
.
num_gpu_blocks
,
enable_caching
=
True
)
get_blocks_metric
:
Metric
=
Metric
()
free_blocks_metric
:
Metric
=
Metric
()
for
_
in
range
(
args
.
num_iteration
):
t1
=
time
.
monotonic_ns
()
blocks
=
block_pool
.
get_new_blocks
(
allocate_block
)
t2
=
time
.
monotonic_ns
()
block_pool
.
free_blocks
(
blocks
)
t3
=
time
.
monotonic_ns
()
get_blocks_metric
.
update
(
t2
-
t1
)
free_blocks_metric
.
update
(
t3
-
t2
)
if
get_blocks_metric
.
max_v
is
not
None
and
free_blocks_metric
.
max_v
is
not
None
:
rows
.
append
(
[
get_blocks_metric
.
cnt
,
args
.
num_gpu_blocks
,
allocate_block
,
get_blocks_metric
.
avg_v
()
/
1000000
,
get_blocks_metric
.
max_v
/
1000000.0
,
free_blocks_metric
.
avg_v
()
/
1000000
,
free_blocks_metric
.
max_v
/
1000000.0
,
]
)
else
:
print
(
"No valid metrics found."
f
"
{
get_blocks_metric
.
max_v
=
}
{
free_blocks_metric
.
max_v
=
}
"
)
print
(
tabulate
(
rows
,
headers
=
[
"Iterations"
,
"Total
\n
Blocks"
,
"Allocated
\n
Blocks"
,
"Get Blocks
\n
Avg (ms)"
,
"Get Blocks
\n
Max (ms)"
,
"Free Blocks
\n
Avg (ms)"
,
"Free Blocks
\n
Max (ms)"
,
],
tablefmt
=
"grid"
,
floatfmt
=
".6f"
,
)
)
def
invoke_main
()
->
None
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the performance of BlockPool for KV Cache."
)
parser
.
add_argument
(
"--num-gpu-blocks"
,
type
=
int
,
default
=
100000
)
parser
.
add_argument
(
"--num-iteration"
,
type
=
int
,
default
=
1000
,
help
=
"Number of iterations to run to stablize final data readings"
,
)
parser
.
add_argument
(
"--allocate-blocks"
,
type
=
int
,
nargs
=
"*"
,
default
=
[
10
,
50
,
100
,
500
,
1000
],
help
=
"Number of blocks to allocate"
,
)
args
=
parser
.
parse_args
()
main
(
args
)
if
__name__
==
"__main__"
:
invoke_main
()
# pragma: no cover
cmake/cpu_extension.cmake
View file @
711aa9d5
...
...
@@ -165,17 +165,32 @@ else()
endif
()
#
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms)
#
if
(
AVX512_FOUND AND NOT AVX512_DISABLED
)
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
# Flag to enable ACL kernels for AARCH64 platforms
if
(
VLLM_BUILD_ACL STREQUAL
"ON"
)
set
(
USE_ACL ON
)
else
()
set
(
USE_ACL OFF
)
endif
()
if
((
AVX512_FOUND AND NOT AVX512_DISABLED
)
OR ASIMD_FOUND
)
FetchContent_Declare
(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.
7
.1
GIT_TAG v3.
8
.1
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
if
(
USE_ACL
)
find_library
(
ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/
)
if
(
NOT ARM_COMPUTE_LIBRARY
)
message
(
FATAL_ERROR
"Could not find ARM Compute Library: please set ACL_ROOT_DIR"
)
endif
()
set
(
ONEDNN_AARCH64_USE_ACL
"ON"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/"
)
endif
()
set
(
ONEDNN_LIBRARY_TYPE
"STATIC"
)
set
(
ONEDNN_BUILD_DOC
"OFF"
)
set
(
ONEDNN_BUILD_EXAMPLES
"OFF"
)
...
...
@@ -264,6 +279,11 @@ elseif(POWER10_FOUND)
"csrc/cpu/quant.cpp"
${
VLLM_EXT_SRC
}
)
endif
()
if
(
ASIMD_FOUND
)
set
(
VLLM_EXT_SRC
"csrc/cpu/quant.cpp"
${
VLLM_EXT_SRC
}
)
endif
()
message
(
STATUS
"CPU extension source files:
${
VLLM_EXT_SRC
}
"
)
...
...
csrc/attention/attention_kernels.cuh
View file @
711aa9d5
...
...
@@ -24,6 +24,7 @@
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "cuda_compat.h"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
...
...
@@ -35,12 +36,6 @@ typedef __hip_bfloat16 __nv_bfloat16;
#include "../quantization/int8_kvcache/quant_utils.cuh"
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
...
...
@@ -684,7 +679,6 @@ __global__ void paged_attention_v2_reduce_kernel(
}
// namespace vllm
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
0 → 100644
View file @
711aa9d5
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
// clang-format off
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
#include "../kernel/sm100_fmha_mla_reduction.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
fmha
::
device
{
using
namespace
cute
;
using
namespace
cutlass
::
fmha
::
kernel
;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template
<
class
Kernel_
>
class
MLA
{
public:
using
Kernel
=
Kernel_
;
using
ReductionKernel
=
cutlass
::
fmha
::
kernel
::
Sm100FmhaMlaReductionKernel
<
typename
Kernel
::
ElementOut
,
typename
Kernel
::
ElementAcc
,
typename
Kernel
::
ElementAcc
,
Kernel
::
TileShapeH
::
value
,
Kernel
::
TileShapeL
::
value
,
256
/*Max split*/
>
;
/// Argument structure: User API
using
KernelArguments
=
typename
Kernel
::
Arguments
;
using
ReductionArguments
=
typename
ReductionKernel
::
Arguments
;
using
Arguments
=
KernelArguments
;
/// Argument structure: Kernel API
using
KernelParams
=
typename
Kernel
::
Params
;
using
ReductionParams
=
typename
ReductionKernel
::
Params
;
struct
Params
{
KernelParams
fmha_params
;
ReductionParams
reduction_params
;
};
private:
/// Kernel API parameters object
Params
params_
;
bool
is_initialized
(
bool
set
=
false
)
{
static
bool
initialized
=
false
;
if
(
set
)
initialized
=
true
;
return
initialized
;
}
static
ReductionArguments
to_reduction_args
(
Arguments
const
&
args
)
{
auto
[
H
,
K
,
D
,
B
]
=
args
.
problem_shape
;
return
ReductionArguments
{
nullptr
,
args
.
epilogue
.
ptr_o
,
nullptr
,
args
.
epilogue
.
ptr_lse
,
args
.
mainloop
.
softmax_scale
,
B
,
args
.
split_kv
,
K
,
args
.
mainloop
.
ptr_seq
,
args
.
ptr_split_kv
,
Kernel
::
TileShapeS
::
value
};
}
public:
/// Access the Params structure
Params
const
&
params
()
const
{
return
params_
;
}
static
void
set_split_kv
(
KernelArguments
&
args
)
{
// printf("set_split_kv start");
if
(
args
.
split_kv
>=
1
)
return
;
auto
[
H
,
K
,
D
,
B
]
=
args
.
problem_shape
;
// std::cout << H << " " << K << " " << D << " " << B << "\n";
int
sm_count
=
args
.
hw_info
.
sm_count
;
// printf(" sm_count = %d\n", sm_count);
int
max_splits
=
ceil_div
(
K
,
128
);
max_splits
=
min
(
16
,
max_splits
);
// printf(" max_splits = %d\n", max_splits);
int
sms_per_batch
=
max
(
1
,
sm_count
/
B
);
// printf(" sms_per_batch = %d\n", sms_per_batch);
int
split_heur
=
min
(
max_splits
,
sms_per_batch
);
int
waves
=
ceil_div
(
B
*
split_heur
,
sm_count
);
int
k_waves
=
ceil_div
(
max_splits
,
split_heur
);
int
split_wave_aware
=
ceil_div
(
max_splits
,
k_waves
);
args
.
split_kv
=
split_wave_aware
;
// printf(" args.split_kv = %d\n", args.split_kv);
}
/// Determines whether the GEMM can execute the given problem.
static
Status
can_implement
(
Arguments
const
&
args
)
{
if
(
!
Kernel
::
can_implement
(
args
))
{
return
Status
::
kInvalid
;
}
if
(
!
ReductionKernel
::
can_implement
(
to_reduction_args
(
args
)))
{
return
Status
::
kInvalid
;
}
return
Status
::
kSuccess
;
}
/// Gets the workspace size
static
size_t
get_workspace_size
(
Arguments
const
&
args
)
{
size_t
workspace_bytes
=
0
;
workspace_bytes
+=
Kernel
::
get_workspace_size
(
args
);
workspace_bytes
+=
ReductionKernel
::
get_workspace_size
(
to_reduction_args
(
args
));
return
workspace_bytes
;
}
/// Computes the maximum number of active blocks per multiprocessor
static
int
maximum_active_blocks
(
int
/* smem_capacity */
=
-
1
)
{
CUTLASS_TRACE_HOST
(
"MLA::maximum_active_blocks()"
);
int
max_active_blocks
=
-
1
;
int
smem_size
=
Kernel
::
SharedStorageSize
;
// first, account for dynamic smem capacity if needed
cudaError_t
result
;
if
(
smem_size
>=
(
48
<<
10
))
{
CUTLASS_TRACE_HOST
(
" Setting smem size to "
<<
smem_size
);
result
=
cudaFuncSetAttribute
(
device_kernel
<
Kernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
cudaSuccess
!=
result
)
{
result
=
cudaGetLastError
();
// to clear the error bit
CUTLASS_TRACE_HOST
(
" cudaFuncSetAttribute() returned error: "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
}
// query occupancy after setting smem size
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
device_kernel
<
Kernel
>
,
Kernel
::
MaxThreadsPerBlock
,
smem_size
);
if
(
cudaSuccess
!=
result
)
{
result
=
cudaGetLastError
();
// to clear the error bit
CUTLASS_TRACE_HOST
(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
return
max_active_blocks
;
}
/// Initializes GEMM state from arguments.
Status
initialize
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MLA::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
// Initialize the workspace
Status
status
=
Kernel
::
initialize_workspace
(
args
,
workspace
,
stream
);
if
(
status
!=
Status
::
kSuccess
)
{
return
status
;
}
status
=
ReductionKernel
::
initialize_workspace
(
to_reduction_args
(
args
),
workspace
,
stream
);
if
(
status
!=
Status
::
kSuccess
)
{
return
status
;
}
KernelParams
kernel_params
=
Kernel
::
to_underlying_arguments
(
args
,
workspace
);
ReductionArguments
reduction_args
=
to_reduction_args
(
args
);
if
(
reduction_args
.
split_kv
>
1
)
{
reduction_args
.
ptr_oaccum
=
kernel_params
.
epilogue
.
ptr_o_acc
;
reduction_args
.
ptr_lseaccum
=
kernel_params
.
epilogue
.
ptr_lse_acc
;
}
ReductionParams
reduction_params
=
ReductionKernel
::
to_underlying_arguments
(
reduction_args
,
workspace
);
// Initialize the Params structure
params_
=
Params
{
kernel_params
,
reduction_params
};
if
(
is_initialized
())
return
Status
::
kSuccess
;
// account for dynamic smem capacity if needed
// no dynamic smem is needed for reduction kernel
int
smem_size
=
Kernel
::
SharedStorageSize
;
if
(
smem_size
>=
(
48
<<
10
))
{
CUTLASS_TRACE_HOST
(
" Setting smem size to "
<<
smem_size
);
cudaError_t
result
=
cudaFuncSetAttribute
(
device_kernel
<
Kernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
cudaSuccess
!=
result
)
{
result
=
cudaGetLastError
();
// to clear the error bit
CUTLASS_TRACE_HOST
(
" cudaFuncSetAttribute() returned error: "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
}
is_initialized
(
true
);
return
Status
::
kSuccess
;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MLA()::update() - workspace: "
<<
workspace
);
size_t
workspace_bytes
=
get_workspace_size
(
args
);
if
(
workspace_bytes
>
0
&&
nullptr
==
workspace
)
{
return
Status
::
kErrorWorkspaceNull
;
}
auto
fmha_params
=
Kernel
::
to_underlying_arguments
(
args
,
workspace
);
ReductionArguments
reduction_args
=
to_reduction_args
(
args
);
if
(
reduction_args
.
split_kv
>
1
)
{
reduction_args
.
ptr_oaccum
=
fmha_params
.
epilogue
.
ptr_o_acc
;
reduction_args
.
ptr_lseaccum
=
fmha_params
.
epilogue
.
ptr_lse_acc
;
}
ReductionParams
reduction_params
=
ReductionKernel
::
to_underlying_arguments
(
reduction_args
,
workspace
);
// Initialize the Params structure
params_
=
Params
{
fmha_params
,
reduction_params
};
return
Status
::
kSuccess
;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static
Status
run
(
Params
&
params
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"MLA::run()"
);
dim3
const
block
=
Kernel
::
get_block_shape
();
dim3
const
grid
=
Kernel
::
get_grid_shape
(
params
.
fmha_params
);
// configure smem size and carveout
int
smem_size
=
Kernel
::
SharedStorageSize
;
Status
launch_result
;
// Use extended launch API only for mainloops that use it
if
constexpr
(
Kernel
::
ArchTag
::
kMinComputeCapability
>=
90
)
{
dim3
cluster
(
cute
::
size
<
0
>
(
typename
Kernel
::
ClusterShape
{}),
cute
::
size
<
1
>
(
typename
Kernel
::
ClusterShape
{}),
cute
::
size
<
2
>
(
typename
Kernel
::
ClusterShape
{}));
void
const
*
kernel
=
(
void
const
*
)
device_kernel
<
Kernel
>
;
void
*
kernel_params
[]
=
{
&
params
.
fmha_params
};
launch_result
=
ClusterLauncher
::
launch
(
grid
,
cluster
,
block
,
smem_size
,
stream
,
kernel
,
kernel_params
);
}
else
{
launch_result
=
Status
::
kSuccess
;
device_kernel
<
Kernel
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
params
.
fmha_params
);
}
cudaError_t
result
=
cudaGetLastError
();
if
(
cudaSuccess
!=
result
or
Status
::
kSuccess
!=
launch_result
)
{
//return Status::kSuccess;
CUTLASS_TRACE_HOST
(
" Kernel launch failed. Reason: "
<<
result
);
return
Status
::
kErrorInternal
;
}
if
(
params
.
reduction_params
.
split_kv
>
1
)
{
// launch reduction kernel
dim3
const
block
=
ReductionKernel
::
get_block_shape
();
dim3
const
grid
=
ReductionKernel
::
get_grid_shape
(
params
.
reduction_params
);
device_kernel
<
ReductionKernel
><<<
grid
,
block
,
0
,
stream
>>>
(
params
.
reduction_params
);
cudaError_t
result
=
cudaGetLastError
();
if
(
cudaSuccess
==
result
)
{
return
Status
::
kSuccess
;
}
else
{
CUTLASS_TRACE_HOST
(
" Kernel launch failed. Reason: "
<<
result
);
return
Status
::
kErrorInternal
;
}
}
else
{
return
Status
::
kSuccess
;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
Status
status
=
initialize
(
args
,
workspace
,
stream
);
if
(
Status
::
kSuccess
==
status
)
{
status
=
run
(
params_
,
stream
);
}
return
status
;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator
()(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
return
run
(
args
,
workspace
,
stream
);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run
(
cudaStream_t
stream
=
nullptr
)
{
return
run
(
params_
,
stream
);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator
()(
cudaStream_t
stream
=
nullptr
)
{
return
run
(
params_
,
stream
);
}
};
////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////
csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_reduction.hpp
0 → 100644
View file @
711aa9d5
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cute/tensor.hpp"
namespace
cutlass
::
fmha
::
kernel
{
using
namespace
cute
;
template
<
class
ElementOut
,
class
ElementAcc
,
class
ElementScale
,
size_t
kNumHeads
,
size_t
kHeadDimLatent
,
int
kMaxSplits
>
struct
Sm100FmhaMlaReductionKernel
{
static
const
int
SharedStorageSize
=
0
;
static
const
int
MaxThreadsPerBlock
=
128
;
static
const
int
MinBlocksPerMultiprocessor
=
1
;
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
static_assert
(
kHeadDimLatent
%
MaxThreadsPerBlock
==
0
);
struct
Arguments
{
ElementAcc
*
ptr_oaccum
=
nullptr
;
ElementOut
*
ptr_o
=
nullptr
;
ElementAcc
*
ptr_lseaccum
=
nullptr
;
ElementAcc
*
ptr_lse
=
nullptr
;
ElementScale
scale
=
1.
f
;
int
num_batches
=
0
;
int
split_kv
=
-
1
;
int
dim_k
=
-
1
;
int
*
ptr_seq
=
nullptr
;
int
*
ptr_split_kv
=
nullptr
;
int
tile_shape_s
=
128
;
};
using
Params
=
Arguments
;
static
Params
to_underlying_arguments
(
Arguments
const
&
args
,
void
*
workspace
)
{
return
{
args
.
ptr_oaccum
,
args
.
ptr_o
,
args
.
ptr_lseaccum
,
args
.
ptr_lse
,
args
.
scale
,
args
.
num_batches
,
args
.
split_kv
,
args
.
dim_k
,
args
.
ptr_seq
,
args
.
ptr_split_kv
,
args
.
tile_shape_s
};
}
static
size_t
get_workspace_size
(
Arguments
const
&
/*args*/
)
{
return
0
;
}
static
Status
initialize_workspace
(
Arguments
const
&
/*args*/
,
void
*
/*ws*/
,
cudaStream_t
/*stream*/
)
{
return
Status
::
kSuccess
;
}
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
return
dim3
(
kNumHeads
,
1
,
params
.
num_batches
);
}
static
dim3
get_block_shape
()
{
return
dim3
(
MaxThreadsPerBlock
,
1
,
1
);
}
static
bool
can_implement
(
Arguments
const
&
args
)
{
if
(
args
.
num_batches
<=
0
)
return
false
;
if
(
args
.
split_kv
<=
0
)
return
false
;
return
true
;
}
CUTLASS_DEVICE
void
operator
()
(
Params
const
&
params
,
char
*
smem_raw
)
{
if
(
params
.
split_kv
<=
1
)
return
;
auto
blk_coord
=
make_coord
(
blockIdx
.
x
,
_0
{},
blockIdx
.
z
);
__shared__
ElementAcc
sLseScale
[
kMaxSplits
];
const
size_t
offset_lseaccum
=
get
<
0
>
(
blk_coord
)
+
kNumHeads
*
params
.
split_kv
*
get
<
2
>
(
blk_coord
);
const
size_t
offset_lse
=
get
<
0
>
(
blk_coord
)
+
kNumHeads
*
get
<
2
>
(
blk_coord
);
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_lseaccum
+
offset_lseaccum
),
make_shape
(
params
.
split_kv
),
Stride
<
Int
<
kNumHeads
>>
{});
Tensor
gLSE
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_lse
+
offset_lse
),
Shape
<
_1
>
{},
Stride
<
_1
>
{});
auto
dim_k
=
params
.
ptr_seq
==
nullptr
?
params
.
dim_k
:
params
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
auto
local_split_kv
=
params
.
ptr_split_kv
==
nullptr
?
params
.
split_kv
:
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
auto
k_tile_total
=
ceil_div
(
dim_k
,
params
.
tile_shape_s
);
auto
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
local_split_kv
);
local_split_kv
=
ceil_div
(
k_tile_total
,
k_tile_per_cta
);
int
warp_idx
=
cutlass
::
canonical_warp_idx_sync
();
if
(
warp_idx
==
0
)
{
constexpr
int
kNLsePerThread
=
cute
::
ceil_div
(
kMaxSplits
,
32
);
ElementAcc
local_lse
[
kNLsePerThread
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kNLsePerThread
;
++
i
)
{
const
int
split
=
i
*
32
+
threadIdx
.
x
;
local_lse
[
i
]
=
split
<
local_split_kv
?
gLSEaccum
(
split
)
:
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
();
}
ElementAcc
lse_max
=
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
();
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kNLsePerThread
;
++
i
)
{
lse_max
=
max
(
lse_max
,
local_lse
[
i
]);
}
CUTLASS_PRAGMA_UNROLL
for
(
int
offset
=
16
;
offset
>=
1
;
offset
/=
2
)
{
lse_max
=
max
(
lse_max
,
__shfl_xor_sync
(
0xffffffff
,
lse_max
,
offset
));
}
lse_max
=
lse_max
==
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
()
?
0.0
f
:
lse_max
;
// In case all local LSEs are -inf
lse_max
=
__shfl_sync
(
0xffffffff
,
lse_max
,
0
);
ElementAcc
sum_lse
=
0
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kNLsePerThread
;
++
i
)
{
sum_lse
=
sum_lse
+
expf
(
local_lse
[
i
]
-
lse_max
);
}
CUTLASS_PRAGMA_UNROLL
for
(
int
offset
=
16
;
offset
>=
1
;
offset
/=
2
)
{
sum_lse
=
sum_lse
+
__shfl_xor_sync
(
0xffffffff
,
sum_lse
,
offset
);
}
sum_lse
=
__shfl_sync
(
0xffffffff
,
sum_lse
,
0
);
ElementAcc
global_lse
=
(
sum_lse
==
0.
f
||
sum_lse
!=
sum_lse
)
?
std
::
numeric_limits
<
ElementAcc
>::
infinity
()
:
logf
(
sum_lse
)
+
lse_max
;
if
(
threadIdx
.
x
==
0
and
params
.
ptr_lse
!=
nullptr
)
{
gLSE
(
0
)
=
global_lse
;
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kNLsePerThread
;
++
i
)
{
const
int
split
=
i
*
32
+
threadIdx
.
x
;
if
(
split
<
local_split_kv
)
{
sLseScale
[
split
]
=
expf
(
local_lse
[
i
]
-
global_lse
);
}
}
}
__syncthreads
();
constexpr
int
Elements
=
kHeadDimLatent
/
MaxThreadsPerBlock
;
const
size_t
offset_oaccum
=
kHeadDimLatent
*
params
.
split_kv
*
(
get
<
0
>
(
blk_coord
)
+
kNumHeads
*
get
<
2
>
(
blk_coord
));
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_oaccum
+
offset_oaccum
),
Shape
<
Int
<
kHeadDimLatent
>>
{},
Stride
<
_1
>
{});
ElementAcc
local_val
[
Elements
]
=
{
0
};
for
(
int
split
=
0
;
split
<
local_split_kv
;
++
split
)
{
ElementAcc
lse_scale
=
sLseScale
[
split
];
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
Elements
;
++
i
)
{
local_val
[
i
]
+=
lse_scale
*
gOaccum
(
threadIdx
.
x
+
MaxThreadsPerBlock
*
i
);
}
gOaccum
.
data
()
=
gOaccum
.
data
()
+
kHeadDimLatent
;
}
auto
ptr_o_local
=
params
.
ptr_o
+
(
get
<
0
>
(
blk_coord
)
+
get
<
2
>
(
blk_coord
)
*
kNumHeads
)
*
kHeadDimLatent
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
ptr_o_local
),
Shape
<
Int
<
kHeadDimLatent
>>
{},
Stride
<
_1
>
{});
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
Elements
;
++
i
)
{
gO
(
threadIdx
.
x
+
MaxThreadsPerBlock
*
i
)
=
static_cast
<
ElementOut
>
(
local_val
[
i
]);
}
}
};
}
// namespace cutlass::fmha::kernel
csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_fmha_mla_tma_warpspecialized.hpp
0 → 100644
View file @
711aa9d5
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "gather_tensor.hpp" // from examples/common
#include "common/pow_2.hpp"
namespace
cutlass
::
fmha
::
kernel
{
using
namespace
cute
;
template
<
class
TileShape
,
class
Element_
,
class
ElementAcc_
,
class
ElementOut_
,
class
ElementLSE_
,
class
TileScheduler
,
#ifdef CPASYNC
bool
kIsCpAsync
=
true
#else
bool
kIsCpAsync
=
false
#endif
>
struct
Sm100FmhaMlaKernelTmaWarpspecialized
{
using
Element
=
Element_
;
using
ElementAcc
=
ElementAcc_
;
using
ElementOut
=
ElementOut_
;
using
ElementLSE
=
ElementLSE_
;
// only 2Sm mode is supported
static
const
bool
kIs2Sm
=
true
;
static
const
int
MaxThreadsPerBlock
=
256
;
static
const
int
MinBlocksPerMultiprocessor
=
1
;
static
const
int
TotalSNum
=
2
;
static
const
int
TotalPNum
=
2
;
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
ClusterShape
=
cute
::
conditional_t
<
kIs2Sm
,
Shape
<
_2
,
_1
,
_1
>
,
Shape
<
_1
,
_1
,
_1
>>
;
using
TileShapeH
=
tuple_element_t
<
0
,
TileShape
>
;
using
TileShapeS
=
tuple_element_t
<
1
,
TileShape
>
;
using
TileShapeD
=
tuple_element_t
<
2
,
TileShape
>
;
using
TileShapeL
=
tuple_element_t
<
0
,
TileShapeD
>
;
using
TileShapeR
=
tuple_element_t
<
1
,
TileShapeD
>
;
static_assert
(
TileShapeL
{}
%
TileShapeR
{}
==
0
,
"Rope head dim must divide latent head dim"
);
using
ProblemShape
=
Shape
<
TileShapeH
,
int
,
TileShapeD
,
int
>
;
using
TensorStride
=
Stride
<
int64_t
,
_1
,
int64_t
>
;
using
TmemAllocator
=
cute
::
conditional_t
<
kIs2Sm
,
cute
::
TMEM
::
Allocator2Sm
,
cute
::
TMEM
::
Allocator1Sm
>
;
static_assert
(
TileShapeH
{}
==
128
);
static
const
int
kWarpsInN
=
kIs2Sm
?
2
:
1
;
static
const
int
kNumComputeWarps
=
4
;
static
const
int
kNumLoadWarps
=
kIsCpAsync
?
2
:
1
;
enum
class
WarpRole
{
kMma
=
0x1
,
kLoad
=
0x2
,
kCompute
=
0x3
,
kLoadPageTable
=
0x4
,
kEmpty
=
0x0
};
static
const
long
long
unsigned
int
kWarpAssignment
=
kIsCpAsync
?
0x4221'3333ull
:
0x0021'3333ull
;
static
CUTLASS_DEVICE
WarpRole
warp_idx_to_role
(
int
warp_idx
)
{
return
static_cast
<
WarpRole
>
((
kWarpAssignment
>>
(
4
*
warp_idx
))
&
0xF
);
}
static
const
int
Alignment
=
128
/
sizeof_bits_v
<
Element
>
;
static
const
int
AlignmentOut
=
128
/
sizeof_bits_v
<
ElementOut
>
;
using
TileShapeQK
=
Shape
<
TileShapeH
,
TileShapeS
,
decltype
(
TileShapeR
{}
/
_1
{})
>
;
static
const
int
StagesQK
=
24
/
sizeof
(
Element
);
// free parameter
static
const
int
IterationsQKLatent
=
decltype
(
TileShapeL
{}
/
get
<
2
>
(
TileShapeQK
{}))
::
value
;
static
const
int
IterationsQKRope
=
decltype
(
TileShapeR
{}
/
get
<
2
>
(
TileShapeQK
{}))
::
value
;
static
const
int
IterationsQK
=
IterationsQKLatent
+
IterationsQKRope
;
using
Schedule
=
cute
::
conditional_t
<
kIs2Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized2SmSm100
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized1SmSm100
>
;
using
CollectiveMmaQK
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm100
,
cutlass
::
arch
::
OpClassTensorOp
,
Element
,
TensorStride
,
Alignment
,
Element
,
TensorStride
,
Alignment
,
ElementAcc
,
TileShapeQK
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCount
<
StagesQK
>
,
Schedule
>::
CollectiveOp
;
using
TiledMmaQK
=
typename
CollectiveMmaQK
::
TiledMma
;
using
CtaShapeQK
=
typename
CollectiveMmaQK
::
CtaShape_MNK
;
// chosen for unified smem staging between K and V
using
TileShapePV
=
Shape
<
TileShapeH
,
_256
,
_32
>
;
using
TransposeTensorStride
=
decltype
(
select
<
1
,
0
,
2
>
(
TensorStride
{}));
static
const
int
StagesPV
=
StagesQK
;
// not sure why, but must be at least two. check pipes
static
const
int
IterationsPV_K
=
decltype
(
TileShapeS
{}
/
get
<
2
>
(
TileShapePV
{}))
::
value
;
static
const
int
IterationsPV_N
=
decltype
(
TileShapeL
{}
/
get
<
1
>
(
TileShapePV
{}))
::
value
;
using
CollectiveMmaPV
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm100
,
cutlass
::
arch
::
OpClassTensorOp
,
Element
,
TensorStride
,
Alignment
,
Element
,
TransposeTensorStride
,
Alignment
,
ElementAcc
,
TileShapePV
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCount
<
StagesPV
>
,
Schedule
>::
CollectiveOp
;
using
CtaShapePV
=
typename
CollectiveMmaPV
::
CtaShape_MNK
;
static_assert
(
std
::
is_same_v
<
TransposeTensorStride
,
typename
CollectiveMmaPV
::
StrideB
>
);
using
TiledMmaPV
=
typename
CollectiveMmaPV
::
TiledMma
;
using
AtomThrShapeMNK
=
typename
CollectiveMmaQK
::
AtomThrShapeMNK
;
static_assert
(
typename
CollectiveMmaQK
::
AtomThrShapeMNK
{}
==
typename
CollectiveMmaPV
::
AtomThrShapeMNK
{},
"schedule must match"
);
static
const
int
StagesPageTable
=
kIsCpAsync
?
StagesPV
:
1
;
// pipelines from load to mma, PipelineTmaUmmaAsync, stages tbd
// use expect_tx for Q load
using
PipelineLoadQK
=
cute
::
conditional_t
<
kIsCpAsync
,
PipelineUmmaConsumerAsync
<
StagesQK
,
AtomThrShapeMNK
>
,
PipelineTmaUmmaAsync
<
StagesQK
,
ClusterShape
,
AtomThrShapeMNK
>>
;
using
PipelineLoadPV
=
PipelineLoadQK
;
// pipeline from mma (Q@K) to softmax, PipelineUmmaAsync, 2 stages
using
PipelineS
=
PipelineUmmaAsync
<
TotalSNum
,
AtomThrShapeMNK
>
;
// pipeline from softmax (P) to mma (bmm2), PipelineUmmaAsync, 2 stages
using
PipelineP
=
PipelineUmmaConsumerAsync
<
TotalPNum
,
AtomThrShapeMNK
>
;
// pipeline from mma to softmax (for rescale), PipelineUmmaAsync, 1 stage
using
PipelineO
=
PipelineUmmaAsync
<
1
,
AtomThrShapeMNK
>
;
using
PipelinePT
=
PipelineAsync
<
StagesPageTable
>
;
struct
PipelineStorage
{
alignas
(
16
)
typename
PipelineLoadQK
::
SharedStorage
load_qk
;
alignas
(
16
)
typename
PipelineS
::
SharedStorage
mma_s
;
alignas
(
16
)
typename
PipelineP
::
SharedStorage
p_mma
;
alignas
(
16
)
typename
PipelineO
::
SharedStorage
mma_o
;
alignas
(
16
)
typename
PipelinePT
::
SharedStorage
load_page_table
;
};
template
<
class
Layout
,
class
Stages
=
_1
>
static
CUTE_DEVICE
constexpr
auto
unstageSmemLayout
(
Layout
const
&
layout
,
Stages
stages
=
{})
{
return
composition
(
layout
,
make_tuple
(
_
,
_
,
_
,
make_layout
(
stages
)));
}
using
SmemLayoutQ
=
decltype
(
unstageSmemLayout
(
typename
CollectiveMmaQK
::
SmemLayoutA
{},
Int
<
IterationsQK
>
{}));
using
SmemLayoutKC
=
typename
CollectiveMmaQK
::
SmemLayoutB
;
using
SmemLayoutVC
=
typename
CollectiveMmaPV
::
SmemLayoutB
;
using
SmemLayoutP
=
decltype
(
unstageSmemLayout
(
typename
CollectiveMmaPV
::
SmemLayoutA
{},
make_shape
(
Int
<
IterationsPV_K
>
{},
_2
{})));
static
const
int
kBytesLoadQ
=
size
(
AtomThrShapeMNK
{})
*
cutlass
::
bits_to_bytes
(
cosize
(
take
<
0
,
3
>
(
SmemLayoutQ
{}))
*
cute
::
sizeof_bits_v
<
Element
>
);
static
const
int
kBytesLoadKC
=
size
(
AtomThrShapeMNK
{})
*
cutlass
::
bits_to_bytes
(
cosize
(
take
<
0
,
3
>
(
SmemLayoutKC
{}))
*
cute
::
sizeof_bits_v
<
Element
>
);
static
const
int
kBytesLoadVC
=
size
(
AtomThrShapeMNK
{})
*
cutlass
::
bits_to_bytes
(
cosize
(
take
<
0
,
3
>
(
SmemLayoutVC
{}))
*
cute
::
sizeof_bits_v
<
Element
>
);
// pre-condition for overlapped smem staging
static_assert
(
kBytesLoadKC
==
kBytesLoadVC
);
static_assert
(
StagesQK
==
StagesPV
);
static
const
int
kTransactionsBytesLoadQK
=
kBytesLoadKC
;
static
const
int
kTransactionsBytesLoadExtraQ
=
kBytesLoadQ
;
static
const
int
kTransactionsBytesLoadPV
=
kBytesLoadVC
;
static
const
int
kNamedBarrierExchange
=
(
int
)
cutlass
::
arch
::
ReservedNamedBarriers
::
TransformBarrier
;
// This Named Barrier is introduced to solve Q tile loading overwritten issue when enable persistent
// tile scheduler for FP8 MLA.
static
const
int
kNamedBarrierEpilogue
=
(
int
)
cutlass
::
arch
::
ReservedNamedBarriers
::
EpilogueBarrier
;
//
static
const
int
kNamedBarrierTmemDealloc
=
(
int
)
cutlass
::
arch
::
ReservedNamedBarriers
::
TmemAllocBarrier
;
enum
class
TmemAllocation
:
uint32_t
{
kSizeS
=
TileShapeS
::
value
/
kWarpsInN
,
// Overall
kSizeO
=
TileShapeL
::
value
/
kWarpsInN
,
// Between accumulators we loop over
kSizeAccO
=
decltype
(
get
<
1
>
(
TileShapePV
{}))
::
value
/
kWarpsInN
,
kNumS
=
TotalSNum
,
kNumP
=
TotalPNum
,
kNumO
=
1
,
kS0
=
0
,
kS1
=
kS0
+
kSizeS
,
kO0
=
kS1
+
kSizeS
,
kTotal
=
kO0
+
kSizeO
};
static_assert
(
static_cast
<
int
>
(
TmemAllocation
::
kTotal
)
<=
TmemAllocator
::
Sm100TmemCapacityColumns
,
"using too much tmem"
);
struct
TensorStorage
{
// to communicate max and row_sum
cute
::
array
<
ElementAcc
,
kNumComputeWarps
*
cutlass
::
NumThreadsPerWarp
>
smem_exchange
;
cute
::
array
<
int
,
StagesPageTable
*
TileShapeS
::
value
>
smem_page_table
;
alignas
(
2048
)
cute
::
array
<
Element
,
cute
::
cosize_v
<
SmemLayoutQ
>>
smem_q
;
union
{
alignas
(
2048
)
cute
::
array
<
Element
,
cute
::
cosize_v
<
SmemLayoutKC
>>
smem_kc
;
alignas
(
2048
)
cute
::
array
<
Element
,
cute
::
cosize_v
<
SmemLayoutVC
>>
smem_vc
;
};
alignas
(
2048
)
cute
::
array
<
Element
,
cute
::
cosize_v
<
SmemLayoutP
>>
smem_p
;
};
struct
SharedStorage
{
PipelineStorage
pipelines
;
TensorStorage
tensors
;
uint32_t
tmem_base_ptr
;
};
static
const
int
SharedStorageSize
=
sizeof
(
SharedStorage
);
static_assert
(
SharedStorageSize
<=
cutlass
::
arch
::
sm100_smem_capacity_bytes
,
"using too much smem"
);
struct
MainloopArguments
{
ElementAcc
softmax_scale
;
// all tensors strides are (num_heads or seqlen, head_dim, batch)
// head_dim stride is always 1
Element
*
ptr_q_latent
;
TensorStride
stride_q_latent
;
Element
*
ptr_q_rope
;
TensorStride
stride_q_rope
;
Element
*
ptr_c_latent
;
TensorStride
stride_c_latent
;
Element
*
ptr_k_rope
;
TensorStride
stride_k_rope
;
// for paged attention, we interpret what was previously [batch, seqlen]
// as [page_count, page_size], and index according to page_table
int
*
ptr_seq
=
nullptr
;
int
*
ptr_page_table
=
nullptr
;
// page table is [batch, seqlen or similar]
Stride
<
_1
,
int
>
stride_page_table
=
{};
int
page_count
=
0
;
int
page_size
=
TileShapeS
{};
// powers of two if kIsCpAsync, otherwise TileShapeS
};
struct
EpilogueArguments
{
ElementOut
*
ptr_o
=
nullptr
;
TensorStride
stride_o
;
ElementLSE
*
ptr_lse
=
nullptr
;
Stride
<
_1
,
int
>
stride_lse
;
ElementAcc
output_scale
=
1.0
f
;
};
struct
Arguments
{
// (num_heads=128, seqlen, (d_latent=512, d_rope=64), batch_count)
// for paged attention, seqlen is max seqlen
ProblemShape
problem_shape
;
MainloopArguments
mainloop
;
EpilogueArguments
epilogue
;
KernelHardwareInfo
hw_info
;
int
split_kv
=
-
1
;
int
*
ptr_split_kv
=
nullptr
;
};
using
TmaLoadQLatent
=
typename
CollectiveMmaQK
::
Params
::
TMA_A
;
using
TmaLoadQRope
=
typename
CollectiveMmaQK
::
Params
::
TMA_A
;
using
TmaLoadCLatent
=
typename
CollectiveMmaQK
::
Params
::
TMA_B
;
using
TmaLoadKRope
=
typename
CollectiveMmaQK
::
Params
::
TMA_B
;
using
TmaLoadCLatentTranspose
=
typename
CollectiveMmaPV
::
Params
::
TMA_B
;
struct
MainloopParams
{
TmaLoadQLatent
tma_load_q_latent
;
TmaLoadQRope
tma_load_q_rope
;
TmaLoadCLatent
tma_load_c_latent
;
TmaLoadKRope
tma_load_k_rope
;
TmaLoadCLatentTranspose
tma_load_c_latent_transpose
;
};
struct
EpilogueParams
{
ElementOut
*
ptr_o
=
nullptr
;
ElementAcc
*
ptr_o_acc
=
nullptr
;
TensorStride
stride_o
;
TensorStride
stride_o_acc
;
ElementLSE
*
ptr_lse
=
nullptr
;
ElementLSE
*
ptr_lse_acc
=
nullptr
;
Stride
<
_1
,
int
>
stride_lse
;
Stride
<
_1
,
int
>
stride_lse_acc
;
ElementAcc
output_scale
=
1.0
f
;
};
struct
Params
{
ProblemShape
problem_shape
;
MainloopArguments
mainloop
;
EpilogueParams
epilogue
;
MainloopParams
mainloop_params
;
typename
TileScheduler
::
Params
tile_scheduler
;
int
split_kv
=
-
1
;
int
*
ptr_split_kv
=
nullptr
;
};
static
Params
to_underlying_arguments
(
Arguments
const
&
args
,
void
*
workspace
)
{
//workspace = nullptr; // let's get an error if one of these needs workspace
auto
[
H
,
K
,
D
,
B
]
=
args
.
problem_shape
;
auto
[
L
,
R
]
=
D
;
int
paged_B
=
B
;
int
paged_K
=
K
;
if
(
args
.
mainloop
.
ptr_page_table
!=
nullptr
)
{
paged_B
=
args
.
mainloop
.
page_count
;
paged_K
=
args
.
mainloop
.
page_size
;
}
auto
params_qk_latent
=
CollectiveMmaQK
::
to_underlying_arguments
(
make_shape
(
H
,
K
,
L
,
B
),
typename
CollectiveMmaQK
::
Arguments
{
args
.
mainloop
.
ptr_q_latent
,
args
.
mainloop
.
stride_q_latent
,
args
.
mainloop
.
ptr_c_latent
,
args
.
mainloop
.
stride_c_latent
,
},
nullptr
);
auto
params_qk_latent_paged
=
CollectiveMmaQK
::
to_underlying_arguments
(
make_shape
(
H
,
paged_K
,
L
,
paged_B
),
typename
CollectiveMmaQK
::
Arguments
{
args
.
mainloop
.
ptr_q_latent
,
args
.
mainloop
.
stride_q_latent
,
args
.
mainloop
.
ptr_c_latent
,
args
.
mainloop
.
stride_c_latent
,
},
nullptr
);
auto
params_qk_rope
=
CollectiveMmaQK
::
to_underlying_arguments
(
make_shape
(
H
,
K
,
R
,
B
),
typename
CollectiveMmaQK
::
Arguments
{
args
.
mainloop
.
ptr_q_rope
,
args
.
mainloop
.
stride_q_rope
,
args
.
mainloop
.
ptr_k_rope
,
args
.
mainloop
.
stride_k_rope
,
},
nullptr
);
auto
params_qk_rope_paged
=
CollectiveMmaQK
::
to_underlying_arguments
(
make_shape
(
H
,
paged_K
,
R
,
paged_B
),
typename
CollectiveMmaQK
::
Arguments
{
args
.
mainloop
.
ptr_q_rope
,
args
.
mainloop
.
stride_q_rope
,
args
.
mainloop
.
ptr_k_rope
,
args
.
mainloop
.
stride_k_rope
,
},
nullptr
);
auto
stride_c_latent_transpose
=
select
<
1
,
0
,
2
>
(
args
.
mainloop
.
stride_c_latent
);
auto
params_pv_latent
=
CollectiveMmaPV
::
to_underlying_arguments
(
make_shape
(
H
,
L
,
paged_K
,
paged_B
),
typename
CollectiveMmaPV
::
Arguments
{
args
.
mainloop
.
ptr_q_latent
,
args
.
mainloop
.
stride_q_latent
,
// dummy, never used
args
.
mainloop
.
ptr_c_latent
,
stride_c_latent_transpose
,
},
nullptr
);
MainloopParams
mainloop_params
{
params_qk_latent
.
tma_load_a
,
params_qk_rope
.
tma_load_a
,
params_qk_latent_paged
.
tma_load_b
,
params_qk_rope_paged
.
tma_load_b
,
params_pv_latent
.
tma_load_b
};
EpilogueParams
epilogue_params
;
epilogue_params
.
ptr_o
=
args
.
epilogue
.
ptr_o
;
epilogue_params
.
stride_o
=
args
.
epilogue
.
stride_o
;
epilogue_params
.
ptr_lse
=
args
.
epilogue
.
ptr_lse
;
epilogue_params
.
stride_lse
=
args
.
epilogue
.
stride_lse
;
epilogue_params
.
output_scale
=
args
.
epilogue
.
output_scale
;
if
(
args
.
split_kv
>
1
)
{
ElementAcc
*
ptr_o_acc
=
reinterpret_cast
<
ElementAcc
*>
(
workspace
);
ElementLSE
*
ptr_lse_acc
=
reinterpret_cast
<
ElementLSE
*>
(
ptr_o_acc
+
H
*
L
*
args
.
split_kv
*
B
);
epilogue_params
.
ptr_o_acc
=
ptr_o_acc
;
epilogue_params
.
ptr_lse_acc
=
ptr_lse_acc
;
epilogue_params
.
stride_o_acc
=
make_tuple
(
static_cast
<
int64_t
>
(
0
+
L
)
*
args
.
split_kv
,
_1
{},
static_cast
<
int64_t
>
(
0
+
H
*
L
)
*
args
.
split_kv
);
epilogue_params
.
stride_lse_acc
=
make_tuple
(
_1
{},
(
0
+
H
)
*
args
.
split_kv
);
}
return
{
args
.
problem_shape
,
args
.
mainloop
,
epilogue_params
,
mainloop_params
,
TileScheduler
::
to_underlying_arguments
(
args
.
problem_shape
,
args
.
hw_info
,
ClusterShape
{},
args
.
split_kv
),
args
.
split_kv
,
args
.
ptr_split_kv
};
}
static
size_t
get_workspace_size
(
Arguments
const
&
args
)
{
ProblemShape
problem_shape
=
args
.
problem_shape
;
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
auto
split_kv
=
args
.
split_kv
;
return
(
sizeof
(
ElementAcc
)
*
D_latent
+
sizeof
(
ElementLSE
))
*
H
*
split_kv
*
B
;
}
static
Status
initialize_workspace
(
Arguments
const
&
/*args*/
,
void
*
/*ws*/
,
cudaStream_t
/*stream*/
)
{
return
Status
::
kSuccess
;
}
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
return
TileScheduler
::
get_grid_shape
(
params
.
tile_scheduler
);
}
static
dim3
get_block_shape
()
{
dim3
block
(
MaxThreadsPerBlock
,
1
,
1
);
return
block
;
}
static
bool
can_implement
(
Arguments
const
&
args
)
{
if
(
kIsCpAsync
)
{
if
((
args
.
mainloop
.
page_size
&
(
args
.
mainloop
.
page_size
-
1
))
!=
0
)
{
return
false
;
}
if
(
args
.
mainloop
.
page_size
>
TileShapeS
{})
{
return
false
;
}
}
else
{
if
(
args
.
mainloop
.
ptr_page_table
!=
nullptr
&&
args
.
mainloop
.
page_size
!=
TileShapeS
{})
{
return
false
;
}
}
if
(
get
<
0
>
(
args
.
problem_shape
)
!=
128
)
{
return
false
;
}
if
(
get
<
1
>
(
args
.
problem_shape
)
<=
0
)
{
return
false
;
}
if
(
args
.
split_kv
<=
0
)
{
return
false
;
}
return
true
;
}
CUTLASS_DEVICE
void
operator
()(
Params
const
&
params
,
char
*
smem_raw
)
{
TileScheduler
tile_scheduler
(
params
.
tile_scheduler
);
int
warp_idx
=
cutlass
::
canonical_warp_idx_sync
();
auto
role
=
warp_idx_to_role
(
warp_idx
);
uint32_t
lane_predicate
=
cute
::
elect_one_sync
();
uint32_t
cta_rank_in_cluster
=
cute
::
block_rank_in_cluster
();
int
cta_coord_v
=
cta_rank_in_cluster
%
size
<
0
>
(
AtomThrShapeMNK
{});
bool
is_mma_leader_cta
=
cta_coord_v
==
0
;
if
(
role
==
WarpRole
::
kLoad
&&
lane_predicate
&&
!
kIsCpAsync
)
{
prefetch_tma_descriptor
(
params
.
mainloop_params
.
tma_load_q_latent
.
get_tma_descriptor
());
prefetch_tma_descriptor
(
params
.
mainloop_params
.
tma_load_c_latent
.
get_tma_descriptor
());
prefetch_tma_descriptor
(
params
.
mainloop_params
.
tma_load_q_rope
.
get_tma_descriptor
());
prefetch_tma_descriptor
(
params
.
mainloop_params
.
tma_load_k_rope
.
get_tma_descriptor
());
prefetch_tma_descriptor
(
params
.
mainloop_params
.
tma_load_c_latent_transpose
.
get_tma_descriptor
());
}
SharedStorage
&
shared_storage
=
*
reinterpret_cast
<
SharedStorage
*>
(
smem_raw
);
typename
PipelineLoadQK
::
Params
pipeline_load_qk_params
;
if
(
role
==
WarpRole
::
kLoad
)
{
pipeline_load_qk_params
.
role
=
PipelineLoadQK
::
ThreadCategory
::
Producer
;
}
if
(
role
==
WarpRole
::
kMma
)
{
pipeline_load_qk_params
.
role
=
PipelineLoadQK
::
ThreadCategory
::
Consumer
;
}
if
constexpr
(
kIsCpAsync
)
{
// we can make our life easier by unconditionally loading blocks
// since we know it'll always be legal
pipeline_load_qk_params
.
producer_arv_count
=
kNumLoadWarps
*
cutlass
::
NumThreadsPerWarp
*
size
(
AtomThrShapeMNK
{});
}
else
{
pipeline_load_qk_params
.
is_leader
=
lane_predicate
&&
(
role
==
WarpRole
::
kLoad
)
&&
is_mma_leader_cta
;
pipeline_load_qk_params
.
transaction_bytes
=
kTransactionsBytesLoadQK
;
}
pipeline_load_qk_params
.
initializing_warp
=
0
;
PipelineLoadQK
pipeline_load_qk
(
shared_storage
.
pipelines
.
load_qk
,
pipeline_load_qk_params
,
ClusterShape
{},
/*barrier init*/
cute
::
true_type
{},
/*mask calc*/
cute
::
false_type
{});
typename
PipelineS
::
Params
pipeline_mma_s_params
;
if
(
role
==
WarpRole
::
kMma
)
{
pipeline_mma_s_params
.
role
=
PipelineS
::
ThreadCategory
::
Producer
;
}
if
(
role
==
WarpRole
::
kCompute
)
{
pipeline_mma_s_params
.
role
=
PipelineS
::
ThreadCategory
::
Consumer
;
}
pipeline_mma_s_params
.
consumer_arv_count
=
kNumComputeWarps
*
cutlass
::
NumThreadsPerWarp
*
size
(
AtomThrShapeMNK
{});
pipeline_mma_s_params
.
initializing_warp
=
1
;
PipelineS
pipeline_mma_s
(
shared_storage
.
pipelines
.
mma_s
,
pipeline_mma_s_params
,
ClusterShape
{},
/*barrier init*/
cute
::
true_type
{},
/*mask calc*/
cute
::
false_type
{});
typename
PipelineP
::
Params
pipeline_p_mma_params
;
if
(
role
==
WarpRole
::
kMma
)
{
pipeline_p_mma_params
.
role
=
PipelineP
::
ThreadCategory
::
Consumer
;
}
if
(
role
==
WarpRole
::
kCompute
)
{
pipeline_p_mma_params
.
role
=
PipelineP
::
ThreadCategory
::
Producer
;
}
pipeline_p_mma_params
.
producer_arv_count
=
kNumComputeWarps
*
cutlass
::
NumThreadsPerWarp
*
size
(
AtomThrShapeMNK
{});
pipeline_p_mma_params
.
consumer_arv_count
=
1
;
pipeline_p_mma_params
.
initializing_warp
=
2
;
PipelineP
pipeline_p_mma
(
shared_storage
.
pipelines
.
p_mma
,
pipeline_p_mma_params
,
ClusterShape
{},
/*barrier init*/
cute
::
true_type
{},
/*mask calc*/
cute
::
false_type
{});
typename
PipelineO
::
Params
pipeline_mma_o_params
;
if
(
role
==
WarpRole
::
kMma
)
{
pipeline_mma_o_params
.
role
=
PipelineO
::
ThreadCategory
::
Producer
;
}
if
(
role
==
WarpRole
::
kCompute
)
{
pipeline_mma_o_params
.
role
=
PipelineO
::
ThreadCategory
::
Consumer
;
}
pipeline_mma_o_params
.
consumer_arv_count
=
kNumComputeWarps
*
cutlass
::
NumThreadsPerWarp
*
size
(
AtomThrShapeMNK
{});
pipeline_mma_o_params
.
initializing_warp
=
3
;
PipelineO
pipeline_mma_o
(
shared_storage
.
pipelines
.
mma_o
,
pipeline_mma_o_params
,
ClusterShape
{},
/*barrier init*/
cute
::
true_type
{},
/*mask calc*/
cute
::
false_type
{});
typename
PipelinePT
::
Params
pipeline_pt_params
;
if
(
role
==
WarpRole
::
kLoad
)
{
pipeline_pt_params
.
role
=
PipelinePT
::
ThreadCategory
::
Consumer
;
}
if
(
role
==
WarpRole
::
kLoadPageTable
)
{
pipeline_pt_params
.
role
=
PipelinePT
::
ThreadCategory
::
Producer
;
}
pipeline_pt_params
.
consumer_arv_count
=
kNumLoadWarps
*
cutlass
::
NumThreadsPerWarp
;
pipeline_pt_params
.
producer_arv_count
=
cutlass
::
NumThreadsPerWarp
;
pipeline_pt_params
.
initializing_warp
=
4
;
PipelinePT
pipeline_page_table
(
shared_storage
.
pipelines
.
load_page_table
,
pipeline_pt_params
);
TmemAllocator
tmem_allocator
;
pipeline_init_arrive_relaxed
(
size
(
ClusterShape
{}));
pipeline_load_qk
.
init_masks
(
ClusterShape
{});
// do we need an update here for 2Sm?
pipeline_mma_s
.
init_masks
(
ClusterShape
{});
pipeline_p_mma
.
init_masks
(
ClusterShape
{});
pipeline_mma_o
.
init_masks
(
ClusterShape
{});
typename
PipelineLoadQK
::
PipelineState
pipeline_load_qk_consumer_state
;
typename
PipelineLoadQK
::
PipelineState
pipeline_load_qk_producer_state
=
cutlass
::
make_producer_start_state
<
PipelineLoadQK
>
();
typename
PipelineS
::
PipelineState
pipeline_mma_s_consumer_state
;
typename
PipelineS
::
PipelineState
pipeline_mma_s_producer_state
=
cutlass
::
make_producer_start_state
<
PipelineS
>
();
typename
PipelineP
::
PipelineState
pipeline_p_mma_consumer_state
;
typename
PipelineP
::
PipelineState
pipeline_p_mma_producer_state
=
cutlass
::
make_producer_start_state
<
PipelineP
>
();
typename
PipelineO
::
PipelineState
pipeline_mma_o_consumer_state
;
typename
PipelineO
::
PipelineState
pipeline_mma_o_producer_state
=
cutlass
::
make_producer_start_state
<
PipelineO
>
();
typename
PipelinePT
::
PipelineState
pipeline_pt_consumer_state
;
typename
PipelinePT
::
PipelineState
pipeline_pt_producer_state
=
cutlass
::
make_producer_start_state
<
PipelinePT
>
();
pipeline_init_wait
(
size
(
ClusterShape
{}));
if
(
role
==
WarpRole
::
kLoadPageTable
)
{
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
problem_shape
=
params
.
problem_shape
;
auto
local_split_kv
=
params
.
split_kv
;
if
(
params
.
mainloop
.
ptr_seq
!=
nullptr
)
{
get
<
1
>
(
problem_shape
)
=
params
.
mainloop
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
if
(
params
.
ptr_split_kv
!=
nullptr
)
{
local_split_kv
=
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
}
}
if
(
local_split_kv
<=
get
<
3
>
(
blk_coord
))
continue
;
load_page_table
(
blk_coord
,
problem_shape
,
params
.
mainloop
,
shared_storage
.
tensors
,
pipeline_page_table
,
pipeline_pt_producer_state
,
local_split_kv
);
}
}
else
if
(
role
==
WarpRole
::
kLoad
)
{
if
constexpr
(
kIsCpAsync
)
{
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
problem_shape
=
params
.
problem_shape
;
auto
local_split_kv
=
params
.
split_kv
;
if
(
params
.
mainloop
.
ptr_seq
!=
nullptr
)
{
get
<
1
>
(
problem_shape
)
=
params
.
mainloop
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
if
(
params
.
ptr_split_kv
!=
nullptr
)
{
local_split_kv
=
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
}
}
if
(
local_split_kv
<=
get
<
3
>
(
blk_coord
))
continue
;
load_cpasync
(
blk_coord
,
problem_shape
,
params
.
mainloop
,
params
.
mainloop_params
,
shared_storage
.
tensors
,
pipeline_load_qk
,
pipeline_load_qk_producer_state
,
local_split_kv
,
/* must be shared pipe */
pipeline_page_table
,
pipeline_pt_consumer_state
);
cutlass
::
arch
::
NamedBarrier
((
kNumComputeWarps
+
kNumLoadWarps
)
*
NumThreadsPerWarp
,
kNamedBarrierEpilogue
).
arrive_and_wait
();
}
}
else
{
if
(
params
.
mainloop
.
ptr_page_table
!=
nullptr
)
{
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
problem_shape
=
params
.
problem_shape
;
auto
local_split_kv
=
params
.
split_kv
;
if
(
params
.
mainloop
.
ptr_seq
!=
nullptr
)
{
get
<
1
>
(
problem_shape
)
=
params
.
mainloop
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
if
(
params
.
ptr_split_kv
!=
nullptr
)
{
local_split_kv
=
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
}
}
if
(
local_split_kv
<=
get
<
3
>
(
blk_coord
))
continue
;
load_tma
<
/* paged= */
true
>
(
blk_coord
,
problem_shape
,
params
.
mainloop
,
params
.
mainloop_params
,
shared_storage
.
tensors
,
pipeline_load_qk
,
pipeline_load_qk_producer_state
,
pipeline_load_qk
,
pipeline_load_qk_producer_state
,
local_split_kv
);
cutlass
::
arch
::
NamedBarrier
((
kNumComputeWarps
+
kNumLoadWarps
)
*
NumThreadsPerWarp
,
kNamedBarrierEpilogue
).
arrive_and_wait
();
}
}
else
{
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
problem_shape
=
params
.
problem_shape
;
auto
local_split_kv
=
params
.
split_kv
;
if
(
params
.
mainloop
.
ptr_seq
!=
nullptr
)
{
get
<
1
>
(
problem_shape
)
=
params
.
mainloop
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
if
(
params
.
ptr_split_kv
!=
nullptr
)
{
local_split_kv
=
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
}
}
if
(
local_split_kv
<=
get
<
3
>
(
blk_coord
))
continue
;
load_tma
<
false
>
(
blk_coord
,
problem_shape
,
params
.
mainloop
,
params
.
mainloop_params
,
shared_storage
.
tensors
,
pipeline_load_qk
,
pipeline_load_qk_producer_state
,
pipeline_load_qk
,
pipeline_load_qk_producer_state
,
local_split_kv
);
cutlass
::
arch
::
NamedBarrier
((
kNumComputeWarps
+
kNumLoadWarps
)
*
NumThreadsPerWarp
,
kNamedBarrierEpilogue
).
arrive_and_wait
();
}
}
}
}
else
if
(
role
==
WarpRole
::
kMma
)
{
tmem_allocator
.
allocate
(
TmemAllocator
::
Sm100TmemCapacityColumns
,
&
shared_storage
.
tmem_base_ptr
);
__syncwarp
();
if
(
is_mma_leader_cta
)
{
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
problem_shape
=
params
.
problem_shape
;
auto
local_split_kv
=
params
.
split_kv
;
if
(
params
.
mainloop
.
ptr_seq
!=
nullptr
)
{
get
<
1
>
(
problem_shape
)
=
params
.
mainloop
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
if
(
params
.
ptr_split_kv
!=
nullptr
)
{
local_split_kv
=
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
}
}
if
(
local_split_kv
<=
get
<
3
>
(
blk_coord
))
continue
;
mma
(
blk_coord
,
problem_shape
,
shared_storage
.
tensors
,
pipeline_load_qk
,
pipeline_load_qk_consumer_state
,
pipeline_load_qk
,
pipeline_load_qk_consumer_state
,
pipeline_mma_s
,
pipeline_mma_s_producer_state
,
pipeline_p_mma
,
pipeline_p_mma_consumer_state
,
pipeline_mma_o
,
pipeline_mma_o_producer_state
,
local_split_kv
);
}
}
//cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive_and_wait();
//uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
//tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
else
if
(
role
==
WarpRole
::
kCompute
)
{
CUTLASS_PRAGMA_NO_UNROLL
for
(;
tile_scheduler
.
is_valid
();
++
tile_scheduler
)
{
auto
blk_coord
=
tile_scheduler
.
get_block_coord
();
auto
problem_shape
=
params
.
problem_shape
;
auto
split_kv
=
params
.
split_kv
;
auto
local_split_kv
=
split_kv
;
if
(
params
.
mainloop
.
ptr_seq
!=
nullptr
)
{
get
<
1
>
(
problem_shape
)
=
params
.
mainloop
.
ptr_seq
[
get
<
2
>
(
blk_coord
)];
if
(
params
.
ptr_split_kv
!=
nullptr
)
{
local_split_kv
=
params
.
ptr_split_kv
[
get
<
2
>
(
blk_coord
)];
}
}
if
(
local_split_kv
<=
get
<
3
>
(
blk_coord
))
continue
;
compute
(
blk_coord
,
problem_shape
,
params
.
mainloop
,
// for softmax_scale
params
.
epilogue
,
shared_storage
.
tensors
,
// for smem_comm
pipeline_mma_s
,
pipeline_mma_s_consumer_state
,
pipeline_p_mma
,
pipeline_p_mma_producer_state
,
pipeline_mma_o
,
pipeline_mma_o_consumer_state
,
local_split_kv
);
}
//cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive();
}
cute
::
cluster_sync
();
cutlass
::
arch
::
NamedBarrier
((
kNumComputeWarps
+
1
)
*
NumThreadsPerWarp
,
kNamedBarrierTmemDealloc
).
arrive
();
if
(
role
==
WarpRole
::
kMma
)
{
uint32_t
free_stage_ptr
=
shared_storage
.
tmem_base_ptr
;
tmem_allocator
.
free
(
free_stage_ptr
,
TmemAllocator
::
Sm100TmemCapacityColumns
);
}
}
template
<
class
BlkCoord
>
CUTLASS_DEVICE
void
load_page_table
(
BlkCoord
const
&
blk_coord
,
ProblemShape
const
&
problem_shape
,
MainloopArguments
const
&
mainloop_args
,
TensorStorage
&
shared_tensors
,
PipelinePT
&
pipeline_page_table
,
typename
PipelinePT
::
PipelineState
&
pipeline_pt_producer_state
,
int
const
&
split_kv
)
{
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
int
batch_coord
=
get
<
2
>
(
blk_coord
);
auto
mPT_l
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_page_table
),
make_shape
(
mainloop_args
.
page_count
,
B
),
mainloop_args
.
stride_page_table
);
auto
mPT
=
mPT_l
(
_
,
batch_coord
);
int
k_tile_total
=
ceil_div
(
K
,
TileShapeS
{});
int
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
split_kv
);
int
k_index
=
get
<
3
>
(
blk_coord
)
*
k_tile_per_cta
;
// lower limit
int
k_tile_count
=
max
(
0
,
min
(
k_tile_total
,
k_index
+
k_tile_per_cta
)
-
k_index
);
if
(
k_tile_count
==
0
)
{
return
;
}
auto
page_size
=
Pow2
{
mainloop_args
.
page_size
};
auto
pages_per_tile
=
Pow2
{
TileShapeS
{}
/
page_size
};
int
thread_idx
=
threadIdx
.
x
%
cutlass
::
NumThreadsPerWarp
;
#if 1
for
(;
k_tile_count
>
0
;
++
k_index
,
--
k_tile_count
)
{
pipeline_page_table
.
producer_acquire
(
pipeline_pt_producer_state
);
// assume a single warp
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
TileShapeS
{};
i
+=
cutlass
::
NumThreadsPerWarp
)
{
int
idx
=
i
+
thread_idx
;
bool
guard
=
idx
<
pages_per_tile
;
int
smem_idx
=
pipeline_pt_producer_state
.
index
()
*
TileShapeS
::
value
+
idx
;
int
pt_idx
=
pages_per_tile
*
k_index
+
idx
;
cutlass
::
arch
::
cp_async_zfill
<
sizeof
(
int
),
cutlass
::
arch
::
CacheOperation
::
Always
>
(
&
shared_tensors
.
smem_page_table
[
smem_idx
],
&
mPT
(
pt_idx
),
guard
);
}
pipeline_page_table
.
producer_commit
(
pipeline_pt_producer_state
,
cutlass
::
arch
::
cpasync_barrier_arrive
);
++
pipeline_pt_producer_state
;
}
#endif
}
struct
Gather
{
int
&
page_table_stage
;
Pow2
pages_per_tile
;
const
int
*
__restrict__
smem_page_table
;
CUTLASS_DEVICE
int
operator
()(
int
idx
)
const
{
return
smem_page_table
[
page_table_stage
*
TileShapeS
::
value
+
idx
%
pages_per_tile
];
}
CUTLASS_DEVICE
friend
void
print
(
Gather
const
&
)
{
printf
(
"<gather>"
);
}
};
template
<
class
BlkCoord
>
CUTLASS_DEVICE
void
load_cpasync
(
BlkCoord
const
&
blk_coord
,
ProblemShape
const
&
problem_shape
,
MainloopArguments
const
&
mainloop_args
,
MainloopParams
const
&
mainloop_params
,
TensorStorage
&
shared_tensors
,
PipelineLoadQK
&
pipeline_load
,
typename
PipelineLoadQK
::
PipelineState
&
pipeline_load_producer_state
,
int
const
&
split_kv
,
PipelinePT
&
pipeline_page_table
,
typename
PipelinePT
::
PipelineState
&
pipeline_pt_consumer_state
)
{
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
using
X
=
Underscore
;
int
k_tile_total
=
ceil_div
(
K
,
TileShapeS
{});
int
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
split_kv
);
int
k_index
=
get
<
3
>
(
blk_coord
)
*
k_tile_per_cta
;
// lower limit
int
k_tile_count
=
max
(
0
,
min
(
k_tile_total
,
k_index
+
k_tile_per_cta
)
-
k_index
);
if
(
k_tile_count
==
0
)
{
return
;
}
// partition all tensors
auto
mQL
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_q_latent
),
make_shape
(
H
,
D_latent
,
B
),
mainloop_args
.
stride_q_latent
);
auto
mQR
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_q_rope
),
make_shape
(
H
,
D_rope
,
B
),
mainloop_args
.
stride_q_rope
);
int
paged_B
=
mainloop_args
.
page_count
;
auto
paged_K
=
Pow2
{
mainloop_args
.
page_size
};
auto
mPT_l
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_page_table
),
make_shape
(
paged_B
,
B
),
mainloop_args
.
stride_page_table
);
int
batch_coord
=
get
<
2
>
(
blk_coord
);
auto
mPT
=
mPT_l
(
_
,
batch_coord
);
auto
gQL
=
local_tile
(
mQL
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
auto
gQR
=
local_tile
(
mQR
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
ThrMMA
cta_mma_qk
=
TiledMmaQK
{}.
get_slice
(
get
<
0
>
(
blk_coord
)
%
size
(
AtomThrShapeMNK
{}));
ThrMMA
cta_mma_pv
=
TiledMmaPV
{}.
get_slice
(
get
<
0
>
(
blk_coord
)
%
size
(
AtomThrShapeMNK
{}));
auto
tSgQL
=
cta_mma_qk
.
partition_A
(
gQL
);
auto
tSgQR
=
cta_mma_qk
.
partition_A
(
gQR
);
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_q
.
begin
()),
SmemLayoutQ
{});
Tensor
sKC
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_kc
.
begin
()),
SmemLayoutKC
{});
Tensor
sVC
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_vc
.
begin
()),
SmemLayoutVC
{});
auto
make_copy_for
=
[](
auto
sT
)
{
auto
rT_a
=
sT
.
layout
()(
_
,
_
,
_
,
_0
{});
auto
rT
=
make_ordered_layout
(
shape
(
rT_a
),
stride
(
rT_a
));
auto
threads
=
Int
<
kNumLoadWarps
*
cutlass
::
NumThreadsPerWarp
>
{};
auto
values
=
Int
<
sizeof
(
uint128_t
)
/
sizeof
(
Element
)
>
{};
return
make_cotiled_copy
(
Copy_Atom
<
SM80_CP_ASYNC_CACHEALWAYS
<
uint128_t
>
,
Element
>
{},
make_ordered_layout
(
make_shape
(
threads
,
values
),
make_stride
(
_1
{},
_0
{})),
rT
);
};
// like cute::copy, but makes sure we do all page table lookups first
auto
copy_split
=
[](
auto
atom
,
auto
src
,
auto
dst
)
{
auto
src_v
=
group_modes
<
1
,
rank_v
<
decltype
(
src
)
>>
(
src
);
auto
dst_v
=
group_modes
<
1
,
rank_v
<
decltype
(
dst
)
>>
(
dst
);
auto
src_v_ptrs
=
make_tensor
<
Element
*>
(
size
<
1
>
(
src_v
));
for
(
int
i
=
0
;
i
<
size
<
1
>
(
src_v
);
i
++
)
{
src_v_ptrs
(
i
)
=
&
src_v
(
_0
{},
i
);
}
for
(
int
i
=
0
;
i
<
size
<
1
>
(
src_v
);
i
++
)
{
auto
src_v_i
=
make_tensor
(
make_gmem_ptr
(
src_v_ptrs
(
i
)),
make_shape
(
shape
<
0
>
(
src_v
)),
make_stride
(
make_stride
(
_1
{},
_0
{}))
);
atom
.
call
(
src_v_i
,
dst_v
(
_
,
i
));
}
};
auto
tiled_copy_q
=
make_copy_for
(
sQ
);
auto
tiled_copy_kc
=
make_copy_for
(
sKC
);
auto
tiled_copy_vc
=
make_copy_for
(
sVC
);
auto
thr_copy_q
=
tiled_copy_q
.
get_thread_slice
(
threadIdx
.
x
%
(
kNumLoadWarps
*
cutlass
::
NumThreadsPerWarp
));
auto
thr_copy_kc
=
tiled_copy_kc
.
get_thread_slice
(
threadIdx
.
x
%
(
kNumLoadWarps
*
cutlass
::
NumThreadsPerWarp
));
auto
thr_copy_vc
=
tiled_copy_vc
.
get_thread_slice
(
threadIdx
.
x
%
(
kNumLoadWarps
*
cutlass
::
NumThreadsPerWarp
));
auto
tQsQ
=
thr_copy_q
.
partition_D
(
sQ
);
auto
tQgQL
=
thr_copy_q
.
partition_S
(
tSgQL
);
auto
tQgQR
=
thr_copy_q
.
partition_S
(
tSgQR
);
auto
tKCsKC
=
thr_copy_kc
.
partition_D
(
sKC
);
auto
tVCsVC
=
thr_copy_vc
.
partition_D
(
sVC
);
auto
pipeline_pt_release_state
=
pipeline_pt_consumer_state
;
int
page_table_stage
=
-
1
;
Pow2
pages_per_tile
{
TileShapeS
{}
/
paged_K
};
const
int
*
__restrict__
smem_page_table
=
shared_tensors
.
smem_page_table
.
begin
();
Gather
gather
{
page_table_stage
,
pages_per_tile
,
smem_page_table
};
auto
mCL
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_c_latent
),
ComposedLayout
{
make_layout
(
make_shape
(
make_shape
(
paged_K
,
paged_B
),
_1
{}),
make_stride
(
make_stride
(
get
<
0
>
(
mainloop_args
.
stride_c_latent
),
example
::
CustomStride
(
gather
,
get
<
2
>
(
mainloop_args
.
stride_c_latent
))),
get
<
1
>
(
mainloop_args
.
stride_c_latent
))),
make_coord
(
_0
{},
_0
{}),
make_identity_layout
(
make_shape
(
paged_K
*
paged_B
,
D_latent
))});
auto
mKR
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_k_rope
),
ComposedLayout
{
make_layout
(
make_shape
(
make_shape
(
paged_K
,
paged_B
),
_1
{}),
make_stride
(
make_stride
(
get
<
0
>
(
mainloop_args
.
stride_k_rope
),
example
::
CustomStride
(
gather
,
get
<
2
>
(
mainloop_args
.
stride_k_rope
))),
get
<
1
>
(
mainloop_args
.
stride_k_rope
))),
make_coord
(
_0
{},
_0
{}),
make_identity_layout
(
make_shape
(
paged_K
*
paged_B
,
D_latent
))});
auto
mCLT
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_c_latent
),
ComposedLayout
{
make_layout
(
make_shape
(
_1
{},
make_shape
(
paged_K
,
paged_B
)),
make_stride
(
get
<
1
>
(
mainloop_args
.
stride_c_latent
),
make_stride
(
get
<
0
>
(
mainloop_args
.
stride_c_latent
),
example
::
CustomStride
(
gather
,
get
<
2
>
(
mainloop_args
.
stride_c_latent
))))),
make_coord
(
_0
{},
_0
{}),
make_identity_layout
(
make_shape
(
D_latent
,
paged_K
*
paged_B
))});
auto
gCL
=
local_tile
(
mCL
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
auto
gKR
=
local_tile
(
mKR
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
auto
gCLT
=
local_tile
(
mCLT
,
TileShapePV
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
auto
tSgCL
=
cta_mma_qk
.
partition_B
(
gCL
);
auto
tSgKR
=
cta_mma_qk
.
partition_B
(
gKR
);
auto
tOgCLT
=
cta_mma_pv
.
partition_B
(
gCLT
);
auto
tKCgCL
=
thr_copy_kc
.
partition_S
(
tSgCL
);
auto
tKCgKR
=
thr_copy_kc
.
partition_S
(
tSgKR
);
auto
tVCgCLT
=
thr_copy_vc
.
partition_S
(
tOgCLT
);
// latent is first in memory, so let's load it first always
// startup: alternate Q and K, set tx count appropriately, for k_idx = 0
auto
&
pipeline_acquire_state
=
pipeline_load_producer_state
;
auto
pipeline_commit_state
=
pipeline_acquire_state
;
int
pipeline_offset
=
0
;
for
(
int
i
=
0
;
i
<
StagesPV
;
i
++
)
{
cutlass
::
arch
::
cp_async_fence
();
}
auto
load_stage
=
[
&
](
auto
fn
)
{
pipeline_load
.
producer_acquire
(
pipeline_acquire_state
);
fn
(
pipeline_acquire_state
.
index
());
cutlass
::
arch
::
cp_async_fence
();
++
pipeline_acquire_state
;
++
pipeline_offset
;
if
(
pipeline_offset
==
StagesPV
-
1
)
{
cutlass
::
arch
::
cp_async_wait
<
StagesPV
-
1
>
();
pipeline_load
.
producer_commit
(
pipeline_commit_state
);
++
pipeline_commit_state
;
--
pipeline_offset
;
}
};
pipeline_page_table
.
consumer_wait
(
pipeline_pt_consumer_state
);
page_table_stage
=
pipeline_pt_consumer_state
.
index
();
++
pipeline_pt_consumer_state
;
// each Q/K tile consists of rope and latent
for
(
int
i
=
0
;
i
<
IterationsQKLatent
;
i
++
)
{
load_stage
([
&
](
int
index
)
{
cute
::
copy
(
tiled_copy_q
,
tQgQL
(
_
,
_
,
_
,
_
,
_0
{},
i
,
batch_coord
),
tQsQ
(
_
,
_
,
_
,
_
,
i
));
copy_split
(
tiled_copy_kc
,
tKCgCL
(
_
,
_
,
_
,
_
,
k_index
,
i
),
tKCsKC
(
_
,
_
,
_
,
_
,
index
));
});
}
for
(
int
i
=
0
;
i
<
IterationsQKRope
;
i
++
)
{
load_stage
([
&
](
int
index
)
{
cute
::
copy
(
tiled_copy_q
,
tQgQR
(
_
,
_
,
_
,
_
,
_0
{},
i
,
batch_coord
),
tQsQ
(
_
,
_
,
_
,
_
,
IterationsQKLatent
+
i
));
copy_split
(
tiled_copy_kc
,
tKCgKR
(
_
,
_
,
_
,
_
,
k_index
,
i
),
tKCsKC
(
_
,
_
,
_
,
_
,
index
));
});
}
k_index
+=
1
;
k_tile_count
-=
1
;
// assume k_tile_count >= 1
// perform K+Q load here
CUTLASS_PRAGMA_NO_UNROLL
while
(
k_tile_count
>
0
)
{
pipeline_page_table
.
consumer_wait
(
pipeline_pt_consumer_state
);
page_table_stage
=
pipeline_pt_consumer_state
.
index
();
++
pipeline_pt_consumer_state
;
for
(
int
i
=
0
;
i
<
IterationsQKLatent
;
i
++
)
{
load_stage
([
&
](
int
index
)
{
copy_split
(
tiled_copy_kc
,
tKCgCL
(
_
,
_
,
_
,
_
,
k_index
,
i
),
tKCsKC
(
_
,
_
,
_
,
_
,
index
));
});
}
for
(
int
i
=
0
;
i
<
IterationsQKRope
;
i
++
)
{
load_stage
([
&
](
int
index
)
{
copy_split
(
tiled_copy_kc
,
tKCgKR
(
_
,
_
,
_
,
_
,
k_index
,
i
),
tKCsKC
(
_
,
_
,
_
,
_
,
index
));
});
}
page_table_stage
=
pipeline_pt_release_state
.
index
();
for
(
int
i
=
0
;
i
<
IterationsPV_K
;
i
++
)
{
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
load_stage
([
&
](
int
index
)
{
copy_split
(
tiled_copy_vc
,
tVCgCLT
(
_
,
_
,
_
,
_
,
j
,
IterationsPV_K
*
(
k_index
-
1
)
+
i
),
tVCsVC
(
_
,
_
,
_
,
_
,
index
));
});
}
}
pipeline_page_table
.
consumer_release
(
pipeline_pt_release_state
);
++
pipeline_pt_release_state
;
k_index
+=
1
;
k_tile_count
-=
1
;
}
page_table_stage
=
pipeline_pt_release_state
.
index
();
for
(
int
i
=
0
;
i
<
IterationsPV_K
;
i
++
)
{
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
load_stage
([
&
](
int
index
)
{
copy_split
(
tiled_copy_vc
,
tVCgCLT
(
_
,
_
,
_
,
_
,
j
,
IterationsPV_K
*
(
k_index
-
1
)
+
i
),
tVCsVC
(
_
,
_
,
_
,
_
,
index
));
});
}
}
pipeline_page_table
.
consumer_release
(
pipeline_pt_release_state
);
++
pipeline_pt_release_state
;
while
(
pipeline_offset
>
0
)
{
cutlass
::
arch
::
cp_async_fence
();
cutlass
::
arch
::
cp_async_wait
<
StagesPV
-
1
>
();
pipeline_load
.
producer_commit
(
pipeline_commit_state
);
++
pipeline_commit_state
;
--
pipeline_offset
;
}
cutlass
::
arch
::
cp_async_wait
<
0
>
();
}
template
<
bool
kIsPaged
=
false
,
class
BlkCoord
>
CUTLASS_DEVICE
void
load_tma
(
BlkCoord
const
&
blk_coord
,
ProblemShape
const
&
problem_shape
,
MainloopArguments
const
&
mainloop_args
,
MainloopParams
const
&
mainloop_params
,
TensorStorage
&
shared_tensors
,
PipelineLoadQK
&
pipeline_load_qk
,
typename
PipelineLoadQK
::
PipelineState
&
pipeline_load_qk_producer_state
,
PipelineLoadPV
&
pipeline_load_pv
,
typename
PipelineLoadPV
::
PipelineState
&
pipeline_load_pv_producer_state
,
int
const
&
split_kv
)
{
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
int
k_tile_total
=
ceil_div
(
K
,
TileShapeS
{});
int
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
split_kv
);
int
k_index
=
get
<
3
>
(
blk_coord
)
*
k_tile_per_cta
;
// lower limit
int
k_tile_count
=
max
(
0
,
min
(
k_tile_total
,
k_index
+
k_tile_per_cta
)
-
k_index
);
if
(
k_tile_count
==
0
)
{
return
;
}
using
X
=
Underscore
;
// partition all tensors
auto
mQL
=
mainloop_params
.
tma_load_q_latent
.
get_tma_tensor
(
make_shape
(
H
,
D_latent
,
B
));
auto
mQR
=
mainloop_params
.
tma_load_q_rope
.
get_tma_tensor
(
make_shape
(
H
,
D_rope
,
B
));
int
paged_B
=
B
;
int
paged_K
=
K
;
if
constexpr
(
kIsPaged
)
{
paged_B
=
mainloop_args
.
page_count
;
paged_K
=
mainloop_args
.
page_size
;
}
auto
mPT_l
=
make_tensor
(
make_gmem_ptr
(
mainloop_args
.
ptr_page_table
),
make_shape
(
paged_B
,
B
),
mainloop_args
.
stride_page_table
);
auto
mCL
=
mainloop_params
.
tma_load_c_latent
.
get_tma_tensor
(
make_shape
(
paged_K
,
D_latent
,
paged_B
));
auto
mKR
=
mainloop_params
.
tma_load_k_rope
.
get_tma_tensor
(
make_shape
(
paged_K
,
D_rope
,
paged_B
));
auto
mCLT
=
mainloop_params
.
tma_load_c_latent_transpose
.
get_tma_tensor
(
make_shape
(
D_latent
,
paged_K
,
paged_B
));
auto
gQL
=
local_tile
(
mQL
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
auto
gQR
=
local_tile
(
mQR
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
auto
gCL
=
local_tile
(
mCL
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
auto
gKR
=
local_tile
(
mKR
,
TileShapeQK
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
auto
gCLT
=
local_tile
(
mCLT
,
TileShapePV
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
ThrMMA
cta_mma_qk
=
TiledMmaQK
{}.
get_slice
(
get
<
0
>
(
blk_coord
)
%
size
(
AtomThrShapeMNK
{}));
ThrMMA
cta_mma_pv
=
TiledMmaPV
{}.
get_slice
(
get
<
0
>
(
blk_coord
)
%
size
(
AtomThrShapeMNK
{}));
auto
tSgQL
=
cta_mma_qk
.
partition_A
(
gQL
);
auto
tSgQR
=
cta_mma_qk
.
partition_A
(
gQR
);
auto
tSgCL
=
cta_mma_qk
.
partition_B
(
gCL
);
auto
tSgKR
=
cta_mma_qk
.
partition_B
(
gKR
);
auto
tOgCLT
=
cta_mma_pv
.
partition_B
(
gCLT
);
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_q
.
begin
()),
SmemLayoutQ
{});
Tensor
sKC
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_kc
.
begin
()),
SmemLayoutKC
{});
Tensor
sVC
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_vc
.
begin
()),
SmemLayoutVC
{});
auto
[
tQLgQL_mkl
,
tQsQ
]
=
tma_partition
(
mainloop_params
.
tma_load_q_latent
,
_0
{},
make_layout
(
_1
{}),
group_modes
<
0
,
3
>
(
sQ
),
group_modes
<
0
,
3
>
(
tSgQL
));
auto
[
tQRgQR_mkl
,
tQsQ_ignore
]
=
tma_partition
(
mainloop_params
.
tma_load_q_rope
,
_0
{},
make_layout
(
_1
{}),
group_modes
<
0
,
3
>
(
sQ
),
group_modes
<
0
,
3
>
(
tSgQR
));
auto
[
tCLgCL_nkl
,
tKCsKC
]
=
tma_partition
(
mainloop_params
.
tma_load_c_latent
,
_0
{},
make_layout
(
_1
{}),
group_modes
<
0
,
3
>
(
sKC
),
group_modes
<
0
,
3
>
(
tSgCL
));
auto
[
tKRgKR_nkl
,
tKCsKC_ignore
]
=
tma_partition
(
mainloop_params
.
tma_load_k_rope
,
_0
{},
make_layout
(
_1
{}),
group_modes
<
0
,
3
>
(
sKC
),
group_modes
<
0
,
3
>
(
tSgKR
));
auto
[
tCLTgCLT_nkl
,
tVCsVC
]
=
tma_partition
(
mainloop_params
.
tma_load_c_latent_transpose
,
_0
{},
make_layout
(
_1
{}),
group_modes
<
0
,
3
>
(
sVC
),
group_modes
<
0
,
3
>
(
tOgCLT
));
uint16_t
mcast_mask
=
0
;
int
batch_coord
=
get
<
2
>
(
blk_coord
);
Tensor
tQLgQL
=
tQLgQL_mkl
(
_
,
_
,
_
,
batch_coord
);
Tensor
tQRgQR
=
tQRgQR_mkl
(
_
,
_
,
_
,
batch_coord
);
auto
mPT
=
mPT_l
(
_
,
batch_coord
);
Tensor
tCLgCL
=
tCLgCL_nkl
(
_
,
_
,
_
,
_
);
Tensor
tKRgKR
=
tKRgKR_nkl
(
_
,
_
,
_
,
_
);
// careful: stage and k are swapped here!
Tensor
tCLTgCLT
=
tCLTgCLT_nkl
(
_
,
_
,
_
,
_
);
// latent is first in memory, so let's load it first always
// startup: alternate Q and K, set tx count appropriately, for k_idx = 0
// each Q/K tile consists of rope and latent
for
(
int
i
=
0
;
i
<
IterationsQKLatent
;
i
++
)
{
pipeline_load_qk
.
producer_expect_transaction
(
pipeline_load_qk_producer_state
,
kTransactionsBytesLoadExtraQ
);
pipeline_load_qk
.
producer_acquire
(
pipeline_load_qk_producer_state
);
auto
tma_barrier
=
pipeline_load_qk
.
producer_get_barrier
(
pipeline_load_qk_producer_state
);
if
(
cute
::
elect_one_sync
())
{
// expect the extra bytes
// load_qk ql
cute
::
copy
(
mainloop_params
.
tma_load_q_latent
.
with
(
*
tma_barrier
,
mcast_mask
),
tQLgQL
(
_
,
_0
{},
i
),
tQsQ
(
_
,
i
));
// load_qk cl
if
constexpr
(
kIsPaged
)
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent
.
with
(
*
tma_barrier
,
mcast_mask
),
tCLgCL
(
_
,
_0
{},
i
,
mPT
(
k_index
)),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
())
);
}
else
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent
.
with
(
*
tma_barrier
,
mcast_mask
),
tCLgCL
(
_
,
k_index
,
i
,
batch_coord
),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
()));
}
}
++
pipeline_load_qk_producer_state
;
}
for
(
int
i
=
0
;
i
<
IterationsQKRope
;
i
++
)
{
pipeline_load_qk
.
producer_expect_transaction
(
pipeline_load_qk_producer_state
,
kTransactionsBytesLoadExtraQ
);
pipeline_load_qk
.
producer_acquire
(
pipeline_load_qk_producer_state
);
auto
tma_barrier
=
pipeline_load_qk
.
producer_get_barrier
(
pipeline_load_qk_producer_state
);
if
(
cute
::
elect_one_sync
())
{
// expect the extra bytes
// load_qk ql
cute
::
copy
(
mainloop_params
.
tma_load_q_rope
.
with
(
*
tma_barrier
,
mcast_mask
),
tQRgQR
(
_
,
_0
{},
i
),
tQsQ
(
_
,
i
+
IterationsQKLatent
));
// load_qk cl
if
constexpr
(
kIsPaged
)
{
cute
::
copy
(
mainloop_params
.
tma_load_k_rope
.
with
(
*
tma_barrier
,
mcast_mask
),
tKRgKR
(
_
,
_0
{},
i
,
mPT
(
k_index
)),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
())
);
}
else
{
cute
::
copy
(
mainloop_params
.
tma_load_k_rope
.
with
(
*
tma_barrier
,
mcast_mask
),
tKRgKR
(
_
,
k_index
,
i
,
batch_coord
),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
()));
}
}
++
pipeline_load_qk_producer_state
;
}
k_index
+=
1
;
k_tile_count
-=
1
;
// assume k_tile_count >= 1
// perform K+Q load here
CUTLASS_PRAGMA_NO_UNROLL
while
(
k_tile_count
>
0
)
{
// perform K load
for
(
int
i
=
0
;
i
<
IterationsQKLatent
;
i
++
)
{
pipeline_load_qk
.
producer_acquire
(
pipeline_load_qk_producer_state
);
auto
tma_barrier
=
pipeline_load_qk
.
producer_get_barrier
(
pipeline_load_qk_producer_state
);
if
(
cute
::
elect_one_sync
())
{
// load_qk cl
if
constexpr
(
kIsPaged
)
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent
.
with
(
*
tma_barrier
,
mcast_mask
),
tCLgCL
(
_
,
_0
{},
i
,
mPT
(
k_index
)),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
())
);
}
else
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent
.
with
(
*
tma_barrier
,
mcast_mask
),
tCLgCL
(
_
,
k_index
,
i
,
batch_coord
),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
()));
}
}
++
pipeline_load_qk_producer_state
;
}
for
(
int
i
=
0
;
i
<
IterationsQKRope
;
i
++
)
{
pipeline_load_qk
.
producer_acquire
(
pipeline_load_qk_producer_state
);
auto
tma_barrier
=
pipeline_load_qk
.
producer_get_barrier
(
pipeline_load_qk_producer_state
);
if
(
cute
::
elect_one_sync
())
{
// load_qk cl
if
constexpr
(
kIsPaged
)
{
cute
::
copy
(
mainloop_params
.
tma_load_k_rope
.
with
(
*
tma_barrier
,
mcast_mask
),
tKRgKR
(
_
,
_0
{},
i
,
mPT
(
k_index
)),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
())
);
}
else
{
cute
::
copy
(
mainloop_params
.
tma_load_k_rope
.
with
(
*
tma_barrier
,
mcast_mask
),
tKRgKR
(
_
,
k_index
,
i
,
batch_coord
),
tKCsKC
(
_
,
pipeline_load_qk_producer_state
.
index
()));
}
}
++
pipeline_load_qk_producer_state
;
}
// prefetch next K load to keep busy while we transpose-load from cache
const
int
kPrefetchDistance
=
1
;
for
(
int
i
=
0
;
i
<
IterationsQKLatent
;
i
++
)
{
if
(
cute
::
elect_one_sync
())
{
if
constexpr
(
kIsPaged
)
{
if
(
k_tile_count
>
kPrefetchDistance
)
{
cute
::
prefetch
(
mainloop_params
.
tma_load_c_latent
,
tCLgCL
(
_
,
_0
{},
i
,
mPT
(
k_index
+
kPrefetchDistance
))
);
}
}
else
{
cute
::
prefetch
(
mainloop_params
.
tma_load_c_latent
,
tCLgCL
(
_
,
k_index
+
kPrefetchDistance
,
i
,
batch_coord
)
);
}
}
}
for
(
int
i
=
0
;
i
<
IterationsQKRope
;
i
++
)
{
if
(
cute
::
elect_one_sync
())
{
if
constexpr
(
kIsPaged
)
{
if
(
k_tile_count
>
kPrefetchDistance
)
{
cute
::
prefetch
(
mainloop_params
.
tma_load_k_rope
,
tKRgKR
(
_
,
_0
{},
i
,
mPT
(
k_index
+
kPrefetchDistance
))
);
}
}
else
{
cute
::
prefetch
(
mainloop_params
.
tma_load_k_rope
,
tKRgKR
(
_
,
k_index
+
kPrefetchDistance
,
i
,
batch_coord
)
);
}
}
}
// perform V load (k_idx - 1)
for
(
int
i
=
0
;
i
<
IterationsPV_K
;
i
++
)
{
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
pipeline_load_pv
.
producer_acquire
(
pipeline_load_pv_producer_state
);
auto
tma_barrier
=
pipeline_load_pv
.
producer_get_barrier
(
pipeline_load_pv_producer_state
);
if
(
cute
::
elect_one_sync
())
{
// load_pv cl
// note the transpose in indices!
// note we are off-by-one on k_index
if
constexpr
(
kIsPaged
)
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent_transpose
.
with
(
*
tma_barrier
,
mcast_mask
,
cute
::
TMA
::
CacheHintSm100
::
EVICT_FIRST
),
tCLTgCLT
(
_
,
j
,
i
,
mPT
(
k_index
-
1
)),
tVCsVC
(
_
,
pipeline_load_pv_producer_state
.
index
())
);
}
else
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent_transpose
.
with
(
*
tma_barrier
,
mcast_mask
,
cute
::
TMA
::
CacheHintSm100
::
EVICT_FIRST
),
tCLTgCLT
(
_
,
j
,
IterationsPV_K
*
(
k_index
-
1
)
+
i
,
batch_coord
),
tVCsVC
(
_
,
pipeline_load_pv_producer_state
.
index
())
);
}
}
++
pipeline_load_pv_producer_state
;
}
}
k_index
+=
1
;
k_tile_count
-=
1
;
}
for
(
int
i
=
0
;
i
<
IterationsPV_K
;
i
++
)
{
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
pipeline_load_pv
.
producer_acquire
(
pipeline_load_pv_producer_state
);
auto
tma_barrier
=
pipeline_load_pv
.
producer_get_barrier
(
pipeline_load_pv_producer_state
);
if
(
cute
::
elect_one_sync
())
{
// load_pv cl
// note the transpose in indices
// note we are off-by-one on k_index
if
constexpr
(
kIsPaged
)
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent_transpose
.
with
(
*
tma_barrier
,
mcast_mask
,
cute
::
TMA
::
CacheHintSm100
::
EVICT_FIRST
),
tCLTgCLT
(
_
,
j
,
i
,
mPT
(
k_index
-
1
)),
tVCsVC
(
_
,
pipeline_load_pv_producer_state
.
index
())
);
}
else
{
cute
::
copy
(
mainloop_params
.
tma_load_c_latent_transpose
.
with
(
*
tma_barrier
,
mcast_mask
,
cute
::
TMA
::
CacheHintSm100
::
EVICT_FIRST
),
tCLTgCLT
(
_
,
j
,
IterationsPV_K
*
(
k_index
-
1
)
+
i
,
batch_coord
),
tVCsVC
(
_
,
pipeline_load_pv_producer_state
.
index
())
);
}
}
++
pipeline_load_pv_producer_state
;
}
}
}
template
<
class
BlkCoord
>
CUTLASS_DEVICE
void
mma
(
BlkCoord
const
&
blk_coord
,
ProblemShape
const
&
problem_shape
,
TensorStorage
&
shared_tensors
,
PipelineLoadQK
&
pipeline_load_qk
,
typename
PipelineLoadQK
::
PipelineState
&
pipeline_load_qk_consumer_state
,
PipelineLoadPV
&
pipeline_load_pv
,
typename
PipelineLoadPV
::
PipelineState
&
pipeline_load_pv_consumer_state
,
PipelineS
&
pipeline_mma_s
,
typename
PipelineS
::
PipelineState
&
pipeline_mma_s_producer_state
,
PipelineP
&
pipeline_p_mma
,
typename
PipelineP
::
PipelineState
&
pipeline_p_mma_consumer_state
,
PipelineO
&
pipeline_mma_o
,
typename
PipelineO
::
PipelineState
&
pipeline_mma_o_producer_state
,
int
const
&
split_kv
)
{
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
int
k_tile_total
=
ceil_div
(
K
,
TileShapeS
{});
int
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
split_kv
);
int
k_index
=
get
<
3
>
(
blk_coord
)
*
k_tile_per_cta
;
// lower limit
int
k_tile_count
=
max
(
0
,
min
(
k_tile_total
,
k_index
+
k_tile_per_cta
)
-
k_index
);
if
(
k_tile_count
==
0
)
{
return
;
}
// mma init
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_q
.
begin
()),
SmemLayoutQ
{});
Tensor
sKC
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_kc
.
begin
()),
SmemLayoutKC
{});
Tensor
sVC
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_vc
.
begin
()),
SmemLayoutVC
{});
Tensor
sP
=
make_tensor
(
make_smem_ptr
((
Element
*
)
shared_tensors
.
smem_p
.
begin
()),
SmemLayoutP
{});
Tensor
tSrQ
=
TiledMmaQK
::
make_fragment_A
(
sQ
);
Tensor
tSrKC
=
TiledMmaQK
::
make_fragment_B
(
sKC
);
Tensor
tOrP
=
TiledMmaPV
::
make_fragment_A
(
sP
);
Tensor
tOrVC
=
TiledMmaPV
::
make_fragment_B
(
sVC
);
TiledMmaQK
tiled_mma_qk
;
TiledMmaPV
tiled_mma_pv
;
Tensor
tStS
=
partition_fragment_C
(
tiled_mma_qk
,
select
<
0
,
1
>
(
TileShapeQK
{}));
Tensor
tItI
=
partition_fragment_C
(
tiled_mma_pv
,
select
<
0
,
1
>
(
TileShapePV
{}));
tiled_mma_pv
.
accumulate_
=
UMMA
::
ScaleOut
::
Zero
;
pipeline_mma_s
.
producer_acquire
(
pipeline_mma_s_producer_state
);
// Mma S0 S1 O0 S2 O1 ... Sn On-1 On
// S0 ownership -- ----- -- --
// S1 ownership -- ----- ----
// O ownership -- -- ---- --
tiled_mma_qk
.
accumulate_
=
UMMA
::
ScaleOut
::
Zero
;
for
(
int
i
=
0
;
i
<
IterationsQK
;
i
++
)
{
pipeline_load_qk
.
consumer_wait
(
pipeline_load_qk_consumer_state
);
int
read_stage
=
pipeline_load_qk_consumer_state
.
index
();
tStS
.
data
()
=
uint32_t
(
pipeline_mma_s_producer_state
.
index
()
==
0
?
TmemAllocation
::
kS0
:
TmemAllocation
::
kS1
);
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tSrQ
);
++
k_block
)
{
cute
::
gemm
(
tiled_mma_qk
,
tSrQ
(
_
,
_
,
k_block
,
i
),
tSrKC
(
_
,
_
,
k_block
,
read_stage
),
tStS
);
tiled_mma_qk
.
accumulate_
=
UMMA
::
ScaleOut
::
One
;
}
pipeline_load_qk
.
consumer_release
(
pipeline_load_qk_consumer_state
);
++
pipeline_load_qk_consumer_state
;
}
pipeline_mma_s
.
producer_commit
(
pipeline_mma_s_producer_state
);
++
pipeline_mma_s_producer_state
;
k_tile_count
-=
1
;
CUTLASS_PRAGMA_NO_UNROLL
while
(
k_tile_count
>
0
)
{
pipeline_mma_s
.
producer_acquire
(
pipeline_mma_s_producer_state
);
tiled_mma_qk
.
accumulate_
=
UMMA
::
ScaleOut
::
Zero
;
for
(
int
i
=
0
;
i
<
IterationsQK
;
i
++
)
{
pipeline_load_qk
.
consumer_wait
(
pipeline_load_qk_consumer_state
);
int
read_stage
=
pipeline_load_qk_consumer_state
.
index
();
tStS
.
data
()
=
uint32_t
(
pipeline_mma_s_producer_state
.
index
()
==
0
?
TmemAllocation
::
kS0
:
TmemAllocation
::
kS1
);
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tSrQ
);
++
k_block
)
{
cute
::
gemm
(
tiled_mma_qk
,
tSrQ
(
_
,
_
,
k_block
,
i
),
tSrKC
(
_
,
_
,
k_block
,
read_stage
),
tStS
);
tiled_mma_qk
.
accumulate_
=
UMMA
::
ScaleOut
::
One
;
}
pipeline_load_qk
.
consumer_release
(
pipeline_load_qk_consumer_state
);
++
pipeline_load_qk_consumer_state
;
}
pipeline_mma_s
.
producer_commit
(
pipeline_mma_s_producer_state
);
++
pipeline_mma_s_producer_state
;
pipeline_mma_o
.
producer_acquire
(
pipeline_mma_o_producer_state
);
pipeline_p_mma
.
consumer_wait
(
pipeline_p_mma_consumer_state
);
for
(
int
i
=
0
;
i
<
IterationsPV_K
;
i
++
)
{
auto
acc_flag
=
tiled_mma_pv
.
accumulate_
;
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
pipeline_load_pv
.
consumer_wait
(
pipeline_load_pv_consumer_state
);
int
read_stage
=
pipeline_load_pv_consumer_state
.
index
();
tItI
.
data
()
=
uint32_t
(
TmemAllocation
::
kO0
)
+
j
*
uint32_t
(
TmemAllocation
::
kSizeAccO
);
tiled_mma_pv
.
accumulate_
=
acc_flag
;
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tOrP
);
++
k_block
)
{
cute
::
gemm
(
tiled_mma_pv
,
tOrP
(
_
,
_
,
k_block
,
make_coord
(
i
,
pipeline_p_mma_consumer_state
.
index
())),
tOrVC
(
_
,
_
,
k_block
,
read_stage
),
tItI
);
tiled_mma_pv
.
accumulate_
=
UMMA
::
ScaleOut
::
One
;
}
pipeline_load_pv
.
consumer_release
(
pipeline_load_pv_consumer_state
);
++
pipeline_load_pv_consumer_state
;
}
}
pipeline_p_mma
.
consumer_release
(
pipeline_p_mma_consumer_state
);
++
pipeline_p_mma_consumer_state
;
pipeline_mma_o
.
producer_commit
(
pipeline_mma_o_producer_state
);
++
pipeline_mma_o_producer_state
;
--
k_tile_count
;
}
pipeline_mma_o
.
producer_acquire
(
pipeline_mma_o_producer_state
);
pipeline_p_mma
.
consumer_wait
(
pipeline_p_mma_consumer_state
);
for
(
int
i
=
0
;
i
<
IterationsPV_K
;
i
++
)
{
auto
acc_flag
=
tiled_mma_pv
.
accumulate_
;
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
pipeline_load_pv
.
consumer_wait
(
pipeline_load_pv_consumer_state
);
int
read_stage
=
pipeline_load_pv_consumer_state
.
index
();
tItI
.
data
()
=
uint32_t
(
TmemAllocation
::
kO0
)
+
j
*
uint32_t
(
TmemAllocation
::
kSizeAccO
);
tiled_mma_pv
.
accumulate_
=
acc_flag
;
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tOrP
);
++
k_block
)
{
cute
::
gemm
(
tiled_mma_pv
,
tOrP
(
_
,
_
,
k_block
,
make_coord
(
i
,
pipeline_p_mma_consumer_state
.
index
())),
tOrVC
(
_
,
_
,
k_block
,
read_stage
),
tItI
);
tiled_mma_pv
.
accumulate_
=
UMMA
::
ScaleOut
::
One
;
}
pipeline_load_pv
.
consumer_release
(
pipeline_load_pv_consumer_state
);
++
pipeline_load_pv_consumer_state
;
}
}
pipeline_p_mma
.
consumer_release
(
pipeline_p_mma_consumer_state
);
++
pipeline_p_mma_consumer_state
;
pipeline_mma_o
.
producer_commit
(
pipeline_mma_o_producer_state
);
++
pipeline_mma_o_producer_state
;
}
template
<
class
IsLastTile
>
CUTLASS_DEVICE
void
softmax
(
IsLastTile
const
&
is_last_tile
,
ElementAcc
&
row_max
,
ElementAcc
&
row_sum
,
ElementAcc
&
correction_factor
,
ProblemShape
const
&
problem_shape
,
MainloopArguments
const
&
mainloop_args
,
TensorStorage
&
shared_tensors
,
int
k_index
,
uint32_t
tmem_s
,
int
smem_p_index
)
{
auto
load_op
=
cute
::
SM100_TMEM_LOAD_32dp32b32x
{};
TiledMmaQK
tiled_mma_qk
;
Tensor
tStS
=
partition_fragment_C
(
tiled_mma_qk
,
select
<
0
,
1
>
(
TileShapeQK
{}));
tStS
.
data
()
=
tmem_s
;
CUTE_STATIC_ASSERT_V
(
shape
<
1
>
(
tStS
)
==
_1
{});
CUTE_STATIC_ASSERT_V
(
shape
<
2
>
(
tStS
)
==
_1
{});
Tensor
tAcc
=
tStS
(
make_coord
(
_
,
_
),
_0
{},
_0
{});
Tensor
cS
=
make_identity_tensor
(
take
<
0
,
2
>
(
CtaShapeQK
{}));
auto
tiled_t2r
=
make_tmem_copy
(
load_op
,
tAcc
);
auto
thread_idx
=
threadIdx
.
x
%
size
(
tiled_t2r
);
auto
thread_t2r
=
tiled_t2r
.
get_slice
(
thread_idx
);
Tensor
tTR_cS
=
thread_t2r
.
partition_D
(
cS
);
Tensor
tTR_rAcc
=
make_tensor
<
ElementAcc
>
(
shape
(
tTR_cS
));
Tensor
tTR_rS_frag
=
make_tensor
<
Element
>
(
shape
(
tTR_rAcc
));
const
int
AlignmentS
=
4
;
Tensor
tTR_tAcc
=
thread_t2r
.
partition_S
(
tAcc
);
Tensor
tTR_rAcc_vec
=
recast
<
Array
<
ElementAcc
,
AlignmentS
>>
(
tTR_rAcc
);
Tensor
tTR_rS_vec
=
recast
<
Array
<
Element
,
AlignmentS
>>
(
tTR_rS_frag
);
// load s
copy
(
tiled_t2r
,
tTR_tAcc
,
tTR_rAcc
);
if
(
is_last_tile
)
{
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc
);
i
++
)
{
if
(
get
<
1
>
(
tTR_cS
(
i
))
+
TileShapeS
{}
*
k_index
>=
get
<
1
>
(
problem_shape
))
{
tTR_rAcc
(
i
)
=
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
();
}
}
}
// max
ElementAcc
row_max_new
=
row_max
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc
);
i
+=
1
)
{
row_max_new
=
::
fmax
(
row_max_new
,
tTR_rAcc
(
i
));
}
// for 2x2 dp, reduce here
if
constexpr
(
kWarpsInN
>
1
)
{
shared_tensors
.
smem_exchange
[
threadIdx
.
x
]
=
row_max_new
;
cutlass
::
arch
::
NamedBarrier
(
kNumComputeWarps
*
NumThreadsPerWarp
,
kNamedBarrierExchange
).
sync
();
// (64, 2) shape
int
peer_index
=
(
threadIdx
.
x
+
64
)
%
128
;
row_max_new
=
cutlass
::
max
(
row_max_new
,
shared_tensors
.
smem_exchange
[
peer_index
]);
}
#ifndef B2B
// find correction factor
ElementAcc
softmax_scale_log2
=
mainloop_args
.
softmax_scale
*
static_cast
<
ElementAcc
>
(
M_LOG2E
);
correction_factor
=
::
exp2f
(
softmax_scale_log2
*
(
row_max
-
row_max_new
));
row_max
=
row_max_new
;
// softmax
ElementAcc
row_max_scale_log2
=
row_max
*
softmax_scale_log2
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc
);
i
++
)
{
tTR_rAcc
(
i
)
=
::
exp2f
(
softmax_scale_log2
*
tTR_rAcc
(
i
)
-
row_max_scale_log2
);
}
#endif
// quantize
cutlass
::
NumericArrayConverter
<
Element
,
ElementAcc
,
AlignmentS
>
epilogue_op
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc_vec
);
i
++
)
{
tTR_rS_vec
(
i
)
=
epilogue_op
(
tTR_rAcc_vec
(
i
));
}
Tensor
sP
=
make_tensor
(
make_smem_ptr
((
Element
*
)
shared_tensors
.
smem_p
.
begin
()),
SmemLayoutP
{})(
_
,
_
,
_
,
make_coord
(
_
,
smem_p_index
));
Tensor
tOcP
=
TiledMmaPV
{}.
get_slice
(
_0
{}).
partition_A
(
cS
);
// have a mapping for each thread to coord
// find identical mapping to coords for the MMA
auto
l
=
make_ordered_layout
(
make_shape
(
make_shape
(
_64
{},
_2
{}),
make_shape
(
_16
{},
TileShapeS
{}
/
_32
{})),
make_stride
(
make_stride
(
_0
{},
_3
{}),
make_stride
(
_1
{},
_2
{})));
auto
sP_
=
as_position_independent_swizzle_tensor
(
sP
);
copy_aligned
(
tTR_rS_frag
,
sP_
.
compose
(
l
)(
threadIdx
.
x
,
_
));
// sum
row_sum
*=
correction_factor
;
static_assert
(
cute
::
is_same_v
<
ElementAcc
,
float
>
);
auto
tTR_rAcc_float2
=
recast
<
float2
>
(
tTR_rAcc
);
auto
sums
=
make_tensor
<
float2
>
(
_4
{});
static_assert
(
size
(
tTR_rAcc_float2
)
%
size
(
sums
)
==
0
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
sums
);
i
++
)
{
sums
(
i
)
=
tTR_rAcc_float2
(
i
);
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
size
(
sums
);
i
<
size
(
tTR_rAcc_float2
);
i
+=
size
(
sums
))
{
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
size
(
sums
);
j
++
)
{
cute
::
add
(
sums
(
j
),
sums
(
j
),
tTR_rAcc_float2
(
i
+
j
));
}
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
1
;
i
<
size
(
sums
);
i
*=
2
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
size
(
sums
);
j
+=
2
*
i
)
{
cute
::
add
(
sums
(
j
),
sums
(
j
),
sums
(
j
+
i
));
}
}
row_sum
+=
sums
(
0
).
x
+
sums
(
0
).
y
;
}
CUTLASS_DEVICE
void
rescale
(
ElementAcc
correction_factor
,
uint32_t
tmem_o
)
{
// for b2b gemm, do nothing
#ifndef B2B
auto
load_op
=
cute
::
SM100_TMEM_LOAD_32dp32b32x
{};
auto
store_op
=
TMEM
::
tmem_load_to_store
(
load_op
);
TiledMmaPV
tiled_mma_pv
;
Tensor
tItI
=
partition_fragment_C
(
tiled_mma_pv
,
select
<
0
,
1
>
(
TileShapePV
{}));
tItI
.
data
()
=
tmem_o
;
CUTE_STATIC_ASSERT_V
(
shape
<
1
>
(
tItI
)
==
_1
{});
CUTE_STATIC_ASSERT_V
(
shape
<
2
>
(
tItI
)
==
_1
{});
Tensor
tAcc
=
tItI
(
make_coord
(
_
,
_
),
_0
{},
_0
{});
auto
cta_tiler_pv
=
take
<
0
,
2
>
(
typename
CollectiveMmaPV
::
CtaShape_MNK
{});
Tensor
gO
=
make_tensor
(
make_gmem_ptr
((
ElementAcc
*
)
nullptr
),
cta_tiler_pv
,
make_stride
(
0
,
0
));
auto
tiled_t2r
=
make_tmem_copy
(
load_op
,
tAcc
);
auto
tiled_r2t
=
make_tmem_copy
(
store_op
,
tAcc
);
auto
thread_idx
=
threadIdx
.
x
%
size
(
tiled_t2r
);
auto
thread_t2r
=
tiled_t2r
.
get_slice
(
thread_idx
);
auto
thread_r2t
=
tiled_r2t
.
get_slice
(
thread_idx
);
Tensor
tTR_gO
=
thread_t2r
.
partition_D
(
gO
);
Tensor
tTR_rAcc
=
make_tensor
<
ElementAcc
>
(
shape
(
tTR_gO
));
Tensor
tTR_tAcc
=
thread_t2r
.
partition_S
(
tAcc
);
// load o
copy
(
tiled_t2r
,
tTR_tAcc
,
tTR_rAcc
);
// multiply by correction factor
float2
correction_factor_vec
=
make_float2
(
correction_factor
,
correction_factor
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc
);
i
+=
2
)
{
float2
in
=
make_float2
(
tTR_rAcc
(
i
+
0
),
tTR_rAcc
(
i
+
1
));
float2
out
;
cute
::
mul
(
out
,
in
,
correction_factor_vec
);
tTR_rAcc
(
i
+
0
)
=
out
.
x
;
tTR_rAcc
(
i
+
1
)
=
out
.
y
;
}
// store o
copy
(
tiled_r2t
,
tTR_rAcc
,
tTR_tAcc
);
#endif
}
template
<
class
BlkCoord
>
CUTLASS_DEVICE
void
epilogue
(
ElementAcc
&
row_max
,
ElementAcc
&
row_sum
,
BlkCoord
const
&
cta_coord
,
ProblemShape
const
&
problem_shape
,
MainloopArguments
const
&
mainloop_args
,
EpilogueParams
const
&
epilogue_args
,
TensorStorage
&
shared_tensors
,
uint32_t
tmem_o
,
int
const
&
split_kv
)
{
auto
load_op
=
cute
::
SM100_TMEM_LOAD_32dp32b32x
{};
TiledMmaPV
tiled_mma_pv
;
Tensor
tItI
=
TiledMmaPV
::
make_fragment_C
(
partition_shape_C
(
TiledMmaPV
{},
take
<
0
,
2
>
(
TileShapePV
{})));
tItI
.
data
()
=
tmem_o
;
CUTE_STATIC_ASSERT_V
(
shape
<
1
>
(
tItI
)
==
_1
{});
CUTE_STATIC_ASSERT_V
(
shape
<
2
>
(
tItI
)
==
_1
{});
Tensor
tAcc
=
tItI
(
make_coord
(
_
,
_
),
_0
{},
_0
{});
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
if
(
epilogue_args
.
ptr_o_acc
!=
nullptr
)
{
using
ElementOutAcc
=
ElementAcc
;
constexpr
auto
AlignmentOutAcc
=
128
/
cute
::
sizeof_bits_v
<
ElementOutAcc
>
;
Tensor
mO
=
make_tensor
(
make_gmem_ptr
(
epilogue_args
.
ptr_o_acc
+
get
<
3
>
(
cta_coord
)
*
D_latent
),
make_shape
(
H
,
D_latent
,
B
),
epilogue_args
.
stride_o_acc
);
auto
cta_tiler_pv
=
take
<
0
,
2
>
(
typename
CollectiveMmaPV
::
CtaShape_MNK
{});
Tensor
gO
=
local_tile
(
mO
,
cta_tiler_pv
,
take
<
0
,
3
>
(
cta_coord
));
auto
tiled_t2r
=
make_tmem_copy
(
load_op
,
tAcc
);
auto
thread_idx
=
threadIdx
.
x
%
size
(
tiled_t2r
);
auto
thread_t2r
=
tiled_t2r
.
get_slice
(
thread_idx
);
Tensor
tTR_gO
=
thread_t2r
.
partition_D
(
gO
);
Tensor
tTR_rAcc
=
make_tensor
<
ElementAcc
>
(
shape
(
tTR_gO
));
Tensor
tTR_rO_frag
=
make_tensor
<
ElementOutAcc
>
(
shape
(
tTR_rAcc
));
Tensor
tTR_rO_src
=
recast
<
Array
<
ElementOutAcc
,
AlignmentOutAcc
>>
(
coalesce
(
tTR_rO_frag
));
Tensor
tR2G_rO_dst
=
recast
<
Array
<
ElementOutAcc
,
AlignmentOutAcc
>>
(
coalesce
(
tTR_gO
));
Tensor
tTR_tAcc
=
thread_t2r
.
partition_S
(
tAcc
);
copy
(
tiled_t2r
,
tTR_tAcc
,
tTR_rAcc
);
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementOutAcc
,
1
,
ElementAcc
,
ElementAcc
,
cutlass
::
epilogue
::
thread
::
ScaleType
::
OnlyAlphaScaling
>
epilogue_op
({
epilogue_args
.
output_scale
/
row_sum
});
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc
);
i
++
)
{
tTR_rO_frag
(
i
)
=
epilogue_op
(
tTR_rAcc
(
i
));
}
copy
(
tTR_rO_src
,
tR2G_rO_dst
);
#ifndef B2B
// compute LSE
ElementAcc
lse
=
cutlass
::
fast_log
(
row_sum
)
+
mainloop_args
.
softmax_scale
*
row_max
;
// store LSE
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
epilogue_args
.
ptr_lse_acc
+
H
*
get
<
3
>
(
cta_coord
)),
make_shape
(
H
,
B
),
epilogue_args
.
stride_lse_acc
);
Tensor
gLSE
=
local_tile
(
mLSE
,
append
<
3
>
(
cta_tiler_pv
,
_1
{}),
take
<
0
,
3
>
(
cta_coord
),
Step
<
_1
,
Underscore
,
_1
>
{});
// for 2x2 dp, this must be conditional and the index is wrong
if
(
!
kIs2Sm
||
(
threadIdx
.
x
<
64
))
{
gLSE
(
threadIdx
.
x
)
=
lse
;
}
#endif
}
else
{
Tensor
mO
=
make_tensor
(
make_gmem_ptr
(
epilogue_args
.
ptr_o
),
make_shape
(
H
,
D_latent
,
B
),
epilogue_args
.
stride_o
);
auto
cta_tiler_pv
=
take
<
0
,
2
>
(
typename
CollectiveMmaPV
::
CtaShape_MNK
{});
Tensor
gO
=
local_tile
(
mO
,
cta_tiler_pv
,
take
<
0
,
3
>
(
cta_coord
));
auto
tiled_t2r
=
make_tmem_copy
(
load_op
,
tAcc
);
auto
thread_idx
=
threadIdx
.
x
%
size
(
tiled_t2r
);
auto
thread_t2r
=
tiled_t2r
.
get_slice
(
thread_idx
);
Tensor
tTR_gO
=
thread_t2r
.
partition_D
(
gO
);
Tensor
tTR_rAcc
=
make_tensor
<
ElementAcc
>
(
shape
(
tTR_gO
));
Tensor
tTR_rO_frag
=
make_tensor
<
ElementOut
>
(
shape
(
tTR_rAcc
));
Tensor
tTR_rO_src
=
recast
<
Array
<
ElementOut
,
AlignmentOut
>>
(
coalesce
(
tTR_rO_frag
));
Tensor
tR2G_rO_dst
=
recast
<
Array
<
ElementOut
,
AlignmentOut
>>
(
coalesce
(
tTR_gO
));
Tensor
tTR_tAcc
=
thread_t2r
.
partition_S
(
tAcc
);
copy
(
tiled_t2r
,
tTR_tAcc
,
tTR_rAcc
);
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementOut
,
1
,
ElementAcc
,
ElementAcc
,
cutlass
::
epilogue
::
thread
::
ScaleType
::
OnlyAlphaScaling
>
epilogue_op
({
epilogue_args
.
output_scale
/
row_sum
});
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
tTR_rAcc
);
i
++
)
{
tTR_rO_frag
(
i
)
=
epilogue_op
(
tTR_rAcc
(
i
));
}
copy
(
tTR_rO_src
,
tR2G_rO_dst
);
#ifndef B2B
if
(
epilogue_args
.
ptr_lse
!=
nullptr
)
{
// compute LSE
ElementAcc
lse
=
cutlass
::
fast_log
(
row_sum
)
+
mainloop_args
.
softmax_scale
*
row_max
;
// store LSE
Tensor
mLSE
=
make_tensor
(
make_gmem_ptr
(
epilogue_args
.
ptr_lse
),
make_shape
(
H
,
B
),
epilogue_args
.
stride_lse
);
Tensor
gLSE
=
local_tile
(
mLSE
,
append
<
3
>
(
cta_tiler_pv
,
_1
{}),
take
<
0
,
3
>
(
cta_coord
),
Step
<
_1
,
Underscore
,
_1
>
{});
// for 2x2 dp, this must be conditional and the index is wrong
if
(
!
kIs2Sm
||
(
threadIdx
.
x
<
64
))
{
gLSE
(
threadIdx
.
x
)
=
lse
;
}
}
#endif
}
}
template
<
class
CtaCoord
>
CUTLASS_DEVICE
void
compute
(
CtaCoord
const
&
cta_coord
,
ProblemShape
const
&
problem_shape
,
MainloopArguments
const
&
mainloop_args
,
EpilogueParams
const
&
epilogue_args
,
TensorStorage
&
shared_tensors
,
PipelineS
&
pipeline_mma_s
,
typename
PipelineS
::
PipelineState
&
pipeline_mma_s_consumer_state
,
PipelineP
&
pipeline_p_mma
,
typename
PipelineP
::
PipelineState
&
pipeline_p_mma_producer_state
,
PipelineO
&
pipeline_mma_o
,
typename
PipelineO
::
PipelineState
&
pipeline_mma_o_consumer_state
,
int
const
&
split_kv
)
{
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
int
k_tile_total
=
ceil_div
(
K
,
TileShapeS
{});
int
k_tile_per_cta
=
ceil_div
(
k_tile_total
,
split_kv
);
int
k_index
=
get
<
3
>
(
cta_coord
)
*
k_tile_per_cta
;
// lower limit
int
k_tile_count
=
max
(
0
,
min
(
k_tile_total
,
k_index
+
k_tile_per_cta
)
-
k_index
);
if
(
k_tile_count
==
0
)
{
// if we return early, we have to make sure we release the load warp
cutlass
::
arch
::
NamedBarrier
(
(
kNumComputeWarps
+
kNumLoadWarps
)
*
NumThreadsPerWarp
,
kNamedBarrierEpilogue
).
arrive
();
return
;
}
int
k_index_final
=
k_tile_total
-
1
;
ElementAcc
row_max
=
-
std
::
numeric_limits
<
ElementAcc
>::
infinity
();
ElementAcc
row_sum
=
0
;
ElementAcc
correction_factor
=
1
;
pipeline_p_mma
.
producer_acquire
(
pipeline_p_mma_producer_state
);
pipeline_mma_s
.
consumer_wait
(
pipeline_mma_s_consumer_state
);
auto
dispatch_bool
=
[](
bool
b
,
auto
fn
)
{
if
(
b
)
{
fn
(
cute
::
true_type
{});
}
else
{
fn
(
cute
::
false_type
{});
}
};
// softmax s0 -> p0
dispatch_bool
(
k_index
==
k_index_final
,
[
&
](
auto
is_last_tile
)
{
softmax
(
is_last_tile
,
row_max
,
row_sum
,
correction_factor
,
problem_shape
,
mainloop_args
,
shared_tensors
,
k_index
,
uint32_t
(
pipeline_mma_s_consumer_state
.
index
()
==
0
?
TmemAllocation
::
kS0
:
TmemAllocation
::
kS1
),
pipeline_p_mma_producer_state
.
index
()
);
});
k_index
+=
1
;
cutlass
::
arch
::
fence_view_async_tmem_load
();
cutlass
::
arch
::
fence_view_async_shared
();
pipeline_mma_s
.
consumer_release
(
pipeline_mma_s_consumer_state
);
++
pipeline_mma_s_consumer_state
;
pipeline_p_mma
.
producer_commit
(
pipeline_p_mma_producer_state
);
++
pipeline_p_mma_producer_state
;
k_tile_count
-=
1
;
CUTLASS_PRAGMA_NO_UNROLL
while
(
k_tile_count
>
0
)
{
pipeline_p_mma
.
producer_acquire
(
pipeline_p_mma_producer_state
);
pipeline_mma_s
.
consumer_wait
(
pipeline_mma_s_consumer_state
);
// softmax s1 -> p1
dispatch_bool
(
k_index
==
k_index_final
,
[
&
](
auto
is_last_tile
)
{
softmax
(
is_last_tile
,
row_max
,
row_sum
,
correction_factor
,
problem_shape
,
mainloop_args
,
shared_tensors
,
k_index
,
uint32_t
(
pipeline_mma_s_consumer_state
.
index
()
==
0
?
TmemAllocation
::
kS0
:
TmemAllocation
::
kS1
),
pipeline_p_mma_producer_state
.
index
()
);
});
cutlass
::
arch
::
fence_view_async_tmem_load
();
cutlass
::
arch
::
fence_view_async_shared
();
pipeline_mma_s
.
consumer_release
(
pipeline_mma_s_consumer_state
);
++
pipeline_mma_s_consumer_state
;
pipeline_p_mma
.
producer_commit
(
pipeline_p_mma_producer_state
);
++
pipeline_p_mma_producer_state
;
pipeline_mma_o
.
consumer_wait
(
pipeline_mma_o_consumer_state
);
// rescale
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
rescale
(
correction_factor
,
uint32_t
(
TmemAllocation
::
kO0
)
+
j
*
uint32_t
(
TmemAllocation
::
kSizeAccO
));
}
cutlass
::
arch
::
fence_view_async_tmem_store
();
pipeline_mma_o
.
consumer_release
(
pipeline_mma_o_consumer_state
);
++
pipeline_mma_o_consumer_state
;
--
k_tile_count
;
k_index
+=
1
;
}
pipeline_mma_o
.
consumer_wait
(
pipeline_mma_o_consumer_state
);
#ifdef B2B
row_sum
=
1
;
#else
if
constexpr
(
kWarpsInN
>
1
)
{
// reduce row_sum if needed (for 2x2 dp)
shared_tensors
.
smem_exchange
[
threadIdx
.
x
]
=
row_sum
;
cutlass
::
arch
::
NamedBarrier
(
kNumComputeWarps
*
NumThreadsPerWarp
,
kNamedBarrierExchange
).
sync
();
// (64, 2) shape
int
peer_index
=
(
threadIdx
.
x
+
64
)
%
128
;
row_sum
+=
shared_tensors
.
smem_exchange
[
peer_index
];
}
#endif
cutlass
::
arch
::
NamedBarrier
((
kNumComputeWarps
+
kNumLoadWarps
)
*
NumThreadsPerWarp
,
kNamedBarrierEpilogue
).
arrive
();
// epilogue
CUTLASS_PRAGMA_UNROLL
for
(
int
j
=
0
;
j
<
IterationsPV_N
;
j
++
)
{
epilogue
(
row_max
,
row_sum
,
replace
<
1
>
(
cta_coord
,
j
),
problem_shape
,
mainloop_args
,
epilogue_args
,
shared_tensors
,
uint32_t
(
TmemAllocation
::
kO0
)
+
j
*
uint32_t
(
TmemAllocation
::
kSizeAccO
),
split_kv
);
}
cutlass
::
arch
::
fence_view_async_tmem_load
();
pipeline_mma_o
.
consumer_release
(
pipeline_mma_o_consumer_state
);
++
pipeline_mma_o_consumer_state
;
}
};
///////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::fmha::kernel
csrc/attention/mla/cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp
0 → 100644
View file @
711aa9d5
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace
cutlass
::
fmha
::
kernel
{
////////////////////////////////////////////////////////////////////////////////
struct
Sm100MlaIndividualTileScheduler
{
struct
Params
{
dim3
grid
;
};
bool
valid_
=
true
;
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler
(
Params
const
&
)
{}
template
<
class
ProblemShape
,
class
ClusterShape
>
static
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
KernelHardwareInfo
hw_info
,
ClusterShape
const
&
cluster_shape
,
int
const
&
split_kv
)
{
using
namespace
cute
;
dim3
grid
(
get
<
0
>
(
cluster_shape
),
get
<
3
>
(
problem_shape
)
/* Batch */
,
split_kv
/*Maximum Split KV*/
);
return
Params
{
grid
};
}
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
return
params
.
grid
;
}
CUTLASS_DEVICE
bool
is_valid
()
{
return
valid_
;
}
CUTLASS_DEVICE
auto
get_block_coord
()
{
using
namespace
cute
;
return
make_coord
(
blockIdx
.
x
,
_0
{},
blockIdx
.
y
,
blockIdx
.
z
);
}
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler
&
operator
++
()
{
valid_
=
false
;
return
*
this
;
}
};
////////////////////////////////////////////////////////////////////////////////
struct
Sm100MlaPersistentTileScheduler
{
struct
Params
{
int
num_blocks
;
FastDivmod
divmod_m_block
;
FastDivmod
divmod_b
;
FastDivmod
divmod_split_kv
;
KernelHardwareInfo
hw_info
;
};
int
block_idx
=
0
;
Params
params
;
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler
(
Params
const
&
params
)
:
block_idx
(
blockIdx
.
x
),
params
(
params
)
{}
template
<
class
ProblemShape
,
class
ClusterShape
>
static
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
KernelHardwareInfo
hw_info
,
ClusterShape
const
&
cluster_shape
,
int
const
&
split_kv
)
{
using
namespace
cute
;
// Get SM count if needed, otherwise use user supplied SM count
int
sm_count
=
hw_info
.
sm_count
;
if
(
sm_count
<=
1
||
sm_count
%
size
<
0
>
(
cluster_shape
)
!=
0
)
{
CUTLASS_TRACE_HOST
(
" WARNING: Arguments do not include a valid SM count.
\n
"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."
);
sm_count
=
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
}
CUTLASS_TRACE_HOST
(
"to_underlying_arguments(): Setting persistent grid SM count to "
<<
sm_count
);
hw_info
.
sm_count
=
sm_count
;
int
num_m_blocks
=
size
<
0
>
(
cluster_shape
);
int
num_blocks
=
num_m_blocks
*
get
<
3
>
(
problem_shape
)
/* Batch */
;
num_blocks
*=
split_kv
;
/* Maximum Split KV*/
return
Params
{
num_blocks
,
{
num_m_blocks
},
{
get
<
3
>
(
problem_shape
)
},
{
split_kv
},
hw_info
};
}
static
dim3
get_grid_shape
(
Params
const
&
params
)
{
dim3
grid
(
std
::
min
(
params
.
num_blocks
,
params
.
hw_info
.
sm_count
),
1
,
1
);
return
grid
;
}
CUTLASS_DEVICE
bool
is_valid
()
{
return
block_idx
<
params
.
num_blocks
;
}
CUTLASS_DEVICE
auto
get_block_coord
()
{
using
namespace
cute
;
int
block_decode
=
block_idx
;
int
m_block
,
bidb
,
n_split_kv
;
params
.
divmod_m_block
(
block_decode
,
m_block
,
block_decode
);
params
.
divmod_b
(
block_decode
,
bidb
,
block_decode
);
params
.
divmod_split_kv
(
block_decode
,
n_split_kv
,
block_decode
);
return
make_coord
(
m_block
,
_0
{},
bidb
,
n_split_kv
);
}
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler
&
operator
++
()
{
block_idx
+=
gridDim
.
x
;
return
*
this
;
}
};
////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::fmha::kernel
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
0 → 100644
View file @
711aa9d5
/*
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
#include "core/registration.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/cutlass.h>
#include <cutlass/kernel_hardware_info.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <iostream>
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
// clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
void
sm100_cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
,
int64_t
num_kv_splits
)
{
TORCH_CHECK
(
false
,
"CUDA version must be >= 12.4 for cutlass_mla_decode"
);
}
int64_t
sm100_cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
,
int64_t
num_kv_splits
)
{
TORCH_CHECK
(
false
,
"CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size"
);
}
#else
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
using
namespace
cute
;
using
namespace
cutlass
::
fmha
::
kernel
;
template
<
bool
v
>
struct
IsPersistent
{
static
const
bool
value
=
v
;
};
template
<
typename
T
,
bool
IsPaged128
,
typename
PersistenceOption
=
IsPersistent
<
true
>
>
struct
MlaSm100
{
using
Element
=
T
;
using
ElementAcc
=
float
;
using
ElementOut
=
T
;
using
TileShape
=
Shape
<
_128
,
_128
,
Shape
<
_512
,
_64
>>
;
using
TileShapeH
=
cute
::
tuple_element_t
<
0
,
TileShape
>
;
using
TileShapeD
=
cute
::
tuple_element_t
<
2
,
TileShape
>
;
// H K (D_latent D_rope) B
using
ProblemShape
=
cute
::
tuple
<
TileShapeH
,
int
,
TileShapeD
,
int
>
;
using
StrideQ
=
cute
::
tuple
<
int64_t
,
_1
,
int64_t
>
;
// H D B
using
StrideK
=
cute
::
tuple
<
int64_t
,
_1
,
int64_t
>
;
// K D B
using
StrideO
=
StrideK
;
// H D B
using
StrideLSE
=
cute
::
tuple
<
_1
,
int
>
;
// H B
using
TileScheduler
=
std
::
conditional_t
<
PersistenceOption
::
value
,
Sm100MlaPersistentTileScheduler
,
Sm100MlaIndividualTileScheduler
>
;
using
FmhaKernel
=
cutlass
::
fmha
::
kernel
::
Sm100FmhaMlaKernelTmaWarpspecialized
<
TileShape
,
Element
,
ElementAcc
,
ElementOut
,
ElementAcc
,
TileScheduler
,
/*kIsCpAsync=*/
!
IsPaged128
>
;
using
Fmha
=
cutlass
::
fmha
::
device
::
MLA
<
FmhaKernel
>
;
};
template
<
typename
T
>
typename
T
::
Fmha
::
Arguments
args_from_options
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
double
sm_scale
,
int64_t
num_kv_splits
)
{
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
q_nope
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
int
batches
=
q_nope
.
sizes
()[
0
];
int
page_count_per_seq
=
page_table
.
sizes
()[
1
];
int
page_count_total
=
kv_c_and_k_pe_cache
.
sizes
()[
0
];
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
int
max_seq_len
=
page_size
*
page_count_per_seq
;
using
TileShapeH
=
typename
T
::
TileShapeH
;
using
TileShapeD
=
typename
T
::
TileShapeD
;
auto
problem_shape
=
cute
::
make_tuple
(
TileShapeH
{},
max_seq_len
,
TileShapeD
{},
batches
);
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
float
scale
=
float
(
sm_scale
);
using
StrideQ
=
typename
T
::
StrideQ
;
using
StrideK
=
typename
T
::
StrideK
;
using
StrideO
=
typename
T
::
StrideO
;
using
StrideLSE
=
typename
T
::
StrideLSE
;
StrideQ
stride_Q_nope
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
q_nope
.
stride
(
1
)),
_1
{},
static_cast
<
int64_t
>
(
q_nope
.
stride
(
0
)));
StrideQ
stride_Q_pe
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
q_pe
.
stride
(
1
)),
_1
{},
static_cast
<
int64_t
>
(
q_pe
.
stride
(
0
)));
StrideK
stride_C
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
0
+
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
page_size
*
(
D_latent
+
D_rope
)));
StrideLSE
stride_PT
=
cute
::
make_stride
(
_1
{},
page_count_per_seq
);
StrideLSE
stride_LSE
=
cute
::
make_tuple
(
_1
{},
0
+
H
);
StrideO
stride_O
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
0
+
D_latent
),
_1
{},
static_cast
<
int64_t
>
(
0
+
H
*
D_latent
));
using
Element
=
typename
T
::
Element
;
using
ElementOut
=
typename
T
::
ElementOut
;
using
ElementAcc
=
typename
T
::
ElementAcc
;
auto
Q_nope_ptr
=
static_cast
<
Element
*>
(
q_nope
.
data_ptr
());
auto
Q_pe_ptr
=
static_cast
<
Element
*>
(
q_pe
.
data_ptr
());
auto
C_ptr
=
static_cast
<
Element
*>
(
kv_c_and_k_pe_cache
.
data_ptr
());
typename
T
::
Fmha
::
Arguments
arguments
{
problem_shape
,
{
scale
,
Q_nope_ptr
,
stride_Q_nope
,
Q_pe_ptr
,
stride_Q_pe
,
C_ptr
,
stride_C
,
C_ptr
+
D_latent
,
stride_C
,
static_cast
<
int
*>
(
seq_lens
.
data_ptr
()),
static_cast
<
int
*>
(
page_table
.
data_ptr
()),
stride_PT
,
page_count_total
,
page_size
},
{
static_cast
<
ElementOut
*>
(
out
.
data_ptr
()),
stride_O
,
static_cast
<
ElementAcc
*>
(
nullptr
),
stride_LSE
},
hw_info
,
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
num_kv_splits
,
// split_kv
nullptr
,
// is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
// workload across available SMs. Consider using var_split_kv for manual
// control if needed.
T
::
Fmha
::
set_split_kv
(
arguments
);
return
arguments
;
}
template
<
typename
Element
,
bool
IsPaged128
,
typename
PersistenceOption
>
void
runMla
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
at
::
Tensor
const
&
workspace
,
double
sm_scale
,
int64_t
num_kv_splits
,
cudaStream_t
stream
)
{
using
MlaSm100Type
=
MlaSm100
<
Element
,
IsPaged128
,
PersistenceOption
>
;
typename
MlaSm100Type
::
Fmha
fmha
;
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
sm_scale
,
num_kv_splits
);
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
fmha
.
initialize
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
fmha
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()
void
sm100_cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
torch
::
Tensor
const
&
workspace
,
double
sm_scale
,
int64_t
num_kv_splits
)
{
auto
in_dtype
=
q_nope
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope
.
get_device
());
const
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
// Maybe per batch split kv will fix this.
DISPATCH_BOOL
(
page_size
==
128
,
IsPaged128
,
[
&
]
{
DISPATCH_BOOL
(
num_kv_splits
<=
1
,
NotManualSplitKV
,
[
&
]
{
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
runMla
<
cutlass
::
half_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runMla
<
cutlass
::
bfloat16_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
runMla
<
cutlass
::
float_e4m3_t
,
IsPaged128
,
IsPersistent
<
NotManualSplitKV
>>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
workspace
,
sm_scale
,
num_kv_splits
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
}
return
true
;
});
return
true
;
});
}
int64_t
sm100_cutlass_mla_get_workspace_size
(
int64_t
max_seq_len
,
int64_t
num_batches
,
int64_t
sm_count
,
int64_t
num_kv_splits
)
{
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
using
MlaSm100Type
=
MlaSm100
<
cutlass
::
half_t
,
true
>
;
// Get split kv. Requires problem shape and sm_count only.
typename
MlaSm100Type
::
Fmha
::
Arguments
arguments
;
using
TileShapeH
=
typename
MlaSm100Type
::
TileShapeH
;
using
TileShapeD
=
typename
MlaSm100Type
::
TileShapeD
;
arguments
.
problem_shape
=
cute
::
make_tuple
(
TileShapeH
{},
static_cast
<
int
>
(
max_seq_len
),
TileShapeD
{},
static_cast
<
int
>
(
num_batches
));
// Assumes device 0 when getting sm_count.
arguments
.
hw_info
.
sm_count
=
sm_count
<=
0
?
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
/*device_id=*/
0
)
:
sm_count
;
arguments
.
split_kv
=
num_kv_splits
;
MlaSm100Type
::
Fmha
::
set_split_kv
(
arguments
);
return
MlaSm100Type
::
Fmha
::
get_workspace_size
(
arguments
);
}
#endif
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"sm100_cutlass_mla_decode"
,
&
sm100_cutlass_mla_decode
);
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CatchAll
,
m
)
{
m
.
impl
(
"sm100_cutlass_mla_get_workspace_size"
,
&
sm100_cutlass_mla_get_workspace_size
);
}
// clang-format on
csrc/attention/paged_attention_v1.cu
View file @
711aa9d5
...
...
@@ -18,12 +18,7 @@
*/
#include "attention_kernels.cuh"
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#include "cuda_compat.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
...
...
@@ -187,7 +182,6 @@ void paged_attention_v1(
CALL_V1_LAUNCHER_BLOCK_SIZE
)
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
csrc/attention/paged_attention_v2.cu
View file @
711aa9d5
...
...
@@ -18,12 +18,7 @@
*/
#include "attention_kernels.cuh"
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#include "cuda_compat.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
...
...
@@ -197,7 +192,6 @@ void paged_attention_v2(
CALL_V2_LAUNCHER_BLOCK_SIZE
)
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
csrc/cpu/cpu_types_arm.hpp
View file @
711aa9d5
...
...
@@ -33,6 +33,8 @@ namespace vec_op {
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
// Number of elements in single ASIMD vector of given Datatype
#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0]))
namespace
{
template
<
typename
T
,
T
...
indexes
,
typename
F
>
...
...
@@ -86,8 +88,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
}
void
save
(
void
*
ptr
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
8
;
int
remainder
=
elem_num
%
8
;
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
])
;
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
])
;
if
(
full_blocks
>
0
)
{
vst1q_f16
(
reinterpret_cast
<
__fp16
*>
(
ptr
),
reg
.
val
[
0
]);
...
...
@@ -197,6 +199,25 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
vcvtq_high_bf16_f32
(
vcvtq_low_bf16_f32
(
v
.
val
[
2
]),
v
.
val
[
3
])})
{};
void
save
(
void
*
ptr
)
const
{
*
reinterpret_cast
<
bfloat16x8x2_t
*>
(
ptr
)
=
reg
;
};
void
save
(
void
*
ptr
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
vst1q_bf16
(
reinterpret_cast
<
__bf16
*>
(
ptr
)
+
NUM_ELEMENTS_REG
(
reg
.
val
[
0
])
*
i
,
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
bfloat16x8_t
temp
=
reg
.
val
[
full_blocks
];
bfloat16_t
*
base
=
reinterpret_cast
<
bfloat16_t
*>
(
ptr
)
+
full_blocks
*
8
;
if
(
remainder
>
0
)
base
[
0
]
=
vgetq_lane_bf16
(
temp
,
0
);
if
(
remainder
>
1
)
base
[
1
]
=
vgetq_lane_bf16
(
temp
,
1
);
if
(
remainder
>
2
)
base
[
2
]
=
vgetq_lane_bf16
(
temp
,
2
);
if
(
remainder
>
3
)
base
[
3
]
=
vgetq_lane_bf16
(
temp
,
3
);
if
(
remainder
>
4
)
base
[
4
]
=
vgetq_lane_bf16
(
temp
,
4
);
if
(
remainder
>
5
)
base
[
5
]
=
vgetq_lane_bf16
(
temp
,
5
);
if
(
remainder
>
6
)
base
[
6
]
=
vgetq_lane_bf16
(
temp
,
6
);
}
};
};
struct
BF16Vec32
:
public
Vec
<
BF16Vec32
>
{
...
...
@@ -213,6 +234,25 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
:
reg
({
vec8_data
.
reg
,
vec8_data
.
reg
,
vec8_data
.
reg
,
vec8_data
.
reg
})
{};
void
save
(
void
*
ptr
)
const
{
*
reinterpret_cast
<
bfloat16x8x4_t
*>
(
ptr
)
=
reg
;
};
void
save
(
void
*
ptr
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
vst1q_bf16
(
reinterpret_cast
<
__bf16
*>
(
ptr
)
+
NUM_ELEMENTS_REG
(
reg
.
val
[
0
])
*
i
,
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
bfloat16x8_t
temp
=
reg
.
val
[
full_blocks
];
bfloat16_t
*
base
=
reinterpret_cast
<
bfloat16_t
*>
(
ptr
)
+
full_blocks
*
8
;
base
[
0
]
=
vgetq_lane_bf16
(
temp
,
0
);
if
(
remainder
>
1
)
base
[
1
]
=
vgetq_lane_bf16
(
temp
,
1
);
if
(
remainder
>
2
)
base
[
2
]
=
vgetq_lane_bf16
(
temp
,
2
);
if
(
remainder
>
3
)
base
[
3
]
=
vgetq_lane_bf16
(
temp
,
3
);
if
(
remainder
>
4
)
base
[
4
]
=
vgetq_lane_bf16
(
temp
,
4
);
if
(
remainder
>
5
)
base
[
5
]
=
vgetq_lane_bf16
(
temp
,
5
);
if
(
remainder
>
6
)
base
[
6
]
=
vgetq_lane_bf16
(
temp
,
6
);
}
};
};
#endif
...
...
@@ -372,6 +412,48 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
}
};
struct
INT32Vec16
:
public
Vec
<
INT32Vec16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
union
AliasReg
{
int32x4x4_t
reg
;
int32_t
values
[
VEC_ELEM_NUM
];
};
int32x4x4_t
reg
;
explicit
INT32Vec16
(
const
void
*
ptr
)
{
reg
.
val
[
0
]
=
vld1q_s32
(
reinterpret_cast
<
const
int32_t
*>
(
ptr
));
reg
.
val
[
1
]
=
vld1q_s32
(
reinterpret_cast
<
const
int32_t
*>
(
ptr
)
+
4
);
reg
.
val
[
2
]
=
vld1q_s32
(
reinterpret_cast
<
const
int32_t
*>
(
ptr
)
+
8
);
reg
.
val
[
3
]
=
vld1q_s32
(
reinterpret_cast
<
const
int32_t
*>
(
ptr
)
+
12
);
}
void
save
(
int32_t
*
ptr
)
const
{
vst1q_s32
(
ptr
,
reg
.
val
[
0
]);
vst1q_s32
(
ptr
+
4
,
reg
.
val
[
1
]);
vst1q_s32
(
ptr
+
8
,
reg
.
val
[
2
]);
vst1q_s32
(
ptr
+
12
,
reg
.
val
[
3
]);
};
void
save
(
int32_t
*
ptr
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
vst1q_s32
(
reinterpret_cast
<
__int32_t
*>
(
ptr
)
+
NUM_ELEMENTS_REG
(
reg
.
val
[
0
])
*
i
,
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
int32x4_t
temp
=
reg
.
val
[
full_blocks
];
int32_t
*
base
=
reinterpret_cast
<
int32_t
*>
(
ptr
)
+
full_blocks
*
4
;
if
(
remainder
>
0
)
base
[
0
]
=
vgetq_lane_s32
(
temp
,
0
);
if
(
remainder
>
1
)
base
[
1
]
=
vgetq_lane_s32
(
temp
,
1
);
if
(
remainder
>
2
)
base
[
2
]
=
vgetq_lane_s32
(
temp
,
2
);
if
(
remainder
>
3
)
base
[
3
]
=
vgetq_lane_s32
(
temp
,
3
);
}
}
};
struct
FP32Vec16
:
public
Vec
<
FP32Vec16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
union
AliasReg
{
...
...
@@ -434,7 +516,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
reg
.
val
[
2
]
=
vcvt_f32_f16
(
vget_low_f16
(
v
.
reg
.
val
[
1
]));
reg
.
val
[
3
]
=
vcvt_f32_f16
(
vget_high_f16
(
v
.
reg
.
val
[
1
]));
};
explicit
FP32Vec16
(
const
INT32Vec16
&
v
)
{
reg
.
val
[
0
]
=
vcvtq_f32_s32
(
v
.
reg
.
val
[
0
]);
reg
.
val
[
1
]
=
vcvtq_f32_s32
(
v
.
reg
.
val
[
1
]);
reg
.
val
[
2
]
=
vcvtq_f32_s32
(
v
.
reg
.
val
[
2
]);
reg
.
val
[
3
]
=
vcvtq_f32_s32
(
v
.
reg
.
val
[
3
]);
};
FP32Vec16
operator
+
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
float32x4x4_t
({
vaddq_f32
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vaddq_f32
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
...
...
@@ -463,6 +550,85 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
vdivq_f32
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
};
FP32Vec16
clamp
(
const
FP32Vec16
&
min
,
const
FP32Vec16
&
max
)
const
{
return
FP32Vec16
(
float32x4x4_t
(
{
vminq_f32
(
max
.
reg
.
val
[
0
],
vmaxq_f32
(
min
.
reg
.
val
[
0
],
reg
.
val
[
0
])),
vminq_f32
(
max
.
reg
.
val
[
1
],
vmaxq_f32
(
min
.
reg
.
val
[
1
],
reg
.
val
[
1
])),
vminq_f32
(
max
.
reg
.
val
[
2
],
vmaxq_f32
(
min
.
reg
.
val
[
2
],
reg
.
val
[
2
])),
vminq_f32
(
max
.
reg
.
val
[
3
],
vmaxq_f32
(
min
.
reg
.
val
[
3
],
reg
.
val
[
3
]))}));
};
FP32Vec16
max
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
float32x4x4_t
({
vmaxq_f32
(
b
.
reg
.
val
[
0
],
reg
.
val
[
0
]),
vmaxq_f32
(
b
.
reg
.
val
[
1
],
reg
.
val
[
1
]),
vmaxq_f32
(
b
.
reg
.
val
[
2
],
reg
.
val
[
2
]),
vmaxq_f32
(
b
.
reg
.
val
[
3
],
reg
.
val
[
3
])}));
};
FP32Vec16
max
(
const
FP32Vec16
&
b
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
float32x4x4_t
temp
;
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
temp
.
val
[
i
]
=
vmaxq_f32
(
b
.
reg
.
val
[
i
],
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
float
max_v
=
std
::
max
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
0
),
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
0
));
temp
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
max_v
,
temp
.
val
[
full_blocks
],
0
);
}
if
(
remainder
>
1
)
{
float
max_v
=
std
::
max
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
1
),
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
1
));
temp
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
max_v
,
temp
.
val
[
full_blocks
],
1
);
}
if
(
remainder
>
2
)
{
float
max_v
=
std
::
max
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
2
),
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
2
));
temp
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
max_v
,
temp
.
val
[
full_blocks
],
2
);
}
return
FP32Vec16
(
temp
);
};
FP32Vec16
min
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
float32x4x4_t
({
vminq_f32
(
b
.
reg
.
val
[
0
],
reg
.
val
[
0
]),
vminq_f32
(
b
.
reg
.
val
[
1
],
reg
.
val
[
1
]),
vminq_f32
(
b
.
reg
.
val
[
2
],
reg
.
val
[
2
]),
vminq_f32
(
b
.
reg
.
val
[
3
],
reg
.
val
[
3
]),
}));
};
FP32Vec16
min
(
const
FP32Vec16
&
b
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
const
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
float32x4x4_t
temp
;
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
temp
.
val
[
i
]
=
vminq_f32
(
b
.
reg
.
val
[
i
],
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
float
min_v
=
std
::
min
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
0
),
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
0
));
temp
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
min_v
,
temp
.
val
[
full_blocks
],
0
);
}
if
(
remainder
>
1
)
{
float
min_v
=
std
::
min
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
1
),
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
1
));
temp
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
min_v
,
temp
.
val
[
full_blocks
],
1
);
}
if
(
remainder
>
2
)
{
float
min_v
=
std
::
min
(
vgetq_lane_f32
(
reg
.
val
[
full_blocks
],
2
),
vgetq_lane_f32
(
b
.
reg
.
val
[
full_blocks
],
2
));
temp
.
val
[
full_blocks
]
=
vsetq_lane_f32
(
min_v
,
temp
.
val
[
full_blocks
],
2
);
}
return
FP32Vec16
(
temp
);
};
FP32Vec16
abs
()
const
{
return
FP32Vec16
(
float32x4x4_t
({
vabsq_f32
(
reg
.
val
[
0
]),
vabsq_f32
(
reg
.
val
[
1
]),
vabsq_f32
(
reg
.
val
[
2
]),
vabsq_f32
(
reg
.
val
[
3
])}));
}
float
reduce_sum
()
const
{
AliasReg
ar
;
ar
.
reg
=
reg
;
...
...
@@ -473,6 +639,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return
answer
;
};
float
reduce_max
()
const
{
AliasReg
ar
;
ar
.
reg
=
reg
;
float
max_v
=
std
::
numeric_limits
<
float
>::
lowest
();
unroll_loop
<
int
,
VEC_ELEM_NUM
>
(
[
&
max_v
,
&
ar
](
int
i
)
{
max_v
=
std
::
max
(
max_v
,
ar
.
values
[
i
]);
});
return
max_v
;
}
float
reduce_min
()
const
{
AliasReg
ar
;
ar
.
reg
=
reg
;
float
min_v
=
std
::
numeric_limits
<
float
>::
max
();
unroll_loop
<
int
,
VEC_ELEM_NUM
>
(
[
&
min_v
,
&
ar
](
int
i
)
{
min_v
=
std
::
min
(
min_v
,
ar
.
values
[
i
]);
});
return
min_v
;
}
template
<
int
group_size
>
float
reduce_sub_sum
(
int
idx
)
{
static_assert
(
VEC_ELEM_NUM
%
group_size
==
0
);
...
...
@@ -493,6 +677,83 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
vst1q_f32
(
ptr
+
8
,
reg
.
val
[
2
]);
vst1q_f32
(
ptr
+
12
,
reg
.
val
[
3
]);
};
void
save
(
float
*
ptr
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
vst1q_f32
(
reinterpret_cast
<
float32_t
*>
(
ptr
)
+
NUM_ELEMENTS_REG
(
reg
.
val
[
0
])
*
i
,
reg
.
val
[
i
]);
if
(
remainder
>
0
)
{
float32x4_t
temp
=
reg
.
val
[
full_blocks
];
float
*
base
=
reinterpret_cast
<
float32_t
*>
(
ptr
)
+
full_blocks
*
NUM_ELEMENTS_REG
(
reg
.
val
[
0
]);
if
(
remainder
>
0
)
base
[
0
]
=
vgetq_lane_f32
(
temp
,
0
);
if
(
remainder
>
1
)
base
[
1
]
=
vgetq_lane_f32
(
temp
,
1
);
if
(
remainder
>
2
)
base
[
2
]
=
vgetq_lane_f32
(
temp
,
2
);
}
}
};
struct
INT8Vec16
:
public
Vec
<
INT8Vec16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
union
AliasReg
{
int8x16_t
reg
;
int8_t
values
[
VEC_ELEM_NUM
];
};
int8x16_t
reg
;
explicit
INT8Vec16
(
const
FP32Vec16
&
vec
)
{
// Convert each 128-bit float32 vector to int32
int32x4_t
part0
=
vcvtq_s32_f32
(
vec
.
reg
.
val
[
0
]);
// Convert first 128-bit block
int32x4_t
part1
=
vcvtq_s32_f32
(
vec
.
reg
.
val
[
1
]);
// Convert second 128-bit block
int32x4_t
part2
=
vcvtq_s32_f32
(
vec
.
reg
.
val
[
2
]);
// Convert third 128-bit block
int32x4_t
part3
=
vcvtq_s32_f32
(
vec
.
reg
.
val
[
3
]);
// Convert fourth 128-bit block
// Narrow each 32-bit vector to 8 bits and combine
int8x8_t
lower
=
vqmovn_s16
(
vcombine_s16
(
vqmovn_s32
(
part0
),
vqmovn_s32
(
part1
)));
int8x8_t
upper
=
vqmovn_s16
(
vcombine_s16
(
vqmovn_s32
(
part2
),
vqmovn_s32
(
part3
)));
reg
=
vcombine_s8
(
lower
,
upper
);
// Combine to form a single 128-bit vector
}
void
save
(
int8_t
*
ptr
)
const
{
vst1q_s8
(
ptr
,
reg
);
};
void
save
(
int8_t
*
ptr
,
const
int
elem_num
)
const
{
int
full_blocks
=
elem_num
/
NUM_ELEMENTS_REG
(
reg
);
int
remainder
=
elem_num
%
NUM_ELEMENTS_REG
(
reg
);
for
(
int
i
=
0
;
i
<
full_blocks
;
i
++
)
vst1q_s8
(
reinterpret_cast
<
int8_t
*>
(
ptr
)
+
NUM_ELEMENTS_REG
(
reg
)
*
i
,
reg
);
if
(
remainder
>
0
)
{
int8x16_t
temp
=
reg
;
int8_t
*
base
=
reinterpret_cast
<
int8_t
*>
(
ptr
)
+
full_blocks
*
NUM_ELEMENTS_REG
(
reg
);
if
(
remainder
>
0
)
base
[
0
]
=
vgetq_lane_s8
(
temp
,
0
);
if
(
remainder
>
1
)
base
[
1
]
=
vgetq_lane_s8
(
temp
,
1
);
if
(
remainder
>
2
)
base
[
2
]
=
vgetq_lane_s8
(
temp
,
2
);
if
(
remainder
>
3
)
base
[
3
]
=
vgetq_lane_s8
(
temp
,
3
);
if
(
remainder
>
4
)
base
[
4
]
=
vgetq_lane_s8
(
temp
,
4
);
if
(
remainder
>
5
)
base
[
5
]
=
vgetq_lane_s8
(
temp
,
5
);
if
(
remainder
>
6
)
base
[
6
]
=
vgetq_lane_s8
(
temp
,
6
);
if
(
remainder
>
7
)
base
[
7
]
=
vgetq_lane_s8
(
temp
,
7
);
if
(
remainder
>
8
)
base
[
8
]
=
vgetq_lane_s8
(
temp
,
8
);
if
(
remainder
>
9
)
base
[
9
]
=
vgetq_lane_s8
(
temp
,
9
);
if
(
remainder
>
10
)
base
[
10
]
=
vgetq_lane_s8
(
temp
,
10
);
if
(
remainder
>
11
)
base
[
11
]
=
vgetq_lane_s8
(
temp
,
11
);
if
(
remainder
>
12
)
base
[
12
]
=
vgetq_lane_s8
(
temp
,
12
);
if
(
remainder
>
13
)
base
[
13
]
=
vgetq_lane_s8
(
temp
,
13
);
if
(
remainder
>
14
)
base
[
14
]
=
vgetq_lane_s8
(
temp
,
14
);
}
};
};
template
<
typename
T
>
...
...
csrc/cpu/dnnl_helper.hpp
View file @
711aa9d5
...
...
@@ -57,6 +57,7 @@ class DNNLPrimitiveHelper {
// Note: Due to the limitation of oneDNN
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
// not supported.
template
<
typename
OutputT
,
typename
BiasT
>
static
void
gemm_s8s8_jit
(
const
int8_t
*
a
,
const
int8_t
*
b
,
OutputT
*
c
,
const
BiasT
*
bias
,
dnnl_dim_t
M
,
dnnl_dim_t
N
,
...
...
@@ -90,6 +91,27 @@ class DNNLPrimitiveHelper {
}
dnnl
::
matmul
::
primitive_desc
matmul_pd
;
// Create memory descriptors with format_tag::any for the primitive. This
// enables the matmul primitive to choose memory layouts for an
// optimized primitive implementation, and these layouts may differ from the
// ones provided by the user.
#ifdef __aarch64__
auto
mat_src_md
=
dnnl
::
memory
::
desc
({
M
,
K
},
dnnl
::
memory
::
data_type
::
s8
,
dnnl
::
memory
::
format_tag
::
any
);
auto
mat_weights_md
=
dnnl
::
memory
::
desc
(
{
K
,
N
},
dnnl
::
memory
::
data_type
::
s8
,
dnnl
::
memory
::
format_tag
::
any
);
auto
mat_dst_md
=
dnnl
::
memory
::
desc
({
M
,
N
},
OutputType
,
dnnl
::
memory
::
format_tag
::
any
);
if
(
bias
)
{
dnnl
::
memory
::
desc
bias_md
({
1
,
N
},
BiasType
,
{
N
,
1
});
matmul_pd
=
dnnl
::
matmul
::
primitive_desc
(
default_engine
(),
mat_src_md
,
mat_weights_md
,
bias_md
,
mat_dst_md
,
attr
);
}
else
{
matmul_pd
=
dnnl
::
matmul
::
primitive_desc
(
default_engine
(),
mat_src_md
,
mat_weights_md
,
mat_dst_md
,
attr
);
}
#else
if
(
bias
)
{
dnnl
::
memory
::
desc
bias_md
({
1
,
N
},
BiasType
,
{
N
,
1
});
matmul_pd
=
dnnl
::
matmul
::
primitive_desc
(
default_engine
(),
a_md
,
b_md
,
...
...
@@ -98,6 +120,7 @@ class DNNLPrimitiveHelper {
matmul_pd
=
dnnl
::
matmul
::
primitive_desc
(
default_engine
(),
a_md
,
b_md
,
c_md
,
attr
);
}
#endif
dnnl
::
matmul
matmul
(
matmul_pd
);
auto
&
engine
=
default_engine
();
...
...
@@ -111,24 +134,34 @@ class DNNLPrimitiveHelper {
(
void
*
)
b_scales
);
auto
&
stream
=
default_stream
();
auto
mat_src_mem
=
a_m
;
auto
mat_weights_mem
=
b_m
;
auto
mat_dst_mem
=
c_m
;
#ifdef __aarch64__
if
(
matmul_pd
.
weights_desc
()
!=
b_m
.
get_desc
())
{
mat_weights_mem
=
dnnl
::
memory
(
matmul_pd
.
weights_desc
(),
engine
);
dnnl
::
reorder
(
b_m
,
mat_weights_mem
).
execute
(
stream
,
b_m
,
mat_weights_mem
);
}
#endif
if
constexpr
(
InputNoScale
)
{
if
(
bias
)
{
dnnl
::
memory
::
desc
bias_md
({
N
},
BiasType
,
{
1
});
dnnl
::
memory
bias_m
(
bias_md
,
engine
,
(
void
*
)
bias
);
matmul
.
execute
(
stream
,
{
{
DNNL_ARG_SRC
,
a_
m
},
{
DNNL_ARG_WEIGHTS
,
b_
m
},
{
DNNL_ARG_SRC
,
mat_src_me
m
},
{
DNNL_ARG_WEIGHTS
,
mat_weights_me
m
},
{
DNNL_ARG_BIAS
,
bias_m
},
{
DNNL_ARG_DST
,
c_
m
},
{
DNNL_ARG_DST
,
mat_dst_me
m
},
{
DNNL_ARG_ATTR_SCALES
|
DNNL_ARG_WEIGHTS
,
b_scales_m
},
});
}
else
{
matmul
.
execute
(
stream
,
{
{
DNNL_ARG_SRC
,
a_
m
},
{
DNNL_ARG_WEIGHTS
,
b_
m
},
{
DNNL_ARG_DST
,
c_
m
},
{
DNNL_ARG_SRC
,
mat_src_me
m
},
{
DNNL_ARG_WEIGHTS
,
mat_weights_me
m
},
{
DNNL_ARG_DST
,
mat_dst_me
m
},
{
DNNL_ARG_ATTR_SCALES
|
DNNL_ARG_WEIGHTS
,
b_scales_m
},
});
}
...
...
@@ -138,19 +171,19 @@ class DNNLPrimitiveHelper {
dnnl
::
memory
bias_m
(
bias_md
,
engine
,
(
void
*
)
bias
);
matmul
.
execute
(
stream
,
{
{
DNNL_ARG_SRC
,
a_
m
},
{
DNNL_ARG_WEIGHTS
,
b_
m
},
{
DNNL_ARG_SRC
,
mat_src_me
m
},
{
DNNL_ARG_WEIGHTS
,
mat_weights_me
m
},
{
DNNL_ARG_BIAS
,
bias_m
},
{
DNNL_ARG_DST
,
c_
m
},
{
DNNL_ARG_DST
,
mat_dst_me
m
},
{
DNNL_ARG_ATTR_SCALES
|
DNNL_ARG_SRC
,
a_scales_m
},
{
DNNL_ARG_ATTR_SCALES
|
DNNL_ARG_WEIGHTS
,
b_scales_m
},
});
}
else
{
matmul
.
execute
(
stream
,
{
{
DNNL_ARG_SRC
,
a_
m
},
{
DNNL_ARG_WEIGHTS
,
b_
m
},
{
DNNL_ARG_DST
,
c_
m
},
{
DNNL_ARG_SRC
,
mat_src_me
m
},
{
DNNL_ARG_WEIGHTS
,
mat_weights_me
m
},
{
DNNL_ARG_DST
,
mat_dst_me
m
},
{
DNNL_ARG_ATTR_SCALES
|
DNNL_ARG_SRC
,
a_scales_m
},
{
DNNL_ARG_ATTR_SCALES
|
DNNL_ARG_WEIGHTS
,
b_scales_m
},
});
...
...
@@ -170,5 +203,4 @@ class DNNLPrimitiveHelper {
return
stream
;
}
};
#endif
Prev
1
2
3
4
5
6
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment