Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
38364a7e
Unverified
Commit
38364a7e
authored
Mar 23, 2026
by
Kyle Sayers
Committed by
GitHub
Mar 23, 2026
Browse files
[Sparse24] [Deprecation] Remove Sparse24 CT integration and kernels (#36799)
Signed-off-by:
Kyle Sayers
<
kylesayrs@gmail.com
>
parent
fafe76b4
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
9 additions
and
2674 deletions
+9
-2674
.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml
...l-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml
+0
-12
CMakeLists.txt
CMakeLists.txt
+0
-26
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
+0
-517
benchmarks/cutlass_benchmarks/utils.py
benchmarks/cutlass_benchmarks/utils.py
+0
-48
csrc/ops.h
csrc/ops.h
+0
-10
csrc/sparse/cutlass/sparse_compressor_c3x.cuh
csrc/sparse/cutlass/sparse_compressor_c3x.cuh
+0
-90
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
+0
-307
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
+0
-570
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
+0
-104
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+0
-20
tests/kernels/quantization/test_cutlass_2of4_sparse.py
tests/kernels/quantization/test_cutlass_2of4_sparse.py
+0
-238
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+0
-281
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+0
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-92
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
...ation/compressed_tensors/schemes/compressed_tensors_24.py
+5
-343
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
...n/compressed_tensors/schemes/compressed_tensors_scheme.py
+4
-4
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+0
-10
No files found.
.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml
deleted
100644 → 0
View file @
fafe76b4
# For vllm script, with -t option (tensor parallel size).
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
model_name
:
"
nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.6353
-
name
:
"
exact_match,flexible-extract"
value
:
0.637
limit
:
null
num_fewshot
:
null
CMakeLists.txt
View file @
38364a7e
...
@@ -343,7 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -343,7 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/cutlass_extensions/common.cpp"
"csrc/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/quantization/w8a8/int8/per_token_group_quant.cu"
)
"csrc/quantization/w8a8/int8/per_token_group_quant.cu"
)
...
@@ -619,31 +618,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -619,31 +618,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif
()
endif
()
endif
()
endif
()
#
# 2:4 Sparse Kernels
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper).
cuda_archs_loose_intersection
(
SCALED_MM_ARCHS
"9.0a;"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS
)
set
(
SRCS
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
SCALED_MM_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
"
${
SRCS
}
"
)
list
(
APPEND VLLM_GPU_FLAGS
"-DENABLE_SPARSE_SCALED_MM_C3X=1"
)
message
(
STATUS
"Building sparse_scaled_mm_c3x for archs:
${
SCALED_MM_ARCHS
}
"
)
else
()
if
(
NOT
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS
)
message
(
STATUS
"Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
"if you intend on running FP8 sparse quantized models on Hopper."
)
else
()
message
(
STATUS
"Not building sparse_scaled_mm_c3x as no compatible archs found "
"in CUDA target architectures"
)
endif
()
endif
()
# The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require
# The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require
# CUDA 12.8 or later
# CUDA 12.8 or later
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER_EQUAL 13.0
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER_EQUAL 13.0
)
...
...
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
deleted
100644 → 0
View file @
fafe76b4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
pickle
as
pkl
import
time
from
collections.abc
import
Callable
,
Iterable
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
utils
import
make_rand_sparse_tensors
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
DEFAULT_TP_SIZES
=
[
1
]
# bench
def
bench_fn
(
label
:
str
,
sub_label
:
str
,
description
:
str
,
fn
:
Callable
,
*
args
,
**
kwargs
)
->
TMeasurement
:
min_run_time
=
1
globals
=
{
"args"
:
args
,
"kwargs"
:
kwargs
,
"fn"
:
fn
,
}
return
TBenchmark
.
Timer
(
stmt
=
"fn(*args, **kwargs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
description
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
def
bench_int8
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
assert
dtype
==
torch
.
int8
b_compressed
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
torch
.
int8
,
m
,
n
,
k
)
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
zeros
((
n
,),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
bfloat16
)
out_ref
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
)
if
not
torch
.
allclose
(
out
,
out_ref
):
print
(
"Incorrect results"
)
print
(
out
)
print
(
out_ref
)
else
:
print
(
"Correct results"
)
timers
=
[]
# pytorch impl - bfloat16
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
bfloat16
),
b
.
to
(
dtype
=
torch
.
bfloat16
),
)
)
# pytorch impl - float16
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp16_fp16_fp16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
float16
),
b
.
to
(
dtype
=
torch
.
float16
),
)
)
# cutlass impl
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
)
)
# cutlass with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
,
)
)
# cutlass sparse impl
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_sparse_mm"
,
ops
.
cutlass_scaled_sparse_mm
,
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
)
)
# cutlass sparse with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_sparse_mm_bias"
,
ops
.
cutlass_scaled_sparse_mm
,
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
,
)
)
return
timers
def
bench_fp8
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
assert
dtype
==
torch
.
float8_e4m3fn
b_compressed
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
torch
.
float8_e4m3fn
,
m
,
n
,
k
)
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
zeros
((
n
,),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
bfloat16
)
out_ref
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
)
if
not
torch
.
allclose
(
out
,
out_ref
):
print
(
"Incorrect results"
)
print
(
out
)
print
(
out_ref
)
else
:
print
(
"Correct results"
)
timers
=
[]
# pytorch impl w. bf16
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
)
)
# pytorch impl: bf16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_bf16_scaled_mm"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
bfloat16
,
)
)
# pytorch impl: bf16 output, with fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
bfloat16
,
use_fast_accum
=
True
,
)
)
# pytorch impl: fp16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_fp16_scaled_mm"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float16
,
)
)
# pytorch impl: fp16 output, with fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float16
,
use_fast_accum
=
True
,
)
)
# cutlass impl: bf16 output
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
)
)
# cutlass impl: bf16 output
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm"
,
ops
.
cutlass_scaled_sparse_mm
,
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
)
)
# cutlass impl: fp16 output
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm"
,
ops
.
cutlass_scaled_sparse_mm
,
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
float16
,
)
)
# cutlass impl: bf16 output, with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias"
,
ops
.
cutlass_scaled_sparse_mm
,
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
,
)
)
# cutlass impl: fp16 output, with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias"
,
ops
.
cutlass_scaled_sparse_mm
,
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
.
to
(
dtype
=
torch
.
float16
),
)
)
return
timers
def
bench
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
if
dtype
==
torch
.
int8
:
return
bench_int8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
)
if
dtype
==
torch
.
float8_e4m3fn
:
return
bench_fp8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
)
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
: should be one of torch.int8, torch.float8_e4m3fn."
)
# runner
def
print_timers
(
timers
:
Iterable
[
TMeasurement
]):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
def
run
(
dtype
:
torch
.
dtype
,
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]]
)
->
Iterable
[
TMeasurement
]:
results
=
[]
for
m
,
k
,
n
in
MKNs
:
timers
=
bench
(
dtype
,
m
,
k
,
n
,
f
"scaled-
{
dtype
}
-gemm"
,
f
"MKN=(
{
m
}
x
{
k
}
x
{
n
}
)"
)
print_timers
(
timers
)
results
.
extend
(
timers
)
return
results
# output makers
def
make_output
(
data
:
Iterable
[
TMeasurement
],
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]],
base_description
:
str
,
timestamp
=
None
,
):
print
(
f
"== All Results
{
base_description
}
===="
)
print_timers
(
data
)
# pickle all the results
timestamp
=
int
(
time
.
time
())
if
timestamp
is
None
else
timestamp
with
open
(
f
"
{
base_description
}
-
{
timestamp
}
.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
(
data
,
f
)
# argparse runners
def
run_square_bench
(
args
):
dim_sizes
=
list
(
range
(
args
.
dim_start
,
args
.
dim_end
+
1
,
args
.
dim_increment
))
MKNs
=
list
(
zip
(
dim_sizes
,
dim_sizes
,
dim_sizes
))
data
=
run
(
args
.
dtype
,
MKNs
)
make_output
(
data
,
MKNs
,
f
"square_bench-
{
args
.
dtype
}
"
)
def
run_range_bench
(
args
):
dim_sizes
=
list
(
range
(
args
.
dim_start
,
args
.
dim_end
,
args
.
dim_increment
))
n
=
len
(
dim_sizes
)
Ms
=
[
args
.
m_constant
]
*
n
if
args
.
m_constant
is
not
None
else
dim_sizes
Ks
=
[
args
.
k_constant
]
*
n
if
args
.
k_constant
is
not
None
else
dim_sizes
Ns
=
[
args
.
n_constant
]
*
n
if
args
.
n_constant
is
not
None
else
dim_sizes
MKNs
=
list
(
zip
(
Ms
,
Ks
,
Ns
))
data
=
run
(
args
.
dtype
,
MKNs
)
make_output
(
data
,
MKNs
,
f
"range_bench-
{
args
.
dtype
}
"
)
def
run_model_bench
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
def
model_shapes
(
model_name
:
str
,
tp_size
:
int
)
->
list
[
tuple
[
int
,
int
]]:
KNs
=
[]
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model_name
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
KNs
.
append
(
KN
)
return
KNs
model_bench_data
=
[]
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
for
model
,
tp_size
in
models_tps
:
Ms
=
args
.
batch_sizes
KNs
=
model_shapes
(
model
,
tp_size
)
MKNs
=
[]
for
m
in
Ms
:
for
k
,
n
in
KNs
:
MKNs
.
append
((
m
,
k
,
n
))
data
=
run
(
args
.
dtype
,
MKNs
)
model_bench_data
.
append
(
data
)
# Print all results
for
data
,
model_tp
in
zip
(
model_bench_data
,
models_tps
):
model
,
tp_size
=
model_tp
print
(
f
"== Results
{
args
.
dtype
}
{
model
}
-TP
{
tp_size
}
===="
)
print_timers
(
data
)
timestamp
=
int
(
time
.
time
())
all_data
=
[]
for
d
in
model_bench_data
:
all_data
.
extend
(
d
)
# pickle all data
with
open
(
f
"model_bench-
{
args
.
dtype
}
-
{
timestamp
}
.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
(
all_data
,
f
)
if
__name__
==
"__main__"
:
def
to_torch_dtype
(
dt
):
if
dt
==
"int8"
:
return
torch
.
int8
if
dt
==
"fp8"
:
return
torch
.
float8_e4m3fn
raise
ValueError
(
"unsupported dtype"
)
parser
=
FlexibleArgumentParser
(
description
=
"""
Benchmark Cutlass GEMM.
To run square GEMMs:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
"""
,
# noqa: E501
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
parser
.
add_argument
(
"--dtype"
,
type
=
to_torch_dtype
,
required
=
True
,
help
=
"Available options are ['int8', 'fp8']"
,
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
)
square_parser
=
subparsers
.
add_parser
(
"square_bench"
)
square_parser
.
add_argument
(
"--dim-start"
,
type
=
int
,
required
=
True
)
square_parser
.
add_argument
(
"--dim-end"
,
type
=
int
,
required
=
True
)
square_parser
.
add_argument
(
"--dim-increment"
,
type
=
int
,
required
=
True
)
square_parser
.
set_defaults
(
func
=
run_square_bench
)
range_parser
=
subparsers
.
add_parser
(
"range_bench"
)
range_parser
.
add_argument
(
"--dim-start"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--dim-end"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--dim-increment"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--m-constant"
,
type
=
int
,
default
=
None
)
range_parser
.
add_argument
(
"--n-constant"
,
type
=
int
,
default
=
None
)
range_parser
.
add_argument
(
"--k-constant"
,
type
=
int
,
default
=
None
)
range_parser
.
set_defaults
(
func
=
run_range_bench
)
model_parser
=
subparsers
.
add_parser
(
"model_bench"
)
model_parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
model_parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
model_parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
model_parser
.
set_defaults
(
func
=
run_model_bench
)
args
=
parser
.
parse_args
()
args
.
func
(
args
)
benchmarks/cutlass_benchmarks/utils.py
View file @
38364a7e
...
@@ -5,8 +5,6 @@
...
@@ -5,8 +5,6 @@
import
torch
import
torch
import
vllm._custom_ops
as
ops
def
to_fp8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
to_fp8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
...
@@ -39,49 +37,3 @@ def make_rand_tensors(
...
@@ -39,49 +37,3 @@ def make_rand_tensors(
return
to_fp8
(
a
),
to_fp8
(
b
)
return
to_fp8
(
a
),
to_fp8
(
b
)
raise
ValueError
(
"unsupported dtype"
)
raise
ValueError
(
"unsupported dtype"
)
def
prune_to_2_4
(
tensor
):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape
=
tensor
.
shape
reshaped
=
tensor
.
reshape
(
-
1
,
4
)
# Get indices of top 2 absolute values in each group of 4
_
,
indices
=
torch
.
topk
(
torch
.
abs
(
reshaped
),
k
=
2
,
dim
=
1
)
# Create binary mask
mask
=
torch
.
zeros_like
(
reshaped
)
mask
.
scatter_
(
dim
=
1
,
index
=
indices
,
src
=
torch
.
ones_like
(
indices
,
dtype
=
mask
.
dtype
))
# Apply mask and reshape back
pruned
=
reshaped
*
mask
# Turn all -0.0 to 0.0
pruned
[
pruned
==
-
0.0
]
=
0.0
return
pruned
.
reshape
(
original_shape
)
def
make_rand_sparse_tensors
(
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
)
*
5
b
=
torch
.
randn
((
n
,
k
),
device
=
"cuda"
).
t
()
*
5
b
=
prune_to_2_4
(
b
.
t
()).
t
()
if
dtype
==
torch
.
int8
:
a
,
b
=
to_int8
(
a
),
to_int8
(
b
)
elif
dtype
==
torch
.
float8_e4m3fn
:
a
,
b
=
to_fp8
(
a
),
to_fp8
(
b
)
elif
dtype
==
torch
.
float16
:
a
,
b
=
to_fp16
(
a
),
to_fp16
(
b
)
elif
dtype
==
torch
.
bfloat16
:
a
,
b
=
to_bf16
(
a
),
to_bf16
(
b
)
else
:
raise
ValueError
(
"unsupported dtype"
)
b_compressed
,
e
=
ops
.
cutlass_sparse_compress
(
b
.
t
())
# Compressed B, Metadata, Original A, B
return
b_compressed
,
e
,
a
,
b
csrc/ops.h
View file @
38364a7e
...
@@ -285,16 +285,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
...
@@ -285,16 +285,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
bool
cutlass_sparse_scaled_mm_supported
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_sparse_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
e
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
vector
<
torch
::
Tensor
>
cutlass_sparse_compress
(
torch
::
Tensor
const
&
a
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
scaled_fp4_quant_func
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
scaled_fp4_quant_func
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_scale
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_scale
,
bool
is_sf_swizzled_layout
);
bool
is_sf_swizzled_layout
);
...
...
csrc/sparse/cutlass/sparse_compressor_c3x.cuh
deleted
100644 → 0
View file @
fafe76b4
#pragma once
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass/numeric_conversion.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
// clang-format on
using
namespace
cute
;
using
namespace
vllm
;
using
CompressorResult
=
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
;
/// Make A structured sparse by replacing elements with 0 and compress it
template
<
typename
Gemm
>
CompressorResult
cutlass_sparse_compress
(
torch
::
Tensor
const
&
a
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
||
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
a
.
dtype
()
==
torch
::
kFloat16
||
a
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
a
.
dim
()
==
2
)
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
0
)
%
4
==
0
)
// Required for semi-structured sparsity
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
)
using
GemmKernel
=
typename
Gemm
::
KernelType
;
using
ElementA
=
typename
Gemm
::
ElementAB
;
using
ElementE
=
typename
GemmKernel
::
CollectiveMainloop
::
ElementE
;
int
m
=
a
.
size
(
0
);
int
k
=
a
.
size
(
1
);
using
ProblemShape
=
typename
GemmKernel
::
ProblemShape
;
ProblemShape
prob_shape
{
m
,
1
,
k
,
1
};
int64_t
lda
=
a
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
using
CompressorUtility
=
typename
Gemm
::
CompressorUtility
;
CompressorUtility
compressor_utility
(
prob_shape
,
a_stride
);
// Allocate buffers for the metadata E and the compressed matrix A
int
ME
=
compressor_utility
.
get_metadata_m_physical
();
int
KE
=
compressor_utility
.
get_metadata_k_physical
();
int
MC
=
compressor_utility
.
get_tensorA_m_physical
();
int
KC
=
compressor_utility
.
get_tensorA_k_physical
();
auto
const
a_meta_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
const
a_nzs_options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
auto
a_meta
=
torch
::
zeros
({
ME
,
KE
},
a_meta_options
);
auto
a_nzs
=
torch
::
zeros
({
MC
,
KC
},
a_nzs_options
);
auto
a_ptr
=
static_cast
<
ElementA
*>
(
a
.
data_ptr
());
auto
a_nzs_ptr
=
static_cast
<
ElementA
*>
(
a_nzs
.
data_ptr
());
auto
a_meta_ptr
=
static_cast
<
ElementE
*>
(
a_meta
.
data_ptr
());
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
a
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
using
Compressor
=
typename
Gemm
::
Compressor
;
typename
Compressor
::
Arguments
arguments
{
prob_shape
,
{
a_ptr
,
a_stride
,
a_nzs_ptr
,
a_meta_ptr
},
{
hw_info
}};
Compressor
compressor_op
;
size_t
workspace_size
=
Compressor
::
get_workspace_size
(
arguments
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
CUTLASS_CHECK
(
compressor_op
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
compressor_op
.
initialize
(
arguments
,
workspace
.
data_ptr
()));
CUTLASS_CHECK
(
compressor_op
.
run
());
CUDA_CHECK
(
cudaDeviceSynchronize
());
return
{
a_meta
,
a_nzs
};
}
#endif
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
deleted
100644 → 0
View file @
fafe76b4
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
// clang-format on
using
namespace
cute
;
using
namespace
vllm
;
struct
GemmCallerTraits
{
using
return_type
=
void
;
template
<
typename
GemmConfig
,
typename
...
Args
>
static
return_type
invoke
(
Args
&&
...
args
)
{
return
cutlass_sparse_gemm_caller
<
GemmConfig
>
(
std
::
forward
<
Args
>
(
args
)...);
}
};
struct
GemmCompressorTraits
{
using
return_type
=
CompressorResult
;
template
<
typename
GemmConfig
,
typename
...
Args
>
static
return_type
invoke
(
Args
&&
...
args
)
{
return
cutlass_sparse_compress
<
GemmConfig
>
(
std
::
forward
<
Args
>
(
args
)...);
}
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
DispatchFunc
,
typename
...
Args
>
typename
DispatchFunc
::
return_type
cutlass_gemm_sm90_fp8_dispatch
(
uint32_t
m
,
uint32_t
n
,
Args
&&
...
args
)
{
static_assert
(
std
::
is_same_v
<
InType
,
cutlass
::
float_e4m3_t
>
);
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_fp8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM256
=
typename
sm90_fp8_config_M256
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM512
=
typename
sm90_fp8_config_M512
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm1
=
typename
sm90_fp8_config_1
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm2
=
typename
sm90_fp8_config_2
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm3
=
typename
sm90_fp8_config_3
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm4
=
typename
sm90_fp8_config_4
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm5
=
typename
sm90_fp8_config_5
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm6
=
typename
sm90_fp8_config_6
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm7
=
typename
sm90_fp8_config_7
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm8
=
typename
sm90_fp8_config_8
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
64
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
64
)
{
if
(
n
==
28672
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm2
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
n
==
4096
||
n
==
6144
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm1
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
else
if
(
mp2
<=
128
)
{
if
(
n
==
4096
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm3
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
n
==
28672
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm5
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
n
==
6144
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm4
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
else
if
(
mp2
<=
256
)
{
if
(
n
==
4096
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm6
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
n
==
28672
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm8
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
n
==
6144
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm7
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
else
{
if
(
n
==
6144
||
n
==
28672
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm8
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
n
==
4096
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm7
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
// Otherwise the default heuristic
if
(
mp2
<=
64
)
{
// n in [1, 64]
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM64
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// n in (64, 128]
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM128
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// n in (128, 256]
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM256
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
{
// n in (256, inf)
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM512
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
DispatchFunc
,
typename
...
Args
>
typename
DispatchFunc
::
return_type
cutlass_gemm_sm90_16bit_dispatch
(
uint32_t
m
,
uint32_t
n
,
Args
&&
...
args
)
{
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmDefault
>(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
DispatchFunc
,
typename
...
Args
>
typename
DispatchFunc
::
return_type
cutlass_gemm_sm90_int8_dispatch
(
uint32_t
m
,
uint32_t
n
,
Args
&&
...
args
)
{
static_assert
(
std
::
is_same_v
<
InType
,
int8_t
>
);
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_int8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_int8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NBig
=
typename
sm90_int8_config_M32_NBig
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NSmall
=
typename
sm90_int8_config_M32_NSmall
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
bool
const
is_small_n
=
n
<
8192
;
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// m in [1, 32]
if
(
is_small_n
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM32NSmall
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM32NBig
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
else
if
(
mp2
<=
64
)
{
// m in (32, 64]
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM64
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM128
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
{
// m in (128, inf)
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmDefault
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
// Dispatch to GEMM implementations based on element types
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_sparse_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
epilogue_args
)
{
uint32_t
const
m
=
out
.
size
(
0
);
uint32_t
const
n
=
out
.
size
(
1
);
// TODO: add dispatch functions to all of these
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
,
GemmCallerTraits
>
(
m
,
n
,
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
,
GemmCallerTraits
>
(
m
,
n
,
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
,
GemmCallerTraits
>
(
m
,
n
,
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
,
GemmCallerTraits
>
(
m
,
n
,
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat16
)
{
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_16bit_dispatch
<
cutlass
::
half_t
,
cutlass
::
half_t
,
Epilogue
,
GemmCallerTraits
>
(
m
,
n
,
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
// a.dtype() == torch::kBFloat16
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kBFloat16
);
return
cutlass_gemm_sm90_16bit_dispatch
<
cutlass
::
bfloat16_t
,
cutlass
::
bfloat16_t
,
Epilogue
,
GemmCallerTraits
>
(
m
,
n
,
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
void
cutlass_scaled_sparse_mm_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"CUTLASS scaled_mm bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_sparse_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueColumnBias
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
b_scales
,
a_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_sparse_mm_sm90_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
b_scales
,
a_scales
);
}
}
CompressorResult
cutlass_sparse_compress_sm90
(
torch
::
Tensor
const
&
a
)
{
// These m and n variables are fordispatching to different GEMM algorithms.
uint32_t
const
m
=
1
;
// Set M to 1 for compression
uint32_t
const
n
=
a
.
size
(
1
);
// Note: For correctness, the compressed format must be invariant in:
// - M, the flattened number of tokens
// - Whether output dtype is fp16 or bf16
// - CUTLASS epilogues
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
c3x
::
TrivialEpilogue
,
GemmCompressorTraits
>
(
m
,
n
,
a
);
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
c3x
::
TrivialEpilogue
,
GemmCompressorTraits
>
(
m
,
n
,
a
);
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat16
)
{
return
cutlass_gemm_sm90_16bit_dispatch
<
cutlass
::
bfloat16_t
,
cutlass
::
bfloat16_t
,
c3x
::
TrivialEpilogue
,
GemmCompressorTraits
>
(
m
,
n
,
a
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kBFloat16
,
"cutlass_sparse_compress only supports int8, fp8_e4m3, fp16, "
"and bf16 datatypes"
);
return
cutlass_gemm_sm90_16bit_dispatch
<
cutlass
::
half_t
,
cutlass
::
half_t
,
c3x
::
TrivialEpilogue
,
GemmCompressorTraits
>
(
m
,
n
,
a
);
}
}
#endif
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
deleted
100644 → 0
View file @
fafe76b4
#pragma once
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/common.hpp"
#include "cutlass_extensions/torch_utils.hpp"
// clang-format on
using
namespace
cute
;
/*
This file defines 2:4 sparse GEMM operations using the CUTLASS 3.x API,
for NVIDIA GPUs with sm90a (Hopper) or later.
*/
namespace
{
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm90_or_later
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
using
GemmUniversalMode
=
cutlass
::
gemm
::
GemmUniversalMode
;
/*
* cutlass_sparse_3x_gemm defines a 2:4 sparse GEMM kernel via CUTLASS
* for SM90 Hopper systems.
*/
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
>
struct
cutlass_sparse_3x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
ElementC
=
void
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
LayoutC_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutC
>::
type
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
// These are the minimum alignments needed for the kernels to compile
static
constexpr
int
AlignmentAB
=
128
/
cutlass
::
sizeof_bits
<
ElementAB
>::
value
;
static
constexpr
int
AlignmentCD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAcc
,
float
,
ElementC
,
LayoutC_Transpose
,
AlignmentCD
,
ElementD
,
LayoutC_Transpose
,
AlignmentCD
,
EpilogueSchedule
,
EVTCompute
>::
CollectiveOp
;
static
constexpr
size_t
CEStorageSize
=
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
);
using
Stages
=
typename
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
CEStorageSize
)
>
;
// clang-format off
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassSparseTensorOp
,
ElementAB
,
cutlass
::
layout
::
RowMajor
,
AlignmentAB
,
ElementAB
,
cutlass
::
layout
::
ColumnMajor
,
AlignmentAB
,
ElementAcc
,
TileShape
,
ClusterShape
,
Stages
,
KernelSchedule
>::
CollectiveOp
;
// clang-format on
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
cute
::
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>>
;
struct
GemmKernel
:
public
KernelType
{};
// Sparse compressor definitions
using
SparseConfig
=
typename
GemmKernel
::
CollectiveMainloop
::
SparseConfig
;
using
LayoutTagA
=
cutlass
::
layout
::
RowMajor
;
using
CompressorUtility
=
cutlass
::
transform
::
kernel
::
StructuredSparseCompressorUtility
<
typename
GemmKernel
::
ProblemShape
,
ElementAB
,
LayoutTagA
,
SparseConfig
>
;
using
CompressorKernel
=
cutlass
::
transform
::
kernel
::
StructuredSparseCompressor
<
typename
GemmKernel
::
ProblemShape
,
ElementAB
,
LayoutTagA
,
SparseConfig
,
cutlass
::
arch
::
Sm90
>
;
using
Compressor
=
cutlass
::
transform
::
device
::
TransformUniversalAdapter
<
CompressorKernel
>
;
};
/*
* This class defines kernel to compress a 2:4 sparse matrix.
* The particular format is defined by the Gemm template parameter,
* which is a cutlass_sparse_3x_gemm.
*/
using
CompressorResult
=
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
;
/// Make A structured sparse by replacing elements with 0 and compress it
template
<
typename
Gemm
>
CompressorResult
cutlass_sparse_compress
(
torch
::
Tensor
const
&
a
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
||
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
a
.
dtype
()
==
torch
::
kFloat16
||
a
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
a
.
dim
()
==
2
)
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
0
)
%
4
==
0
)
// Required for semi-structured sparsity
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
)
using
GemmKernel
=
typename
Gemm
::
KernelType
;
using
ElementA
=
typename
Gemm
::
ElementAB
;
using
ElementE
=
typename
GemmKernel
::
CollectiveMainloop
::
ElementE
;
int
m
=
a
.
size
(
0
);
int
k
=
a
.
size
(
1
);
using
ProblemShape
=
typename
GemmKernel
::
ProblemShape
;
ProblemShape
prob_shape
{
m
,
1
,
k
,
1
};
int64_t
lda
=
a
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
using
CompressorUtility
=
typename
Gemm
::
CompressorUtility
;
CompressorUtility
compressor_utility
(
prob_shape
,
a_stride
);
// Allocate buffers for the metadata E and the compressed matrix A
int
ME
=
compressor_utility
.
get_metadata_m_physical
();
int
KE
=
compressor_utility
.
get_metadata_k_physical
();
int
MC
=
compressor_utility
.
get_tensorA_m_physical
();
int
KC
=
compressor_utility
.
get_tensorA_k_physical
();
auto
const
a_meta_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
const
a_nzs_options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
auto
a_meta
=
torch
::
zeros
({
ME
,
KE
},
a_meta_options
);
auto
a_nzs
=
torch
::
zeros
({
MC
,
KC
},
a_nzs_options
);
auto
a_ptr
=
static_cast
<
ElementA
*>
(
a
.
data_ptr
());
auto
a_nzs_ptr
=
static_cast
<
ElementA
*>
(
a_nzs
.
data_ptr
());
auto
a_meta_ptr
=
static_cast
<
ElementE
*>
(
a_meta
.
data_ptr
());
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
a
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
using
Compressor
=
typename
Gemm
::
Compressor
;
typename
Compressor
::
Arguments
arguments
{
prob_shape
,
{
a_ptr
,
a_stride
,
a_nzs_ptr
,
a_meta_ptr
},
{
hw_info
}};
Compressor
compressor_op
;
size_t
workspace_size
=
Compressor
::
get_workspace_size
(
arguments
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
CUTLASS_CHECK
(
compressor_op
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
compressor_op
.
initialize
(
arguments
,
workspace
.
data_ptr
()));
CUTLASS_CHECK
(
compressor_op
.
run
());
CUDA_CHECK
(
cudaDeviceSynchronize
());
return
{
a_meta
,
a_nzs
};
}
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_sparse_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
// Interface stride expected from the argument a (will get transposed)
// We compute C^T = B^T * A^T, but we assume B is transposed before
// compression and hence the bt_* naming
using
LayoutB
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
LayoutA
;
using
LayoutE
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
LayoutE
;
// M, N, K after transposition
int32_t
m
=
out
.
size
(
1
);
int32_t
n
=
out
.
size
(
0
);
int32_t
k
=
a
.
size
(
1
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
Stride
<
Int
<
1
>
,
int64_t
,
int64_t
>
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
Int
<
0
>
{}};
StrideC
c_stride
{
Int
<
1
>
{},
ldc
,
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
n
,
k
,
1
};
using
ElementE
=
typename
GemmKernel
::
CollectiveMainloop
::
ElementE
;
using
SparseConfig
=
typename
GemmKernel
::
CollectiveMainloop
::
SparseConfig
;
LayoutB
b_layout
=
SparseConfig
::
fill_layoutA
(
prob_shape
);
LayoutE
e_layout
=
SparseConfig
::
fill_layoutE
(
prob_shape
);
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
bt_nzs
.
data_ptr
());
auto
e_ptr
=
static_cast
<
ElementE
*>
(
bt_meta
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
b_ptr
,
b_layout
,
a_ptr
,
a_stride
,
e_ptr
,
e_layout
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
Gemm
::
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...),
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
epilogue_args
};
// Launch the CUTLASS GEMM kernel.
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
GemmOp
gemm_op
;
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
//////////////////////////////////////////////////
// Gemm Configs are defined below
//////////////////////////////////////////////////
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_config_default
{};
template
<
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_config_default
<
half_t
,
OutType
,
Epilogue
>
{
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
half_t
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_config_default
<
cutlass
::
bfloat16_t
,
OutType
,
Epilogue
>
{
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
cutlass
::
bfloat16_t
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
//////////////////////// Cherry-Picking Kernels ////////////////////////
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_1
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_8
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_2
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_128
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_8
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_3
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_4
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_8
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_5
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_8
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_6
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_7
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_8
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_128
,
_256
,
_128
>
;
using
ClusterShape
=
Shape
<
_8
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
////////////////////////////////////////////////////////////////////////
template
<
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_config_default
<
cutlass
::
float_e4m3_t
,
OutType
,
Epilogue
>
{
// M in (128, inf)
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
cutlass
::
float_e4m3_t
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M64
{
// M in [1, 64]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M128
{
// M in (64, 128]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M256
{
// M in (128, 256]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M512
{
// M in (256, ]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_config_default
<
int8_t
,
OutType
,
Epilogue
>
{
// For M > 128 and any N
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
int8_t
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M128
{
// For M in (64, 128] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M64
{
// For M in (32, 64] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NBig
{
// For M in [1, 32] and N >= 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_4
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NSmall
{
// For M in [1, 32] and N < 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
}
// namespace
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
deleted
100644 → 0
View file @
fafe76b4
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
bool
cutlass_sparse_scaled_mm_supported
(
int64_t
cuda_device_capability
)
{
// sparse CUTLASS kernels need exactly hopper and are not forward compatible
// CUDA 12.2 and SM90 (Hopper)
#if defined CUDA_VERSION
return
CUDA_VERSION
>=
12020
&&
cuda_device_capability
==
90
;
#endif
return
false
;
}
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
void
cutlass_scaled_sparse_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
e
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
using
CompressorResult
=
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
;
CompressorResult
cutlass_sparse_compress_sm90
(
torch
::
Tensor
const
&
a
);
#endif
void
cutlass_scaled_sparse_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
bt_nzs
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
1
)
==
bt_nzs
.
size
(
0
)
&&
bt_nzs
.
size
(
1
)
*
2
==
a
.
size
(
1
)
&&
a
.
size
(
0
)
==
c
.
size
(
0
));
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
bt_nzs
.
size
(
0
));
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
bt_nzs
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
bt_nzs
.
stride
(
0
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
bt_nzs
.
size
(
0
)
&&
bias
->
is_contiguous
()
&&
bias
->
dim
()
==
1
);
}
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
// We build for 9.0a which is not forward compatible, so restrict this to
// Hopper only
if
(
version_num
==
90
)
{
cutlass_scaled_sparse_mm_sm90
(
c
,
a
,
bt_nzs
,
bt_meta
,
a_scales
,
b_scales
,
bias
);
return
;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
std
::
vector
<
torch
::
Tensor
>
cutlass_sparse_compress
(
torch
::
Tensor
const
&
a
)
{
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
a
.
stride
(
0
)
%
8
==
0
);
// 8 Byte Alignment for Compression
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
// We build for 9.0a which is not forward compatible, so restrict this to
// Hopper only
if
(
version_num
==
90
)
{
std
::
vector
<
torch
::
Tensor
>
result_tensors
;
auto
[
a_meta
,
a_nzs
]
=
cutlass_sparse_compress_sm90
(
a
);
result_tensors
.
push_back
(
std
::
move
(
a_nzs
));
result_tensors
.
push_back
(
std
::
move
(
a_meta
));
return
result_tensors
;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_sparse_compress for a compute capability equal to "
"CUDA device capability: "
,
version_num
);
}
csrc/torch_bindings.cpp
View file @
38364a7e
...
@@ -523,26 +523,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -523,26 +523,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"cutlass_scaled_mm_supports_block_fp8"
,
ops
.
impl
(
"cutlass_scaled_mm_supports_block_fp8"
,
&
cutlass_scaled_mm_supports_block_fp8
);
&
cutlass_scaled_mm_supports_block_fp8
);
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
// given capability
ops
.
def
(
"cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool"
);
ops
.
impl
(
"cutlass_sparse_scaled_mm_supported"
,
&
cutlass_sparse_scaled_mm_supported
);
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops
.
def
(
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor bt_nzs,"
" Tensor bt_meta, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()"
);
ops
.
impl
(
"cutlass_scaled_sparse_mm"
,
torch
::
kCUDA
,
&
cutlass_scaled_sparse_mm
);
// CUTLASS sparse matrix compressor
ops
.
def
(
"cutlass_sparse_compress(Tensor a) -> Tensor[]"
);
ops
.
impl
(
"cutlass_sparse_compress"
,
&
cutlass_sparse_compress
);
// SM100 CUTLASS MLA decode
// SM100 CUTLASS MLA decode
ops
.
def
(
ops
.
def
(
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
...
...
tests/kernels/quantization/test_cutlass_2of4_sparse.py
deleted
100644 → 0
View file @
fafe76b4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for sparse cutlass kernels
Run `pytest tests/kernels/quantization/test_cutlass_2of4_sparse.py`.
"""
import
pytest
import
torch
from
tests.kernels.utils
import
baseline_scaled_mm
,
to_fp8
,
to_int8
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
sparse_cutlass_supported
,
)
from
vllm.platforms
import
current_platform
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
accelerator
.
device_count
()
==
1
else
2
)
]
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
def
to_bf16
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
tensor
.
to
(
dtype
=
torch
.
bfloat16
)
def
to_fp16
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
tensor
.
to
(
dtype
=
torch
.
float16
)
def
prune_to_2_4
(
tensor
):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape
=
tensor
.
shape
reshaped
=
tensor
.
reshape
(
-
1
,
4
)
# Get indices of top 2 absolute values in each group of 4
_
,
indices
=
torch
.
topk
(
torch
.
abs
(
reshaped
),
k
=
2
,
dim
=
1
)
# Create binary mask
mask
=
torch
.
zeros_like
(
reshaped
)
mask
.
scatter_
(
dim
=
1
,
index
=
indices
,
src
=
torch
.
ones_like
(
indices
,
dtype
=
mask
.
dtype
))
# Apply mask and reshape back
pruned
=
reshaped
*
mask
# Turn all -0.0 to 0.0
pruned
[
pruned
==
-
0.0
]
=
0.0
return
pruned
.
reshape
(
original_shape
)
# This function checks that applying an identity matrix multiplication
# to the compressed weights yields the original uncompressed weights.
def
check_compress_decompress_invariance
(
dtype
:
torch
.
dtype
,
b
:
torch
.
Tensor
,
b_compressed
:
torch
.
Tensor
,
b_metadata
:
torch
.
Tensor
,
):
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
# same dtype as its inputs. This line addresses that constraint while
# arbitrarily using bfloat16 for the int8/fp8 cases.
out_dtype
=
torch
.
float16
if
dtype
is
torch
.
float16
else
torch
.
bfloat16
eye
=
torch
.
eye
(
b
.
shape
[
0
],
device
=
"cuda"
,
dtype
=
dtype
)
eye_scale
=
torch
.
ones
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
b_decomp
=
ops
.
cutlass_scaled_sparse_mm
(
eye
,
b_compressed
,
b_metadata
,
eye_scale
,
eye_scale
,
out_dtype
=
out_dtype
)
torch
.
testing
.
assert_close
(
b
.
to
(
dtype
=
out_dtype
),
b_decomp
)
def
make_rand_sparse_tensors
(
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
)
b
=
torch
.
randn
((
n
,
k
),
device
=
"cuda"
).
t
()
if
dtype
==
torch
.
int8
:
# ensure A and B aren't all zeros after rounding
a
=
a
*
5.0
b
=
b
*
5.0
b
=
prune_to_2_4
(
b
.
t
()).
t
()
if
dtype
==
torch
.
int8
:
a
,
b
=
to_int8
(
a
),
to_int8
(
b
)
elif
dtype
==
torch
.
float8_e4m3fn
:
a
,
b
=
to_fp8
(
a
),
to_fp8
(
b
)
elif
dtype
==
torch
.
float16
:
a
,
b
=
to_fp16
(
a
),
to_fp16
(
b
)
elif
dtype
==
torch
.
bfloat16
:
a
,
b
=
to_bf16
(
a
),
to_bf16
(
b
)
else
:
raise
ValueError
(
"unsupported dtype"
)
b_compressed
,
e
=
ops
.
cutlass_sparse_compress
(
b
.
t
())
check_compress_decompress_invariance
(
dtype
,
b
,
b_compressed
,
e
)
# Compressed B, Metadata, Original A, B
return
b_compressed
,
e
,
a
,
b
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse CUTLASS is not supported on this GPU type."
,
)
# Test working with a subset of A and B for sparse matmul
def
test_cutlass_sparse_subset
():
big_m
=
1024
m
,
n
,
k
=
512
,
512
,
512
# Create tensors
b_comp
,
e
,
whole_a
,
b
=
make_rand_sparse_tensors
(
torch
.
float8_e4m3fn
,
big_m
,
n
,
k
)
a
=
whole_a
[
0
:
m
,
0
:
k
]
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_comp
,
e
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
MNK_FACTORS
=
[
(
1
,
256
,
128
),
(
1
,
16384
,
1024
),
(
1
,
24576
,
512
),
(
16
,
256
,
512
),
(
16
,
16384
,
128
),
(
16
,
24576
,
4096
),
(
32
,
8192
,
4096
),
(
32
,
16384
,
4096
),
(
33
,
1024
,
1024
),
(
33
,
8192
,
128
),
(
64
,
2048
,
512
),
(
64
,
16384
,
1024
),
(
100
,
8192
,
512
),
(
128
,
32768
,
4096
),
(
256
,
4096
,
4096
),
(
512
,
256
,
1024
),
(
512
,
8192
,
4096
),
(
512
,
16384
,
128
),
(
512
,
24576
,
128
),
]
# Test working with a subset of A and B for sparse matmul
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse CUTLASS is not supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"m, n, k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_sparse_gemm
(
m
:
int
,
k
:
int
,
n
:
int
,
dtype
:
type
[
torch
.
dtype
],
use_bias
:
bool
):
# Create tensors
b_comp
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
dtype
,
m
,
n
,
k
)
scale_a
=
torch
.
ones
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
ones
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
rand
((
n
,),
device
=
"cuda"
,
dtype
=
dtype
)
if
use_bias
else
None
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_comp
,
e
,
scale_a
,
scale_b
,
out_dtype
=
dtype
,
bias
=
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
dtype
,
bias
=
bias
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
3e-1
)
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse CUTLASS is not supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"m, k, n"
,
MNK_FACTORS
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_sparse_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
use_bias
:
bool
):
# Create tensors
b_comp
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
torch
.
float8_e4m3fn
,
m
,
n
,
k
)
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
out_dtype
=
torch
.
bfloat16
bias
=
torch
.
rand
((
n
,),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
if
use_bias
else
None
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_comp
,
e
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
,
bias
=
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
,
bias
=
bias
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
3e-1
)
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse CUTLASS is not supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"m,k,n"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_sparse_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
# Create tensors
b_comp
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
torch
.
int8
,
m
,
n
,
k
)
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
out_dtype
=
torch
.
bfloat16
bias
=
torch
.
rand
((
n
,),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
if
use_bias
else
None
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_comp
,
e
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
,
bias
=
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
,
bias
=
bias
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e0
,
atol
=
2e0
)
tests/quantization/test_compressed_tensors.py
View file @
38364a7e
...
@@ -12,7 +12,6 @@ from compressed_tensors.quantization import QuantizationType
...
@@ -12,7 +12,6 @@ from compressed_tensors.quantization import QuantizationType
from
tests.models.utils
import
check_logprobs_close
from
tests.models.utils
import
check_logprobs_close
from
vllm.model_executor.layers.fused_moe
import
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.fused_moe
import
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensors24
,
CompressedTensorsLinearMethod
,
CompressedTensorsLinearMethod
,
CompressedTensorsW4A4Fp4
,
CompressedTensorsW4A4Fp4
,
CompressedTensorsW4A8Fp8
,
CompressedTensorsW4A8Fp8
,
...
@@ -27,9 +26,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8
...
@@ -27,9 +26,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
cutlass_fp4_supported
,
cutlass_fp4_supported
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
sparse_cutlass_supported
,
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.fa_utils
import
get_flash_attn_version
from
vllm.v1.attention.backends.fa_utils
import
get_flash_attn_version
...
@@ -362,283 +358,6 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
...
@@ -362,283 +358,6 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
assert
output
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
def
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
,
format
=
"dense"
):
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
qkv_proj
.
scheme
.
weight_quant
.
strategy
==
weight_strategy
assert
qkv_proj
.
scheme
.
input_quant
.
strategy
==
input_strategy
assert
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
format
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
()
or
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing"
,
"channel"
,
"token"
,
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing"
,
"channel"
,
"tensor"
,
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing"
,
"tensor"
,
"tensor"
,
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing"
,
"tensor"
,
"token"
,
),
],
)
def
test_compressed_tensors_2of4_quant_fp8
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
float8_e4m3fn
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
()
or
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
,
"channel"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM"
,
"channel"
,
"tensor"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM"
,
"tensor"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM"
,
"tensor"
,
"tensor"
,
),
],
)
def
test_compressed_tensors_2of4_quant_fp8_compressed
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
float8_e4m3fn
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
,
format
=
"sparse-24-bitmask"
,
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"cutlass is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM"
,
"channel"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM"
,
"channel"
,
"tensor"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM"
,
"tensor"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM"
,
"tensor"
,
"tensor"
,
),
],
)
def
test_compressed_tensors_2of4_quant_int8_compressed
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
int8
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
,
format
=
"sparse-24-bitmask"
,
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing"
,
"channel"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing"
,
"tensor"
,
"tensor"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing"
,
"tensor"
,
"token"
,
),
],
)
def
test_compressed_tensors_2of4_quant_int8
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
int8
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"2of4 Sparse is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor"
)],
)
def
test_compressed_tensors_2of4_sparse
(
vllm_runner
,
args_2of4
):
model
=
args_2of4
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
qkv_proj
.
scheme
.
weight_quant
is
None
assert
qkv_proj
.
scheme
.
input_quant
is
None
assert
not
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"dense"
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Cutlass is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[(
"nm-testing/llama2.c-stories42M-pruned2.4-compressed"
)]
)
def
test_compressed_tensors_2of4_sparse_compressed
(
vllm_runner
,
args_2of4
):
model
=
args_2of4
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
qkv_proj
.
scheme
.
weight_quant
is
None
assert
qkv_proj
.
scheme
.
input_quant
is
None
assert
not
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"sparse-24-bitmask"
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
)
...
...
tests/weight_loading/models.txt
View file @
38364a7e
...
@@ -20,8 +20,6 @@ compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
...
@@ -20,8 +20,6 @@ compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
#compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
#compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
awq, casperhansen/mixtral-instruct-awq, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
...
...
vllm/_custom_ops.py
View file @
38364a7e
...
@@ -876,10 +876,6 @@ def cutlass_scaled_mm_azp(
...
@@ -876,10 +876,6 @@ def cutlass_scaled_mm_azp(
return
out
.
view
(
*
target_shape
)
return
out
.
view
(
*
target_shape
)
def
cutlass_sparse_scaled_mm_supported
(
cuda_device_capability
:
int
)
->
bool
:
return
torch
.
ops
.
_C
.
cutlass_sparse_scaled_mm_supported
(
cuda_device_capability
)
def
cutlass_group_gemm_supported
(
cuda_device_capability
:
int
)
->
bool
:
def
cutlass_group_gemm_supported
(
cuda_device_capability
:
int
)
->
bool
:
if
cuda_device_capability
<
90
or
cuda_device_capability
>=
110
:
if
cuda_device_capability
<
90
or
cuda_device_capability
>=
110
:
return
False
return
False
...
@@ -890,94 +886,6 @@ def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
...
@@ -890,94 +886,6 @@ def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
return
False
return
False
def
cutlass_sparse_compress
(
a
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compresses a sparse matrix for use with Cutlass sparse operations.
This function takes a dense tensor and compresses it into two components:
non-zero elements and metadata. The compressed representation is compatible
with Cutlass sparse kernels.
Args:
a (torch.Tensor):
The input tensor to be compressed. Must have one of the following data types:
- `torch.int8`
- `torch.float8_e4m3fn`
- `torch.bfloat16`
- `torch.float16`
Returns:
tuple[torch.Tensor, torch.Tensor]:
A tuple containing:
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
Raises:
ValueError: If the compression operation fails.
Notes:
- The `a_meta` tensor has a data type of `torch.uint8`.
- Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`).
- The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor.
- The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`.
"""
assert
a
.
dtype
in
[
torch
.
int8
,
torch
.
float8_e4m3fn
,
torch
.
bfloat16
,
torch
.
float16
]
assert
a
.
is_contiguous
()
# a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
elemsPerMetaElem
=
4
assert
a
.
shape
[
1
]
%
(
2
*
elemsPerMetaElem
)
==
0
return
torch
.
ops
.
_C
.
cutlass_sparse_compress
(
a
)
def
cutlass_scaled_sparse_mm
(
a
:
torch
.
Tensor
,
bt_nzs
:
torch
.
Tensor
,
bt_meta
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""
Performs a scaled sparse matrix multiplication using Cutlass.
Steps:
1. Create a dense matrix `a` of shape (m, k) on the CUDA device:
`a = torch.randn((m, k), device='cuda')`.
2. Create a dense matrix `b` of shape (k, n) on the CUDA device:
`b = torch.randn((k, n), device='cuda')`.
3. Prune matrix `b` to 2:4 sparsity along the specified dimension:
`b = prune_to_2_4(b, dim=0)`.
4. Compress the transposed sparse matrix `b.t()`:
`bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`.
5. Perform sparse matrix multiplication using the compressed matrix,
applying scaling factors for `a` and `b`, and the output data type:
`out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`.
Returns:
- The result of the scaled sparse matrix multiplication.
"""
assert
bt_nzs
.
shape
[
0
]
%
16
==
0
and
bt_nzs
.
shape
[
1
]
%
16
==
0
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
assert
bias
is
None
or
bias
.
shape
[
0
]
==
bt_nzs
.
shape
[
0
]
and
bias
.
dtype
==
out_dtype
m
=
a
.
shape
[
0
]
n
=
bt_nzs
.
shape
[
0
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
_C
.
cutlass_scaled_sparse_mm
(
out
,
a
,
bt_nzs
,
bt_meta
,
scale_a
,
scale_b
,
bias
)
return
out
def
get_cutlass_moe_mm_data
(
def
get_cutlass_moe_mm_data
(
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
View file @
38364a7e
...
@@ -5,39 +5,16 @@ from collections.abc import Callable
...
@@ -5,39 +5,16 @@ from collections.abc import Callable
from
typing
import
Any
from
typing
import
Any
import
torch
import
torch
from
compressed_tensors
import
CompressionFormat
,
ModelCompressor
from
compressed_tensors.quantization
import
(
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
,
)
)
from
compressed_tensors.utils
import
combine_shards
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsScheme
,
)
)
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
convert_to_channelwise
,
sparse_cutlass_supported
,
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
,
)
__all__
=
[
"CompressedTensors24"
]
__all__
=
[
"CompressedTensors24"
]
from
vllm.platforms
import
current_platform
class
CompressedTensors24
(
CompressedTensorsScheme
):
class
CompressedTensors24
(
CompressedTensorsScheme
):
def
__init__
(
def
__init__
(
...
@@ -47,33 +24,11 @@ class CompressedTensors24(CompressedTensorsScheme):
...
@@ -47,33 +24,11 @@ class CompressedTensors24(CompressedTensorsScheme):
input_quant
:
QuantizationArgs
|
None
=
None
,
input_quant
:
QuantizationArgs
|
None
=
None
,
model_compression_config
:
dict
[
str
,
Any
]
|
None
=
None
,
model_compression_config
:
dict
[
str
,
Any
]
|
None
=
None
,
):
):
self
.
quantized
=
quantized
raise
NotImplementedError
(
"Sparse24 models are no longer supported by vLLM"
)
self
.
weight_quant
=
weight_quant
self
.
input_quant
=
input_quant
model_compressor
=
ModelCompressor
.
from_compression_config
(
model_compression_config
)
self
.
do_sparse_decompress
=
(
model_compressor
is
not
None
and
model_compressor
.
sparsity_config
.
format
==
CompressionFormat
.
sparse_24_bitmask
.
value
)
if
self
.
do_sparse_decompress
:
self
.
model_compressor
=
model_compressor
if
(
quantized
and
input_quant
is
not
None
and
self
.
_get_quant_dtype
()
==
current_platform
.
fp8_dtype
()
):
static
=
not
input_quant
.
dynamic
g_shape
=
GroupShape
.
PER_TENSOR
if
static
else
GroupShape
.
PER_TOKEN
self
.
quant_fp8
=
QuantFP8
(
static
,
g_shape
)
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
# Only cutlass 3.x kernels are implemented so far
raise
NotImplementedError
(
"Sparse24 models are no longer supported by vLLM"
)
return
90
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -85,164 +40,10 @@ class CompressedTensors24(CompressedTensorsScheme):
...
@@ -85,164 +40,10 @@ class CompressedTensors24(CompressedTensorsScheme):
weight_loader
:
Callable
,
weight_loader
:
Callable
,
**
kwargs
,
**
kwargs
,
):
):
if
not
sparse_cutlass_supported
():
raise
NotImplementedError
(
"Sparse24 models are no longer supported by vLLM"
)
raise
ValueError
(
"Sparse CUTLASS not supported. vLLM must be built with "
"CUDA 12.2 or later to use this feature"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size
=
input_size
layer
.
input_size_per_partition
=
input_size_per_partition
self
.
weights_dtype
:
torch
.
dtype
=
self
.
_get_params_dtype
(
params_dtype
)
# parameter to store uncompressed weight
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
self
.
weights_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
if
self
.
do_sparse_decompress
:
assert
all
(
partition_size
%
8
==
0
for
partition_size
in
output_partition_sizes
),
"All partitions must be divisible by 8 for "
"2:4 sparse compressed models"
shape
=
BasevLLMParameter
(
data
=
torch
.
empty
(
2
,
1
,
dtype
=
torch
.
int64
),
weight_loader
=
weight_loader
,
)
compressed_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
//
2
,
dtype
=
self
.
weights_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
bitmask
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
//
8
,
dtype
=
torch
.
uint8
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"shape"
,
shape
)
layer
.
register_parameter
(
"compressed"
,
compressed_weight
)
layer
.
register_parameter
(
"bitmask"
,
bitmask
)
# Check if quantized, not just 2:4 Sparse
if
self
.
quantized
:
if
(
self
.
weight_quant
and
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
):
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
(
(
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
else
:
assert
(
self
.
weight_quant
and
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# input quant will be non-none
if
self
.
input_quant
and
not
self
.
input_quant
.
dynamic
:
# register input quant scale
assert
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
input_scale
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
else
:
# for sparse-only, pass in 1 for weight/input scales
weight_scale
=
torch
.
nn
.
Parameter
(
data
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
input_scale
=
torch
.
nn
.
Parameter
(
data
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight"
,
weight
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""
raise
NotImplementedError
(
"Sparse24 models are no longer supported by vLLM"
)
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
if
self
.
do_sparse_decompress
:
layer
.
weight
.
data
=
self
.
_decompress_bitmask_compressed_weight
(
compressed
=
layer
.
compressed
,
bitmask
=
layer
.
bitmask
,
layer
=
layer
,
)
# compressed and bitmask tensors
# are no longer needed after decompression
del
layer
.
compressed
del
layer
.
bitmask
# torch.compile workaround
if
hasattr
(
layer
,
"input_scale"
):
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
if
self
.
weight_quant
:
if
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
convert_to_channelwise
(
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
),
requires_grad
=
False
,
)
else
:
# torch.compile workaround
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
# Set all negative zero values to 0 prior to compression
if
layer
.
weight
.
dtype
.
is_floating_point
and
layer
.
weight
.
dtype
.
itemsize
>=
2
:
layer
.
weight
.
data
[
layer
.
weight
.
data
==
-
0.0
]
=
0.0
w_compressed
,
meta
=
ops
.
cutlass_sparse_compress
(
layer
.
weight
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
w_compressed
,
requires_grad
=
False
)
layer
.
meta
=
torch
.
nn
.
Parameter
(
meta
,
requires_grad
=
False
)
def
apply_weights
(
def
apply_weights
(
self
,
self
,
...
@@ -250,143 +51,4 @@ class CompressedTensors24(CompressedTensorsScheme):
...
@@ -250,143 +51,4 @@ class CompressedTensors24(CompressedTensorsScheme):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
raise
NotImplementedError
(
"Sparse24 models are no longer supported by vLLM"
)
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
if
self
.
quantized
:
scale
=
getattr
(
layer
,
"input_scale"
,
None
)
if
self
.
weights_dtype
==
torch
.
int8
:
ops_output
=
ops
.
scaled_int8_quant
(
x
,
scale
=
scale
)
q_input
=
ops_output
[
0
]
input_scale
=
ops_output
[
1
]
else
:
assert
self
.
weights_dtype
==
torch
.
float8_e4m3fn
q_input
,
input_scale
=
self
.
quant_fp8
(
x
,
scale
=
scale
)
else
:
# Not quantized, nothing to do with the input_scales, use as is
input_scale
=
layer
.
input_scale
q_input
=
x
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
=
q_input
,
bt_nzs
=
layer
.
weight
,
bt_meta
=
layer
.
meta
,
scale_a
=
input_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
)
assert
out
.
is_contiguous
()
return
out
def
_get_params_dtype
(
self
,
params_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
if
not
self
.
quantized
:
return
params_dtype
return
self
.
_get_quant_dtype
()
def
_get_quant_dtype
(
self
)
->
torch
.
dtype
:
assert
self
.
quantized
assert
self
.
weight_quant
is
not
None
assert
self
.
input_quant
is
not
None
is_8_bits
=
self
.
weight_quant
.
num_bits
==
self
.
input_quant
.
num_bits
==
8
if
not
is_8_bits
:
raise
ValueError
(
"Cutlass only supports 8-bit quantization"
)
if
(
self
.
weight_quant
.
type
==
QuantizationType
.
FLOAT
and
self
.
input_quant
.
type
==
QuantizationType
.
FLOAT
):
return
torch
.
float8_e4m3fn
if
(
self
.
weight_quant
.
type
==
QuantizationType
.
INT
and
self
.
input_quant
.
type
==
QuantizationType
.
INT
):
return
torch
.
int8
raise
ValueError
(
"Quantization type not supported by Cutlass"
)
def
_decompress_bitmask_compressed_weight
(
self
,
compressed
:
torch
.
Tensor
,
bitmask
:
torch
.
Tensor
,
layer
:
torch
.
nn
.
Module
,
)
->
torch
.
Tensor
:
"""
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
return the result.
This function also supports sharded decompression.
:param compressed: The 2:4 sparse weight tensor compressed using the
sparse-24-bitmask compressor. This is different from
`cutlass_sparse_compress` which uses a different scheme (2 bits for
every nonzero element that represent the coordinate within the block
of 4). The bitmask compression here uses a bitmask to indicate the
positions of non-zero elements.
:param bitmask: The 2:4 bitmask associated with the compressed weights,
representing the positions of non-zero elements in the compressed
tensor.
:param layer: The layer whose weights need to be processed after
loading.
:return: The decompressed 2:4 sparse weight tensor.
"""
sparsity_compressor
=
self
.
model_compressor
.
sparsity_compressor
def
_process_split
(
bitmask_compressed_weight
:
torch
.
Tensor
,
shape
,
bitmask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
weight_data
=
dict
(
compressed
=
bitmask_compressed_weight
,
shape
=
shape
,
bitmask
=
bitmask
,
)
return
sparsity_compressor
.
decompress_weight
(
weight_data
)
split_weights
:
list
[
torch
.
Tensor
]
=
[]
split_bitmask
:
list
[
torch
.
Tensor
]
=
[]
split_shape
:
list
[
tuple
[
int
,
int
]]
=
[]
if
isinstance
(
layer
,
(
QKVParallelLinear
,
MergedColumnParallelLinear
)):
split_weights
=
torch
.
split
(
compressed
,
layer
.
logical_widths
)
split_bitmask
=
torch
.
split
(
bitmask
,
layer
.
logical_widths
)
split_shape
=
[
(
out
,
layer
.
input_size_per_partition
)
for
out
in
layer
.
logical_widths
]
if
split_weights
:
decompressed_shards
=
[
_process_split
(
compressed_weight
,
shape
,
bitmask
)
for
compressed_weight
,
shape
,
bitmask
in
zip
(
split_weights
,
split_shape
,
split_bitmask
)
]
decompressed
=
combine_shards
(
decompressed_shards
)
else
:
decompressed
=
sparsity_compressor
.
decompress_weight
(
dict
(
compressed
=
compressed
,
shape
=
(
layer
.
logical_widths
[
0
],
layer
.
input_size_per_partition
,
),
bitmask
=
bitmask
,
)
)
return
decompressed
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
View file @
38364a7e
...
@@ -20,7 +20,7 @@ class CompressedTensorsScheme(ABC):
...
@@ -20,7 +20,7 @@ class CompressedTensorsScheme(ABC):
"""
"""
Get minimum device capability.
Get minimum device capability.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
@
abstractmethod
def
create_weights
(
self
,
*
args
,
**
kwargs
):
def
create_weights
(
self
,
*
args
,
**
kwargs
):
...
@@ -28,7 +28,7 @@ class CompressedTensorsScheme(ABC):
...
@@ -28,7 +28,7 @@ class CompressedTensorsScheme(ABC):
Weight creation for the particular scheme. Inputs to this function
Weight creation for the particular scheme. Inputs to this function
"""
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
@
abstractmethod
def
apply_weights
(
def
apply_weights
(
...
@@ -44,7 +44,7 @@ class CompressedTensorsScheme(ABC):
...
@@ -44,7 +44,7 @@ class CompressedTensorsScheme(ABC):
:param bias: bias parameter
:param bias: bias parameter
"""
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
@
abstractmethod
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
...
@@ -52,4 +52,4 @@ class CompressedTensorsScheme(ABC):
...
@@ -52,4 +52,4 @@ class CompressedTensorsScheme(ABC):
Called after weight loading is complete for any cleanup that
Called after weight loading is complete for any cleanup that
needs to occur.
needs to occur.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
()
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
38364a7e
...
@@ -8,16 +8,6 @@ from vllm import _custom_ops as ops
...
@@ -8,16 +8,6 @@ from vllm import _custom_ops as ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
def
sparse_cutlass_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
ops
.
cutlass_sparse_scaled_mm_supported
(
capability
)
def
cutlass_fp8_supported
()
->
bool
:
def
cutlass_fp8_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
if
not
current_platform
.
is_cuda
():
return
False
return
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment