Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dcb5624a
Commit
dcb5624a
authored
Apr 29, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.5' into v0.8.5-dev
parents
55880ca2
ba41cc90
Changes
690
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3796 additions
and
87 deletions
+3796
-87
benchmarks/backend_request_func.py
benchmarks/backend_request_func.py
+107
-0
benchmarks/benchmark_dataset.py
benchmarks/benchmark_dataset.py
+80
-0
benchmarks/benchmark_prefix_caching.py
benchmarks/benchmark_prefix_caching.py
+8
-6
benchmarks/benchmark_serving.py
benchmarks/benchmark_serving.py
+31
-16
benchmarks/benchmark_serving_structured_output.py
benchmarks/benchmark_serving_structured_output.py
+7
-7
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+7
-0
benchmarks/kernels/benchmark_bitblas.py
benchmarks/kernels/benchmark_bitblas.py
+236
-0
benchmarks/kernels/benchmark_lora.py
benchmarks/kernels/benchmark_lora.py
+8
-2
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+7
-12
cmake/external_projects/vllm_flash_attn.cmake
cmake/external_projects/vllm_flash_attn.cmake
+1
-1
csrc/attention/merge_attn_states.cu
csrc/attention/merge_attn_states.cu
+15
-10
csrc/attention/mla/cutlass_mla_entry.cu
csrc/attention/mla/cutlass_mla_entry.cu
+38
-0
csrc/attention/mla/cutlass_mla_kernels.cu
csrc/attention/mla/cutlass_mla_kernels.cu
+225
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+21
-18
csrc/moe/marlin_moe_wna16/generate_kernels.py
csrc/moe/marlin_moe_wna16/generate_kernels.py
+103
-0
csrc/moe/marlin_moe_wna16/kernel.h
csrc/moe/marlin_moe_wna16/kernel.h
+44
-0
csrc/moe/marlin_moe_wna16/marlin_template.h
csrc/moe/marlin_moe_wna16/marlin_template.h
+1917
-0
csrc/moe/marlin_moe_wna16/ops.cu
csrc/moe/marlin_moe_wna16/ops.cu
+927
-0
csrc/moe/moe_wna16.cu
csrc/moe/moe_wna16.cu
+3
-7
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+11
-8
No files found.
benchmarks/backend_request_func.py
View file @
dcb5624a
# SPDX-License-Identifier: Apache-2.0
import
io
import
json
import
os
import
sys
...
...
@@ -32,6 +33,7 @@ class RequestFuncInput:
extra_body
:
Optional
[
dict
]
=
None
multi_modal_content
:
Optional
[
dict
]
=
None
ignore_eos
:
bool
=
False
language
:
Optional
[
str
]
=
None
@
dataclass
...
...
@@ -436,6 +438,110 @@ async def async_request_openai_chat_completions(
return
output
async
def
async_request_openai_audio
(
request_func_input
:
RequestFuncInput
,
pbar
:
Optional
[
tqdm
]
=
None
,
)
->
RequestFuncOutput
:
# Lazy import without PlaceholderModule to avoid vllm dep.
import
soundfile
api_url
=
request_func_input
.
api_url
assert
api_url
.
endswith
(
(
"transcriptions"
,
"translations"
)),
"OpenAI Chat Completions API URL must end with 'transcriptions' "
"or `translations`."
async
with
aiohttp
.
ClientSession
(
trust_env
=
True
,
timeout
=
AIOHTTP_TIMEOUT
)
as
session
:
content
=
[{
"type"
:
"text"
,
"text"
:
request_func_input
.
prompt
}]
payload
=
{
"model"
:
request_func_input
.
model_name
\
if
request_func_input
.
model_name
else
request_func_input
.
model
,
"temperature"
:
0.0
,
"max_completion_tokens"
:
request_func_input
.
output_len
,
"stream"
:
True
,
"language"
:
"en"
,
# Flattened due to multipart/form-data
"stream_include_usage"
:
True
,
"stream_continuous_usage_stats"
:
True
}
if
request_func_input
.
extra_body
:
payload
.
update
(
request_func_input
.
extra_body
)
headers
=
{
"Authorization"
:
f
"Bearer
{
os
.
environ
.
get
(
'OPENAI_API_KEY'
)
}
"
,
}
# Send audio file
def
to_bytes
(
y
,
sr
):
buffer
=
io
.
BytesIO
()
soundfile
.
write
(
buffer
,
y
,
sr
,
format
=
"WAV"
)
buffer
.
seek
(
0
)
return
buffer
with
to_bytes
(
*
request_func_input
.
multi_modal_content
[
'audio'
])
as
f
:
form
=
aiohttp
.
FormData
()
form
.
add_field
(
'file'
,
f
,
content_type
=
'audio/wav'
)
for
key
,
value
in
payload
.
items
():
form
.
add_field
(
key
,
str
(
value
))
output
=
RequestFuncOutput
()
output
.
prompt_len
=
request_func_input
.
prompt_len
generated_text
=
""
ttft
=
0.0
st
=
time
.
perf_counter
()
most_recent_timestamp
=
st
try
:
async
with
session
.
post
(
url
=
api_url
,
data
=
form
,
headers
=
headers
)
as
response
:
if
response
.
status
==
200
:
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk
=
chunk_bytes
.
decode
(
"utf-8"
).
removeprefix
(
"data: "
)
if
chunk
!=
"[DONE]"
:
timestamp
=
time
.
perf_counter
()
data
=
json
.
loads
(
chunk
)
if
choices
:
=
data
.
get
(
"choices"
):
content
=
choices
[
0
][
"delta"
].
get
(
"content"
)
# First token
if
ttft
==
0.0
:
ttft
=
timestamp
-
st
output
.
ttft
=
ttft
# Decoding phase
else
:
output
.
itl
.
append
(
timestamp
-
most_recent_timestamp
)
generated_text
+=
content
or
""
elif
usage
:
=
data
.
get
(
"usage"
):
output
.
output_tokens
=
usage
.
get
(
"completion_tokens"
)
most_recent_timestamp
=
timestamp
output
.
generated_text
=
generated_text
output
.
success
=
True
output
.
latency
=
most_recent_timestamp
-
st
else
:
output
.
error
=
response
.
reason
or
""
output
.
success
=
False
except
Exception
:
output
.
success
=
False
exc_info
=
sys
.
exc_info
()
output
.
error
=
""
.
join
(
traceback
.
format_exception
(
*
exc_info
))
if
pbar
:
pbar
.
update
(
1
)
return
output
def
get_model
(
pretrained_model_name_or_path
:
str
)
->
str
:
if
os
.
getenv
(
'VLLM_USE_MODELSCOPE'
,
'False'
).
lower
()
==
'true'
:
from
modelscope
import
snapshot_download
...
...
@@ -493,6 +599,7 @@ ASYNC_REQUEST_FUNCS = {
"deepspeed-mii"
:
async_request_deepspeed_mii
,
"openai"
:
async_request_openai_completions
,
"openai-chat"
:
async_request_openai_chat_completions
,
"openai-audio"
:
async_request_openai_audio
,
"tensorrt-llm"
:
async_request_trt_llm
,
"scalellm"
:
async_request_openai_completions
,
"sglang"
:
async_request_openai_completions
,
...
...
benchmarks/benchmark_dataset.py
View file @
dcb5624a
...
...
@@ -64,6 +64,7 @@ class SampleRequest:
class
BenchmarkDataset
(
ABC
):
DEFAULT_SEED
=
0
IS_MULTIMODAL
=
False
def
__init__
(
self
,
...
...
@@ -621,6 +622,7 @@ class ConversationDataset(HuggingFaceDataset):
SUPPORTED_DATASET_PATHS
=
{
'lmms-lab/LLaVA-OneVision-Data'
,
'Aeala/ShareGPT_Vicuna_unfiltered'
}
IS_MULTIMODAL
=
True
def
sample
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
...
...
@@ -685,6 +687,7 @@ class VisionArenaDataset(HuggingFaceDataset):
"lmarena-ai/vision-arena-bench-v0.1"
:
lambda
x
:
x
[
"turns"
][
0
][
0
][
"content"
]
}
IS_MULTIMODAL
=
True
def
sample
(
self
,
...
...
@@ -815,3 +818,80 @@ class AIMODataset(HuggingFaceDataset):
))
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
return
sampled_requests
# -----------------------------------------------------------------------------
# ASR Dataset Implementation
# -----------------------------------------------------------------------------
class
ASRDataset
(
HuggingFaceDataset
):
"""
Dataset class for processing a ASR dataset for transcription.
Tested on the following set:
+----------------+----------------------------------------+--------------------------+-----------------------------+
| Dataset | Domain | Speaking Style | hf-subset |
+----------------+----------------------------------------+--------------------------+-----------------------------+
| TED-LIUM | TED talks | Oratory | release1, release2, release3|
| | | | release3-speaker-adaptation |
| VoxPopuli | European Parliament | Oratory | en, de, it, fr, ... |
| LibriSpeech | Audiobook | Narrated | "LIUM/tedlium" |
| GigaSpeech | Audiobook, podcast, YouTube | Narrated, spontaneous | xs, s, m, l, xl, dev, test |
| SPGISpeech | Financial meetings | Oratory, spontaneous | S, M, L, dev, test |
| AMI | Meetings | Spontaneous | ihm, sdm |
+----------------+----------------------------------------+--------------------------+-----------------------------+
"""
# noqa: E501
SUPPORTED_DATASET_PATHS
=
{
"openslr/librispeech_asr"
,
"facebook/voxpopuli"
,
"LIUM/tedlium"
,
"edinburghcstr/ami"
,
"speechcolab/gigaspeech"
,
"kensho/spgispeech"
}
DEFAULT_OUTPUT_LEN
=
128
IS_MULTIMODAL
=
True
# TODO Whisper-specific. Abstract interface when more models are supported.
TRANSCRIPTION_PREAMBLE
=
"<|startoftranscript|><|en|><|transcribe|>"
\
"<|notimestamps|>"
skip_long_audios
:
bool
=
True
def
sample
(
self
,
tokenizer
:
PreTrainedTokenizerBase
,
num_requests
:
int
,
output_len
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
list
:
import
librosa
output_len
=
(
output_len
if
output_len
is
not
None
else
self
.
DEFAULT_OUTPUT_LEN
)
prompt
=
ASRDataset
.
TRANSCRIPTION_PREAMBLE
prompt_len
=
len
(
tokenizer
(
prompt
).
input_ids
)
sampled_requests
=
[]
skipped
=
0
for
item
in
self
.
data
:
if
len
(
sampled_requests
)
>=
num_requests
:
break
audio
=
item
[
"audio"
]
y
,
sr
=
audio
[
"array"
],
audio
[
"sampling_rate"
]
duration_s
=
librosa
.
get_duration
(
y
=
y
,
sr
=
sr
)
# Whisper max supported duration
if
self
.
skip_long_audios
and
duration_s
>
30
:
skipped
+=
1
continue
mm_content
=
{
"audio"
:
(
y
,
sr
)}
sampled_requests
.
append
(
SampleRequest
(
prompt
=
prompt
,
prompt_len
=
prompt_len
,
expected_output_len
=
output_len
,
multi_modal_data
=
mm_content
,
))
if
skipped
:
logger
.
warning
(
"%d samples discarded from dataset due to"
\
" their length being greater than"
\
" what Whisper supports."
,
skipped
)
self
.
maybe_oversample_requests
(
sampled_requests
,
num_requests
)
return
sampled_requests
benchmarks/benchmark_prefix_caching.py
View file @
dcb5624a
...
...
@@ -78,14 +78,16 @@ class Request:
output_len
:
int
def
sample_tokens
(
tokenizer
:
PreTrainedTokenizerBase
,
length
:
int
)
->
str
:
def
sample_tokens
(
tokenizer
:
PreTrainedTokenizerBase
,
length
:
int
)
->
list
[
int
]:
vocab
=
tokenizer
.
get_vocab
()
all_special_ids
=
set
(
tokenizer
.
all_special_ids
)
# Remove the special tokens.
vocab
=
{
k
:
v
for
k
,
v
in
vocab
.
items
()
if
k
not
in
tokenizer
.
all_special_ids
}
return
random
.
choices
(
list
(
vocab
.
values
()),
k
=
length
)
return
random
.
choices
(
[
v
for
k
,
v
in
vocab
.
items
()
if
k
not
in
all_special_ids
],
k
=
length
,
)
def
sample_requests_from_dataset
(
...
...
benchmarks/benchmark_serving.py
View file @
dcb5624a
...
...
@@ -50,7 +50,7 @@ try:
except
ImportError
:
from
argparse
import
ArgumentParser
as
FlexibleArgumentParser
from
benchmark_dataset
import
(
AIMODataset
,
BurstGPTDataset
,
from
benchmark_dataset
import
(
AIMODataset
,
ASRDataset
,
BurstGPTDataset
,
ConversationDataset
,
HuggingFaceDataset
,
InstructCoderDataset
,
RandomDataset
,
SampleRequest
,
ShareGPTDataset
,
SonnetDataset
,
...
...
@@ -274,10 +274,6 @@ async def benchmark(
input_requests
[
0
].
expected_output_len
,
\
input_requests
[
0
].
multi_modal_data
if
backend
!=
"openai-chat"
and
test_mm_content
is
not
None
:
# multi-modal benchmark is only available on OpenAI Chat backend.
raise
ValueError
(
"Multi-modal content is only supported on 'openai-chat' backend."
)
assert
test_mm_content
is
None
or
isinstance
(
test_mm_content
,
dict
)
test_input
=
RequestFuncInput
(
model
=
model_id
,
...
...
@@ -604,6 +600,9 @@ def main(args: argparse.Namespace):
elif
args
.
dataset_path
in
AIMODataset
.
SUPPORTED_DATASET_PATHS
:
dataset_class
=
AIMODataset
args
.
hf_split
=
"train"
elif
args
.
dataset_path
in
ASRDataset
.
SUPPORTED_DATASET_PATHS
:
dataset_class
=
ASRDataset
args
.
hf_split
=
"train"
else
:
supported_datasets
=
set
([
dataset_name
for
cls
in
HuggingFaceDataset
.
__subclasses__
()
...
...
@@ -615,6 +614,13 @@ def main(args: argparse.Namespace):
f
" from one of following:
{
supported_datasets
}
. "
"Please consider contributing if you would "
"like to add support for additional dataset formats."
)
if
(
dataset_class
.
IS_MULTIMODAL
and
backend
not
in
\
[
"openai-chat"
,
"openai-audio"
]):
# multi-modal benchmark is only available on OpenAI Chat backend.
raise
ValueError
(
"Multi-modal content is only supported on 'openai-chat' and "
\
"'openai-audio' backend."
)
input_requests
=
dataset_class
(
dataset_path
=
args
.
dataset_path
,
dataset_subset
=
args
.
hf_subset
,
...
...
@@ -707,7 +713,7 @@ def main(args: argparse.Namespace):
))
# Save config and results to json
if
args
.
save_result
:
if
args
.
save_result
or
args
.
append_result
:
result_json
:
dict
[
str
,
Any
]
=
{}
# Setup
...
...
@@ -728,6 +734,14 @@ def main(args: argparse.Namespace):
raise
ValueError
(
"Invalid metadata format. Please use KEY=VALUE format."
)
# Traffic
result_json
[
"request_rate"
]
=
(
args
.
request_rate
if
args
.
request_rate
<
float
(
"inf"
)
else
"inf"
)
result_json
[
"burstiness"
]
=
args
.
burstiness
result_json
[
"max_concurrency"
]
=
args
.
max_concurrency
# Merge with benchmark result
result_json
=
{
**
result_json
,
**
benchmark_result
}
if
not
args
.
save_detailed
:
# Remove fields with too many data points
...
...
@@ -738,15 +752,6 @@ def main(args: argparse.Namespace):
if
field
in
result_json
:
del
result_json
[
field
]
# Traffic
result_json
[
"request_rate"
]
=
(
args
.
request_rate
if
args
.
request_rate
<
float
(
"inf"
)
else
"inf"
)
result_json
[
"burstiness"
]
=
args
.
burstiness
result_json
[
"max_concurrency"
]
=
args
.
max_concurrency
# Merge with benchmark result
result_json
=
{
**
result_json
,
**
benchmark_result
}
# Save to file
base_model_id
=
model_id
.
split
(
"/"
)[
-
1
]
max_concurrency_str
=
(
f
"-concurrency
{
args
.
max_concurrency
}
"
...
...
@@ -756,7 +761,12 @@ def main(args: argparse.Namespace):
file_name
=
args
.
result_filename
if
args
.
result_dir
:
file_name
=
os
.
path
.
join
(
args
.
result_dir
,
file_name
)
with
open
(
file_name
,
"w"
,
encoding
=
'utf-8'
)
as
outfile
:
with
open
(
file_name
,
mode
=
"a+"
if
args
.
append_result
else
"w"
,
encoding
=
'utf-8'
)
as
outfile
:
# Append a newline.
if
args
.
append_result
and
outfile
.
tell
()
!=
0
:
outfile
.
write
(
"
\n
"
)
json
.
dump
(
result_json
,
outfile
)
save_to_pytorch_benchmark_format
(
args
,
result_json
,
file_name
)
...
...
@@ -888,6 +898,11 @@ if __name__ == "__main__":
help
=
"When saving the results, whether to include per request "
"information such as response, error, ttfs, tpots, etc."
,
)
parser
.
add_argument
(
"--append-result"
,
action
=
"store_true"
,
help
=
"Append the benchmark result to the existing json file."
,
)
parser
.
add_argument
(
"--metadata"
,
metavar
=
"KEY=VALUE"
,
...
...
benchmarks/benchmark_serving_structured_output.py
View file @
dcb5624a
...
...
@@ -51,7 +51,7 @@ try:
except
ImportError
:
from
argparse
import
ArgumentParser
as
FlexibleArgumentParser
from
vllm.v1.structured_output.
utils
import
(
from
vllm.v1.structured_output.
backend_xgrammar
import
(
has_xgrammar_unsupported_json_features
)
MILLISECONDS_TO_SECONDS_CONVERSION
=
1000
...
...
@@ -150,17 +150,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
elif
args
.
dataset
==
"grammar"
:
schema
=
"""
?start:
select_statement
root ::=
select_statement
?
select_statement
:
"SELECT " column
_list " FROM " table_name
select_statement
::=
"SELECT " column
" from " table " where " condition
?
column
_list: column_name ("," column_name)*
column
::= "col_1 " | "col_2 "
?table_name: identifier
table ::= "table_1 " | "table_2 "
?column_name: identifi
er
condition ::= column "= " numb
er
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
number ::= "1 " | "2 "
"""
prompt
=
"Generate an SQL query to show the 'username'
\
and 'email' from the 'users' table."
...
...
benchmarks/benchmark_throughput.py
View file @
dcb5624a
...
...
@@ -571,6 +571,13 @@ def validate_args(args):
raise
ValueError
(
"Tokenizer must be the same as the model for MII backend."
)
# --data-parallel is not supported currently.
# https://github.com/vllm-project/vllm/issues/16222
if
args
.
data_parallel_size
>
1
:
raise
ValueError
(
"Data parallel is not supported in offline benchmark,
\
please use benchmark serving instead"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the throughput."
)
...
...
benchmarks/kernels/benchmark_bitblas.py
0 → 100644
View file @
dcb5624a
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from
vllm.model_executor.layers.quantization.utils.bitblas_utils
import
(
MINIMUM_BITBLAS_VERSION
)
try
:
import
bitblas
if
bitblas
.
__version__
<
MINIMUM_BITBLAS_VERSION
:
raise
ImportError
(
"bitblas version is wrong. Please "
f
"install bitblas>=
{
MINIMUM_BITBLAS_VERSION
}
"
)
except
ImportError
as
e
:
bitblas_import_exception
=
e
raise
ValueError
(
"Trying to use the bitblas backend, but could not import"
f
"with the following error:
{
bitblas_import_exception
}
. "
"Please install bitblas through the following command: "
f
"`pip install bitblas>=
{
MINIMUM_BITBLAS_VERSION
}
`"
)
from
bitblas_import_exception
from
bitblas
import
Matmul
,
MatmulConfig
,
auto_detect_nvidia_target
from
vllm.utils
import
FlexibleArgumentParser
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark BitBLAS int4 on a specific target."
)
# Add arguments to the parser
parser
.
add_argument
(
"--target"
,
type
=
str
,
default
=
auto_detect_nvidia_target
(),
help
=
"Specify the target device for benchmarking."
,
)
parser
.
add_argument
(
"--group_size"
,
type
=
int
,
default
=
None
,
help
=
"Group size for grouped quantization."
)
parser
.
add_argument
(
"--A_dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int8"
],
help
=
"Data type of activation A."
,
)
parser
.
add_argument
(
"--W_dtype"
,
type
=
str
,
default
=
"int4"
,
choices
=
[
"float16"
,
"float32"
,
"float64"
,
"int32"
,
"int8"
,
"int4"
,
"int2"
,
"int1"
,
"nf4"
,
"fp4_e2m1"
,
],
help
=
"Data type of weight W."
,
)
parser
.
add_argument
(
"--accum_dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"int32"
],
help
=
"Data type for accumulation."
,
)
parser
.
add_argument
(
"--out_dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float16"
,
"float32"
,
"int32"
,
"int8"
],
help
=
"Data type for output."
,
)
parser
.
add_argument
(
"--layout"
,
type
=
str
,
default
=
"nt"
,
choices
=
[
"nt"
,
"nn"
],
help
=
"Matrix layout, 'nt' for non-transpose A and transpose W."
,
)
parser
.
add_argument
(
"--with_bias"
,
action
=
"store_true"
,
help
=
"Include bias in the benchmark."
)
parser
.
add_argument
(
"--with_scaling"
,
action
=
"store_true"
,
help
=
"Include scaling factor in the quantization."
,
)
parser
.
add_argument
(
"--with_zeros"
,
action
=
"store_true"
,
help
=
"Include zeros in the quantization."
)
parser
.
add_argument
(
"--zeros_mode"
,
type
=
str
,
default
=
None
,
choices
=
[
"original"
,
"rescale"
,
"quantized"
],
help
=
"Specify the mode for calculating zeros."
,
)
# Parse the arguments
args
=
parser
.
parse_args
()
# Assign arguments to variables
target
=
args
.
target
A_dtype
=
args
.
A_dtype
W_dtype
=
args
.
W_dtype
accum_dtype
=
args
.
accum_dtype
out_dtype
=
args
.
out_dtype
layout
=
args
.
layout
with_bias
=
args
.
with_bias
group_size
=
args
.
group_size
with_scaling
=
args
.
with_scaling
with_zeros
=
args
.
with_zeros
zeros_mode
=
args
.
zeros_mode
# Define a list of shared arguments that repeat in every config
shared_args
=
[
A_dtype
,
W_dtype
,
out_dtype
,
accum_dtype
,
layout
,
with_bias
,
group_size
,
with_scaling
,
with_zeros
,
zeros_mode
,
]
# Define just the (M, K, N) shapes in a more compact list
shapes
=
[
# square test
(
1
,
16384
,
16384
),
# BLOOM-176B
(
1
,
43008
,
14336
),
(
1
,
14336
,
14336
),
(
1
,
57344
,
14336
),
(
1
,
14336
,
57344
),
# OPT-65B
(
1
,
9216
,
9216
),
(
1
,
36864
,
9216
),
(
1
,
9216
,
36864
),
(
1
,
22016
,
8192
),
# LLAMA-70B/65B
(
1
,
8192
,
22016
),
(
1
,
8192
,
8192
),
(
1
,
28672
,
8192
),
(
1
,
8192
,
28672
),
# square test
(
16384
,
16384
,
16384
),
# BLOOM-176B
(
8192
,
43008
,
14336
),
(
8192
,
14336
,
14336
),
(
8192
,
57344
,
14336
),
(
8192
,
14336
,
57344
),
# OPT-65B
(
8192
,
9216
,
9216
),
(
8192
,
36864
,
9216
),
(
8192
,
9216
,
36864
),
(
8192
,
22016
,
8192
),
# LLAMA-70B/65B
(
8192
,
8192
,
22016
),
(
8192
,
8192
,
8192
),
(
8192
,
28672
,
8192
),
(
8192
,
8192
,
28672
),
]
# Build test shapes with all the shared arguments
test_shapes
=
[(
MatmulConfig
,
Matmul
,
(
*
shape
,
*
shared_args
))
for
shape
in
shapes
]
benchmark_sets
=
[]
benchmark_sets
.
extend
(
test_shapes
)
benchmark_results
=
{}
for
config_class
,
operator
,
input_args
in
benchmark_sets
:
config
=
config_class
(
*
input_args
)
matmul
=
operator
(
config
,
target
=
target
,
enable_tuning
=
True
)
kernel_latency
=
matmul
.
profile_latency
()
print
(
"Time cost is: {:.3f} ms"
.
format
(
kernel_latency
))
profile_config
=
{
f
"
{
operator
.
__name__
}
-
{
'-'
.
join
([
str
(
i
)
for
i
in
input_args
])
}
"
:
{
"BitBLAS_top20_latency"
:
kernel_latency
,
}
}
benchmark_results
.
update
(
profile_config
)
# Define headers for the table
headers
=
[
"PrimFunc"
,
"Input Arguments"
,
"BitBLAS Top20 Latency"
,
]
# Calculate column widths for pretty printing
col_widths
=
[
0
,
0
,
0
]
for
config_key
,
values
in
benchmark_results
.
items
():
args_split
=
config_key
.
split
(
"-"
)
func_name
=
args_split
[
0
]
input_args_str
=
"-"
.
join
(
args_split
[
1
:])
col_widths
[
0
]
=
max
(
col_widths
[
0
],
len
(
func_name
)
+
2
,
len
(
headers
[
0
])
+
2
)
col_widths
[
1
]
=
max
(
col_widths
[
1
],
len
(
input_args_str
)
+
2
,
len
(
headers
[
1
])
+
2
)
col_widths
[
2
]
=
max
(
col_widths
[
2
],
len
(
f
"
{
values
[
'BitBLAS_top20_latency'
]:.
3
f
}
ms"
)
+
2
,
len
(
headers
[
2
])
+
2
)
# break only if you want to measure widths from a single example;
# otherwise, let it loop over all items.
# Print header
for
i
,
header
in
enumerate
(
headers
):
headers
[
i
]
=
header
.
ljust
(
col_widths
[
i
])
print
(
""
.
join
(
headers
))
print
(
"-"
*
sum
(
col_widths
))
# Print rows
for
config_key
,
values
in
benchmark_results
.
items
():
args_split
=
config_key
.
split
(
"-"
)
func_name
=
args_split
[
0
]
input_args_str
=
"-"
.
join
(
args_split
[
1
:])
row
=
[
func_name
,
input_args_str
,
f
"
{
values
[
'BitBLAS_top20_latency'
]:.
3
f
}
ms"
,
]
row_str
=
""
.
join
(
[
str
(
cell
).
ljust
(
col_widths
[
idx
])
for
idx
,
cell
in
enumerate
(
row
)])
print
(
row_str
)
benchmarks/kernels/benchmark_lora.py
View file @
dcb5624a
...
...
@@ -17,8 +17,14 @@ from torch.utils.benchmark import Measurement as TMeasurement
from
utils
import
ArgPool
,
Bench
,
CudaGraphBenchParams
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm.lora.ops.triton_ops
import
LoRAKernelMeta
,
lora_expand
,
lora_shrink
from
vllm.lora.ops.triton_ops.utils
import
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
from
vllm.triton_utils
import
HAS_TRITON
if
HAS_TRITON
:
from
vllm.lora.ops.triton_ops
import
(
LoRAKernelMeta
,
lora_expand
,
lora_shrink
)
from
vllm.lora.ops.triton_ops.utils
import
(
_LORA_A_PTR_DICT
,
_LORA_B_PTR_DICT
)
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
...
...
benchmarks/kernels/benchmark_moe.py
View file @
dcb5624a
...
...
@@ -576,11 +576,10 @@ def get_weight_block_size_safety(config, default_value=None):
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
block_quant_shape
=
None
tp_size
=
args
.
tp_size
config
=
AutoConfig
.
from_pretrained
(
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
...
...
@@ -599,21 +598,16 @@ def main(args: argparse.Namespace):
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
elif
config
.
architectures
[
0
]
in
[
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
]:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
block_quant_shape
=
get_weight_block_size_safety
(
config
)
elif
config
.
architectures
[
0
]
==
"Qwen2MoeForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
else
:
if
not
hasattr
(
config
,
"hidden_size"
):
# Support for llama4
config
=
config
.
text_config
# Support for llama4
config
=
config
.
get_text_config
()
# Default: Mixtral.
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
...
...
@@ -624,6 +618,7 @@ def main(args: argparse.Namespace):
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
block_quant_shape
=
get_weight_block_size_safety
(
config
)
if
args
.
batch_size
is
None
:
batch_sizes
=
[
...
...
cmake/external_projects/vllm_flash_attn.cmake
View file @
dcb5624a
...
...
@@ -38,7 +38,7 @@ else()
FetchContent_Declare
(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG
dc9d410b3e2d6534a4c70724c2515f4def670a22
GIT_TAG
8798f27777fb57f447070301bf33a9f9c607f491
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR
${
CMAKE_BINARY_DIR
}
/vllm-flash-attn
...
...
csrc/attention/merge_attn_states.cu
View file @
dcb5624a
...
...
@@ -107,13 +107,14 @@ __global__ void merge_attn_states_kernel(
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size); \
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size); \
}
/*@brief Merges the attention states from prefix and suffix
...
...
@@ -122,10 +123,10 @@ __global__ void merge_attn_states_kernel(
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [h,
d
] The log-sum-exp values for the prefix attention
* @param prefix_lse [h,
n
] The log-sum-exp values for the prefix attention
* states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,
d
] The log-sum-exp values for the suffix attention
* @param suffix_lse [h,
n
] The log-sum-exp values for the suffix attention
* states.
*/
template
<
typename
scalar_t
>
...
...
@@ -146,13 +147,17 @@ void merge_attn_states_launcher(torch::Tensor& output,
if
(
output_lse
.
has_value
())
{
output_lse_ptr
=
output_lse
.
value
().
data_ptr
<
float
>
();
}
// process one pack elements per thread. float -> 4, half/bf16 -> 8
// Process one pack elements per thread. for float, the
// pack_size is 4 for half/bf16, the pack_size is 8.
const
uint
threads_per_head
=
head_size
/
pack_size
;
const
uint
total_threads
=
num_tokens
*
num_heads
*
threads_per_head
;
dim3
block
(
NUM_THREADS
);
dim3
grid
((
total_threads
+
NUM_THREADS
-
1
)
/
NUM_THREADS
);
const
c10
::
cuda
::
OptionalCUDAGuard
device_guard
(
prefix_output
.
device
());
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
LAUNCH_MERGE_ATTN_STATES
(
scalar_t
,
NUM_THREADS
);
}
...
...
csrc/attention/mla/cutlass_mla_entry.cu
0 → 100644
View file @
dcb5624a
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
void
cutlass_mla_decode_sm100a
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
double
scale
);
#endif
void
cutlass_mla_decode
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
double
scale
)
{
#if defined ENABLE_CUTLASS_MLA && ENABLE_CUTLASS_MLA
return
cutlass_mla_decode_sm100a
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass MLA"
);
}
csrc/attention/mla/cutlass_mla_kernels.cu
0 → 100644
View file @
dcb5624a
/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/kernel_hardware_info.h"
#include "cutlass_extensions/common.hpp"
#include "device/sm100_mla.hpp"
#include "kernel/sm100_mla_tile_scheduler.hpp"
using
namespace
cute
;
using
namespace
cutlass
::
fmha
::
kernel
;
template
<
typename
T
,
bool
PersistenceOption
=
true
>
struct
MlaSm100
{
using
Element
=
T
;
using
ElementAcc
=
float
;
using
ElementOut
=
T
;
using
TileShape
=
Shape
<
_128
,
_128
,
Shape
<
_512
,
_64
>>
;
using
TileShapeH
=
cute
::
tuple_element_t
<
0
,
TileShape
>
;
using
TileShapeD
=
cute
::
tuple_element_t
<
2
,
TileShape
>
;
// H K (D_latent D_rope) B
using
ProblemShape
=
cute
::
tuple
<
TileShapeH
,
int
,
TileShapeD
,
int
>
;
using
StrideQ
=
cute
::
tuple
<
int64_t
,
_1
,
int64_t
>
;
// H D B
using
StrideK
=
cute
::
tuple
<
int64_t
,
_1
,
int64_t
>
;
// K D B
using
StrideO
=
StrideK
;
// H D B
using
StrideLSE
=
cute
::
tuple
<
_1
,
int
>
;
// H B
using
TileScheduler
=
std
::
conditional_t
<
PersistenceOption
,
Sm100MlaPersistentTileScheduler
,
Sm100MlaIndividualTileScheduler
>
;
using
FmhaKernel
=
cutlass
::
fmha
::
kernel
::
Sm100FmhaMlaKernelTmaWarpspecialized
<
TileShape
,
Element
,
ElementAcc
,
ElementOut
,
ElementAcc
,
TileScheduler
,
/*kIsCpAsync=*/
true
>
;
using
Fmha
=
cutlass
::
fmha
::
device
::
MLA
<
FmhaKernel
>
;
};
template
<
typename
T
>
typename
T
::
Fmha
::
Arguments
args_from_options
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
double
scale
)
{
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
q_nope
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
int
batches
=
q_nope
.
sizes
()[
0
];
int
page_count_per_seq
=
page_table
.
sizes
()[
1
];
int
page_count_total
=
kv_c_and_k_pe_cache
.
sizes
()[
0
];
int
page_size
=
kv_c_and_k_pe_cache
.
sizes
()[
1
];
int
max_seq_len
=
page_size
*
page_count_per_seq
;
using
TileShapeH
=
typename
T
::
TileShapeH
;
using
TileShapeD
=
typename
T
::
TileShapeD
;
auto
problem_shape
=
cute
::
make_tuple
(
TileShapeH
{},
max_seq_len
,
TileShapeD
{},
batches
);
auto
[
H
,
K
,
D
,
B
]
=
problem_shape
;
auto
[
D_latent
,
D_rope
]
=
D
;
using
StrideQ
=
typename
T
::
StrideQ
;
using
StrideK
=
typename
T
::
StrideK
;
using
StrideO
=
typename
T
::
StrideO
;
using
StrideLSE
=
typename
T
::
StrideLSE
;
StrideQ
stride_Q_latent
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
D_latent
),
_1
{},
static_cast
<
int64_t
>
(
H
*
D_latent
));
StrideQ
stride_Q_rope
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
H
*
D_rope
));
StrideK
stride_C
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
D_latent
+
D_rope
),
_1
{},
static_cast
<
int64_t
>
(
page_size
*
(
D_latent
+
D_rope
)));
StrideLSE
stride_PT
=
cute
::
make_stride
(
_1
{},
page_count_per_seq
);
StrideLSE
stride_LSE
=
cute
::
make_tuple
(
_1
{},
static_cast
<
int
>
(
H
));
StrideO
stride_O
=
cute
::
make_tuple
(
static_cast
<
int64_t
>
(
D_latent
),
_1
{},
static_cast
<
int64_t
>
(
H
*
D_latent
));
using
Element
=
typename
T
::
Element
;
using
ElementOut
=
typename
T
::
ElementOut
;
using
ElementAcc
=
typename
T
::
ElementAcc
;
auto
Q_latent_ptr
=
static_cast
<
Element
*>
(
q_nope
.
data_ptr
());
auto
Q_rope_ptr
=
static_cast
<
Element
*>
(
q_pe
.
data_ptr
());
auto
C_ptr
=
static_cast
<
Element
*>
(
kv_c_and_k_pe_cache
.
data_ptr
());
auto
scale_f
=
static_cast
<
float
>
(
scale
);
typename
T
::
Fmha
::
Arguments
arguments
{
problem_shape
,
{
scale_f
,
Q_latent_ptr
,
stride_Q_latent
,
Q_rope_ptr
,
stride_Q_rope
,
C_ptr
,
stride_C
,
C_ptr
+
D_latent
,
stride_C
,
static_cast
<
int
*>
(
seq_lens
.
data_ptr
()),
static_cast
<
int
*>
(
page_table
.
data_ptr
()),
stride_PT
,
page_count_total
,
page_size
},
{
static_cast
<
ElementOut
*>
(
out
.
data_ptr
()),
stride_O
,
static_cast
<
ElementAcc
*>
(
nullptr
),
stride_LSE
},
hw_info
,
-
1
,
// split_kv
nullptr
,
// is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
// workload across available SMs. Consider using var_split_kv for manual
// control if needed.
T
::
Fmha
::
set_split_kv
(
arguments
);
return
arguments
;
}
template
<
typename
Element
>
void
runMla
(
at
::
Tensor
const
&
out
,
at
::
Tensor
const
&
q_nope
,
at
::
Tensor
const
&
q_pe
,
at
::
Tensor
const
&
kv_c_and_k_pe_cache
,
at
::
Tensor
const
&
seq_lens
,
at
::
Tensor
const
&
page_table
,
float
scale
,
cudaStream_t
stream
)
{
using
MlaSm100Type
=
MlaSm100
<
Element
>
;
typename
MlaSm100Type
::
Fmha
fmha
;
auto
arguments
=
args_from_options
<
MlaSm100Type
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
);
size_t
workspace_size
=
MlaSm100Type
::
Fmha
::
get_workspace_size
(
arguments
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
q_nope
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
CUTLASS_CHECK
(
fmha
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
fmha
.
initialize
(
arguments
,
workspace
.
data_ptr
(),
stream
));
CUTLASS_CHECK
(
fmha
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
void
cutlass_mla_decode_sm100a
(
torch
::
Tensor
const
&
out
,
torch
::
Tensor
const
&
q_nope
,
torch
::
Tensor
const
&
q_pe
,
torch
::
Tensor
const
&
kv_c_and_k_pe_cache
,
torch
::
Tensor
const
&
seq_lens
,
torch
::
Tensor
const
&
page_table
,
double
scale
)
{
TORCH_CHECK
(
q_nope
.
device
().
is_cuda
(),
"q_nope must be on CUDA"
);
TORCH_CHECK
(
q_nope
.
dim
()
==
3
,
"q_nope must be a 3D tensor"
);
TORCH_CHECK
(
q_pe
.
dim
()
==
3
,
"q_pe must be a 3D tensor"
);
TORCH_CHECK
(
kv_c_and_k_pe_cache
.
dim
()
==
3
,
"kv_c_and_k_pe_cache must be a 3D tensor"
);
TORCH_CHECK
(
seq_lens
.
dim
()
==
1
,
"seq_lens must be a 1D tensor"
);
TORCH_CHECK
(
page_table
.
dim
()
==
2
,
"page_table must be a 2D tensor"
);
TORCH_CHECK
(
out
.
dim
()
==
3
,
"out must be a 3D tensor"
);
auto
B_q_nope
=
q_nope
.
size
(
0
);
auto
H_q_nope
=
q_nope
.
size
(
1
);
auto
D_q_nope
=
q_nope
.
size
(
2
);
auto
B_q_pe
=
q_pe
.
size
(
0
);
auto
H_q_pe
=
q_pe
.
size
(
1
);
auto
D_q_pe
=
q_pe
.
size
(
2
);
auto
B_pt
=
page_table
.
size
(
0
);
auto
PAGE_NUM
=
page_table
.
size
(
1
);
auto
PAGE_SIZE
=
kv_c_and_k_pe_cache
.
size
(
1
);
auto
D_ckv
=
kv_c_and_k_pe_cache
.
size
(
2
);
auto
B_o
=
out
.
size
(
0
);
auto
H_o
=
out
.
size
(
1
);
auto
D_o
=
out
.
size
(
2
);
TORCH_CHECK
(
D_q_nope
==
512
,
"D_q_nope must be equal to 512"
);
TORCH_CHECK
(
D_q_pe
==
64
,
"D_q_pe must be equal to 64"
);
TORCH_CHECK
(
D_ckv
==
576
,
"D_ckv must be equal to 576"
);
TORCH_CHECK
(
H_q_nope
==
H_q_pe
&&
H_q_nope
==
H_o
&&
H_o
==
128
,
"H_q_nope, H_q_pe, and H_o must be equal to 128"
);
TORCH_CHECK
(
PAGE_SIZE
>
0
&&
(
PAGE_SIZE
&
(
PAGE_SIZE
-
1
))
==
0
,
"PAGE_SIZE must be a power of 2"
);
TORCH_CHECK
(
B_q_nope
==
B_q_pe
&&
B_q_nope
==
B_pt
&&
B_q_nope
==
B_o
,
"Batch dims must be same for page_table, q_nope and q_pe, and out"
);
TORCH_CHECK
(
PAGE_NUM
%
(
128
/
PAGE_SIZE
)
==
0
,
"PAGE_NUM must be divisible by 128 / PAGE_SIZE"
);
TORCH_CHECK
(
D_o
==
512
,
"D_o must be equal to 512"
);
TORCH_CHECK
(
q_nope
.
dtype
()
==
at
::
ScalarType
::
Half
||
q_nope
.
dtype
()
==
at
::
ScalarType
::
BFloat16
||
q_nope
.
dtype
()
==
at
::
ScalarType
::
Float8_e4m3fn
,
"q_nope must be a half, bfloat16, or float8_e4m3fn tensor"
);
TORCH_CHECK
(
kv_c_and_k_pe_cache
.
dtype
()
==
q_nope
.
dtype
()
&&
q_nope
.
dtype
()
==
q_pe
.
dtype
(),
"kv_c_and_k_pe_cache, q_nope, and q_pe must be the same type"
);
TORCH_CHECK
(
seq_lens
.
dtype
()
==
torch
::
kInt32
,
"seq_lens must be a 32-bit integer tensor"
);
TORCH_CHECK
(
page_table
.
dtype
()
==
torch
::
kInt32
,
"page_table must be a 32-bit integer tensor"
);
auto
in_dtype
=
q_nope
.
dtype
();
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q_nope
.
get_device
()};
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
q_nope
.
get_device
());
if
(
in_dtype
==
at
::
ScalarType
::
Half
)
{
runMla
<
cutlass
::
half_t
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runMla
<
cutlass
::
bfloat16_t
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
,
stream
);
}
else
if
(
in_dtype
==
at
::
ScalarType
::
Float8_e4m3fn
)
{
runMla
<
cutlass
::
float_e4m3_t
>
(
out
,
q_nope
,
q_pe
,
kv_c_and_k_pe_cache
,
seq_lens
,
page_table
,
scale
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported input data type of MLA"
);
}
}
csrc/cache_kernels.cu
View file @
dcb5624a
...
...
@@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
cache_t
*
__restrict__
value_cache
,
// [num_blocks, block_size, num_heads,
// head_size]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
block_stride
,
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
const
int64_t
block_stride
,
const
int64_t
page_stride
,
const
int64_t
head_stride
,
const
int64_t
key_stride
,
const
int64_t
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
// NOTE: slot_idx can be -1 if the token is padded
...
...
@@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int64_t
tgt_key_value_idx
=
block_idx
*
block_stride
+
block_offset
*
num_heads
*
head_siz
e
+
head_idx
*
head_s
iz
e
+
head_offset
;
block_offset
*
page_strid
e
+
head_idx
*
head_s
trid
e
+
head_offset
;
scalar_t
tgt_key
=
key
[
src_key_idx
];
scalar_t
tgt_value
=
value
[
src_value_idx
];
if
constexpr
(
kv_dt
==
Fp8KVCacheDataType
::
kAuto
)
{
...
...
@@ -524,16 +525,16 @@ void reshape_and_cache(
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride,
key
_stride, \
value_stride, num_heads, head_size,
block_size,
\
reinterpret_cast<const float*>(k_scale.data_ptr()),
\
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE)
\
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE>
\
<<<grid, block, 0, stream>>>(
\
reinterpret_cast<KV_T*>(key.data_ptr()),
\
reinterpret_cast<KV_T*>(value.data_ptr()),
\
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),
\
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),
\
slot_mapping.data_ptr<int64_t>(), block_stride,
page
_stride,
\
head_stride, key_stride,
value_stride, num_heads, head_size, \
block_size,
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void
reshape_and_cache_flash
(
...
...
@@ -560,9 +561,11 @@ void reshape_and_cache_flash(
int
head_size
=
key
.
size
(
2
);
int
block_size
=
key_cache
.
size
(
1
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
int
block_stride
=
key_cache
.
stride
(
0
);
int64_t
key_stride
=
key
.
stride
(
0
);
int64_t
value_stride
=
value
.
stride
(
0
);
int64_t
block_stride
=
key_cache
.
stride
(
0
);
int64_t
page_stride
=
key_cache
.
stride
(
1
);
int64_t
head_stride
=
key_cache
.
stride
(
2
);
TORCH_CHECK
(
key_cache
.
stride
(
0
)
==
value_cache
.
stride
(
0
));
dim3
grid
(
num_tokens
);
...
...
csrc/moe/marlin_moe_wna16/generate_kernels.py
0 → 100644
View file @
dcb5624a
# SPDX-License-Identifier: Apache-2.0
import
glob
import
itertools
import
os
import
subprocess
import
jinja2
FILE_HEAD
=
"""
// auto generated by generate.py
// clang-format off
#include "kernel.h"
#include "marlin_template.h"
namespace MARLIN_NAMESPACE_NAME {
"""
.
strip
()
TEMPLATE
=
(
"template __global__ void Marlin<"
"{{scalar_t}}, "
"{{w_type_id}}, "
"{{threads}}, "
"{{thread_m_blocks}}, "
"{{thread_n_blocks}}, "
"{{thread_k_blocks}}, "
"{{'true' if m_block_size_8 else 'false'}}, "
"{{stages}}, "
"{{'true' if has_act_order else 'false'}}, "
"{{'true' if has_zp else 'false'}}, "
"{{group_blocks}}, "
"{{'true' if is_zp_float else 'false'}}>"
"( MARLIN_KERNEL_PARAMS );"
)
# int8 with zero point case (vllm::kU8) is also supported,
# we don't add it to reduce wheel size.
SCALAR_TYPES
=
[
"vllm::kU4"
,
"vllm::kU4B8"
,
"vllm::kU8B128"
]
THREAD_CONFIGS
=
[(
128
,
128
,
256
),
(
64
,
256
,
256
),
(
64
,
128
,
128
)]
THREAD_M_BLOCKS
=
[
0.5
,
1
,
2
,
3
,
4
]
# group_blocks:
# = 0 : act order case
# = -1 : channelwise quantization
# > 0 : group_size=16*group_blocks
GROUP_BLOCKS
=
[
0
,
-
1
,
2
,
4
,
8
]
DTYPES
=
[
"fp16"
,
"bf16"
]
def
remove_old_kernels
():
for
filename
in
glob
.
glob
(
os
.
path
.
dirname
(
__file__
)
+
"/kernel_*.cu"
):
subprocess
.
call
([
"rm"
,
"-f"
,
filename
])
def
generate_new_kernels
():
for
scalar_type
,
dtype
in
itertools
.
product
(
SCALAR_TYPES
,
DTYPES
):
has_zp
=
"B"
not
in
scalar_type
all_template_str_list
=
[]
for
group_blocks
,
m_blocks
,
thread_configs
in
itertools
.
product
(
GROUP_BLOCKS
,
THREAD_M_BLOCKS
,
THREAD_CONFIGS
):
has_act_order
=
group_blocks
==
0
if
has_zp
and
has_act_order
:
continue
if
thread_configs
[
2
]
==
256
:
if
m_blocks
<=
1
and
thread_configs
[
0
]
!=
128
:
continue
if
m_blocks
>
1
and
thread_configs
[
0
]
!=
64
:
continue
k_blocks
=
thread_configs
[
0
]
//
16
n_blocks
=
thread_configs
[
1
]
//
16
threads
=
thread_configs
[
2
]
c_dtype
=
"half"
if
dtype
==
"fp16"
else
"nv_bfloat16"
template_str
=
jinja2
.
Template
(
TEMPLATE
).
render
(
scalar_t
=
c_dtype
,
w_type_id
=
scalar_type
+
".id()"
,
threads
=
threads
,
thread_m_blocks
=
max
(
m_blocks
,
1
),
thread_n_blocks
=
n_blocks
,
thread_k_blocks
=
k_blocks
,
m_block_size_8
=
m_blocks
==
0.5
,
stages
=
"pipe_stages"
,
has_act_order
=
has_act_order
,
has_zp
=
has_zp
,
group_blocks
=
group_blocks
,
is_zp_float
=
False
,
)
all_template_str_list
.
append
(
template_str
)
file_content
=
FILE_HEAD
+
"
\n\n
"
file_content
+=
"
\n\n
"
.
join
(
all_template_str_list
)
+
"
\n\n
}
\n
"
filename
=
f
"kernel_
{
dtype
}
_
{
scalar_type
[
6
:].
lower
()
}
.cu"
with
open
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
filename
),
"w"
)
as
f
:
f
.
write
(
file_content
)
if
__name__
==
"__main__"
:
remove_old_kernels
()
generate_new_kernels
()
csrc/moe/marlin_moe_wna16/kernel.h
0 → 100644
View file @
dcb5624a
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "quantization/gptq_marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
const int4 *__restrict__ A, const int4 *__restrict__ B, \
int4 *__restrict__ C, int4 *__restrict__ C_tmp, \
const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \
const int *__restrict__ g_idx, \
const int32_t *__restrict__ sorted_token_ids_ptr, \
const int32_t *__restrict__ expert_ids_ptr, \
const int32_t *__restrict__ num_tokens_past_padded_ptr, \
const float *__restrict__ topk_weights_ptr, int top_k, \
bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \
int prob_n, int prob_k, int *locks, bool use_atomic_add, \
bool use_fp32_reduce
namespace
MARLIN_NAMESPACE_NAME
{
template
<
typename
scalar_t
,
// compute dtype, half or nv_float16
const
vllm
::
ScalarTypeId
w_type_id
,
// weight ScalarType id
const
int
threads
,
// number of threads in a threadblock
const
int
thread_m_blocks
,
// number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const
int
thread_n_blocks
,
// same for n dimension (output)
const
int
thread_k_blocks
,
// same for k dimension (reduction)
const
bool
m_block_size_8
,
// whether m_block_size == 8
// only works when thread_m_blocks == 1
const
int
stages
,
// number of stages for the async global->shared
// fetch pipeline
const
bool
has_act_order
,
// whether act_order is enabled
const
bool
has_zp
,
// whether zero-points are enabled
const
int
group_blocks
,
// number of consecutive 16x16 blocks
// with a separate quantization scale
const
bool
is_zp_float
// is zero point of float16 type?
>
__global__
void
Marlin
(
MARLIN_KERNEL_PARAMS
);
}
csrc/moe/marlin_moe_wna16/marlin_template.h
0 → 100644
View file @
dcb5624a
This diff is collapsed.
Click to expand it.
csrc/moe/marlin_moe_wna16/ops.cu
0 → 100644
View file @
dcb5624a
This diff is collapsed.
Click to expand it.
csrc/moe/moe_wna16.cu
View file @
dcb5624a
...
...
@@ -13,7 +13,6 @@
template
<
typename
scalar_t
,
int
bit
,
int
GROUPS
>
__global__
void
moe_wna16_gemm_kernel
(
const
scalar_t
*
__restrict__
input
,
scalar_t
*
__restrict__
output
,
const
uint32_t
*
__restrict__
qweight
,
const
scalar_t
*
__restrict__
scales
,
const
uint32_t
*
__restrict__
qzeros
,
...
...
@@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel(
if
(
token_index
/
top_k
>=
size_m
)
break
;
num_valid_tokens
=
m
+
1
;
if
(
blockIdx
.
z
==
0
&&
offset_n
<
size_n
)
output
[
token_index
*
size_n
+
offset_n
]
=
Dtype
::
int2num
(
0
);
if
(
expert_id
!=
-
1
)
{
int
k_per_thread
=
DIVIDE
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
);
...
...
@@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
int64_t
BLOCK_SIZE_M
,
int64_t
BLOCK_SIZE_N
,
int64_t
BLOCK_SIZE_K
,
int64_t
bit
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
input
.
dtype
()).
device
(
input
.
device
());
output
.
zero_
();
const
int
num_experts
=
b_qweight
.
size
(
0
);
const
int
size_m
=
input
.
size
(
0
);
...
...
@@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
const
uint32_t
*
b_qzeros_ptr
;
if
(
b_qzeros
.
has_value
())
b_qzeros_ptr
=
(
const
uint32_t
*
)
b_qzeros
.
value
().
data_ptr
<
uint8_t
>
();
const
float
*
topk_weights_ptr
;
const
float
*
topk_weights_ptr
=
nullptr
;
if
(
topk_weights
.
has_value
())
topk_weights_ptr
=
(
const
float
*
)
topk_weights
.
value
().
data_ptr
();
topk_weights_ptr
=
(
const
float
*
)
topk_weights
.
value
().
data_ptr
<
float
>
();
int
groups_per_block_row
=
BLOCK_SIZE_K
/
group_size
;
TORCH_CHECK
(
bit
==
4
||
bit
==
8
,
"bit must be 4 or 8"
);
...
...
csrc/moe/torch_bindings.cpp
View file @
dcb5624a
...
...
@@ -43,14 +43,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
m
.
impl
(
"moe_wna16_gemm"
,
torch
::
kCUDA
,
&
moe_wna16_gemm
);
m
.
def
(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"
);
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
"Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace,"
"Tensor sorted_token_ids,"
"Tensor! expert_ids, Tensor! num_tokens_past_padded,"
"Tensor! topk_weights, int moe_block_size, int top_k, "
"bool mul_topk_weights, bool is_ep, int b_q_type_id,"
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor"
);
// conditionally compiled so impl registration is in source file
#endif
...
...
Prev
1
2
3
4
5
6
7
…
35
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment