Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
38364a7e
Unverified
Commit
38364a7e
authored
Mar 23, 2026
by
Kyle Sayers
Committed by
GitHub
Mar 23, 2026
Browse files
[Sparse24] [Deprecation] Remove Sparse24 CT integration and kernels (#36799)
Signed-off-by:
Kyle Sayers
<
kylesayrs@gmail.com
>
parent
fafe76b4
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
9 additions
and
2674 deletions
+9
-2674
.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml
...l-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml
+0
-12
CMakeLists.txt
CMakeLists.txt
+0
-26
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
+0
-517
benchmarks/cutlass_benchmarks/utils.py
benchmarks/cutlass_benchmarks/utils.py
+0
-48
csrc/ops.h
csrc/ops.h
+0
-10
csrc/sparse/cutlass/sparse_compressor_c3x.cuh
csrc/sparse/cutlass/sparse_compressor_c3x.cuh
+0
-90
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
+0
-307
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
+0
-570
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
+0
-104
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+0
-20
tests/kernels/quantization/test_cutlass_2of4_sparse.py
tests/kernels/quantization/test_cutlass_2of4_sparse.py
+0
-238
tests/quantization/test_compressed_tensors.py
tests/quantization/test_compressed_tensors.py
+0
-281
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+0
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-92
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
...ation/compressed_tensors/schemes/compressed_tensors_24.py
+5
-343
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
...n/compressed_tensors/schemes/compressed_tensors_scheme.py
+4
-4
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+0
-10
No files found.
.buildkite/lm-eval-harness/configs/SparseLlama3.1_2of4_fp8_compressed.yaml
deleted
100644 → 0
View file @
fafe76b4
# For vllm script, with -t option (tensor parallel size).
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
model_name
:
"
nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.6353
-
name
:
"
exact_match,flexible-extract"
value
:
0.637
limit
:
null
num_fewshot
:
null
CMakeLists.txt
View file @
38364a7e
...
...
@@ -343,7 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp"
"csrc/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/quantization/w8a8/int8/per_token_group_quant.cu"
)
...
...
@@ -619,31 +618,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif
()
endif
()
#
# 2:4 Sparse Kernels
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper).
cuda_archs_loose_intersection
(
SCALED_MM_ARCHS
"9.0a;"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS
)
set
(
SRCS
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
SCALED_MM_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
"
${
SRCS
}
"
)
list
(
APPEND VLLM_GPU_FLAGS
"-DENABLE_SPARSE_SCALED_MM_C3X=1"
)
message
(
STATUS
"Building sparse_scaled_mm_c3x for archs:
${
SCALED_MM_ARCHS
}
"
)
else
()
if
(
NOT
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS
)
message
(
STATUS
"Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
"if you intend on running FP8 sparse quantized models on Hopper."
)
else
()
message
(
STATUS
"Not building sparse_scaled_mm_c3x as no compatible archs found "
"in CUDA target architectures"
)
endif
()
endif
()
# The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require
# CUDA 12.8 or later
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER_EQUAL 13.0
)
...
...
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
deleted
100644 → 0
View file @
fafe76b4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
pickle
as
pkl
import
time
from
collections.abc
import
Callable
,
Iterable
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
utils
import
make_rand_sparse_tensors
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.utils.argparse_utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
list
(
WEIGHT_SHAPES
.
keys
())
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
]
DEFAULT_TP_SIZES
=
[
1
]
# bench
def
bench_fn
(
label
:
str
,
sub_label
:
str
,
description
:
str
,
fn
:
Callable
,
*
args
,
**
kwargs
)
->
TMeasurement
:
min_run_time
=
1
globals
=
{
"args"
:
args
,
"kwargs"
:
kwargs
,
"fn"
:
fn
,
}
return
TBenchmark
.
Timer
(
stmt
=
"fn(*args, **kwargs)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
description
=
description
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
def
bench_int8
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
assert
dtype
==
torch
.
int8
b_compressed
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
torch
.
int8
,
m
,
n
,
k
)
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
zeros
((
n
,),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
bfloat16
)
out_ref
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
)
if
not
torch
.
allclose
(
out
,
out_ref
):
print
(
"Incorrect results"
)
print
(
out
)
print
(
out_ref
)
else
:
print
(
"Correct results"
)
timers
=
[]
# pytorch impl - bfloat16
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
bfloat16
),
b
.
to
(
dtype
=
torch
.
bfloat16
),
)
)
# pytorch impl - float16
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp16_fp16_fp16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
float16
),
b
.
to
(
dtype
=
torch
.
float16
),
)
)
# cutlass impl
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
)
)
# cutlass with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
,
)
)
# cutlass sparse impl
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_sparse_mm"
,
ops
.
cutlass_scaled_sparse_mm
,
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
)
)
# cutlass sparse with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_sparse_mm_bias"
,
ops
.
cutlass_scaled_sparse_mm
,
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
,
)
)
return
timers
def
bench_fp8
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
assert
dtype
==
torch
.
float8_e4m3fn
b_compressed
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
torch
.
float8_e4m3fn
,
m
,
n
,
k
)
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
zeros
((
n
,),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
bfloat16
)
out_ref
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
)
if
not
torch
.
allclose
(
out
,
out_ref
):
print
(
"Incorrect results"
)
print
(
out
)
print
(
out_ref
)
else
:
print
(
"Correct results"
)
timers
=
[]
# pytorch impl w. bf16
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
)
)
# pytorch impl: bf16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_bf16_scaled_mm"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
bfloat16
,
)
)
# pytorch impl: bf16 output, with fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
bfloat16
,
use_fast_accum
=
True
,
)
)
# pytorch impl: fp16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_fp16_scaled_mm"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float16
,
)
)
# pytorch impl: fp16 output, with fp8 fast accum
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float16
,
use_fast_accum
=
True
,
)
)
# cutlass impl: bf16 output
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
)
)
# cutlass impl: bf16 output
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm"
,
ops
.
cutlass_scaled_sparse_mm
,
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
)
)
# cutlass impl: fp16 output
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm"
,
ops
.
cutlass_scaled_sparse_mm
,
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
float16
,
)
)
# cutlass impl: bf16 output, with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias"
,
ops
.
cutlass_scaled_sparse_mm
,
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
,
)
)
# cutlass impl: fp16 output, with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias"
,
ops
.
cutlass_scaled_sparse_mm
,
a
,
b_compressed
,
e
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
.
to
(
dtype
=
torch
.
float16
),
)
)
return
timers
def
bench
(
dtype
:
torch
.
dtype
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
)
->
Iterable
[
TMeasurement
]:
if
dtype
==
torch
.
int8
:
return
bench_int8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
)
if
dtype
==
torch
.
float8_e4m3fn
:
return
bench_fp8
(
dtype
,
m
,
k
,
n
,
label
,
sub_label
)
raise
ValueError
(
f
"Unsupported dtype
{
dtype
}
: should be one of torch.int8, torch.float8_e4m3fn."
)
# runner
def
print_timers
(
timers
:
Iterable
[
TMeasurement
]):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
def
run
(
dtype
:
torch
.
dtype
,
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]]
)
->
Iterable
[
TMeasurement
]:
results
=
[]
for
m
,
k
,
n
in
MKNs
:
timers
=
bench
(
dtype
,
m
,
k
,
n
,
f
"scaled-
{
dtype
}
-gemm"
,
f
"MKN=(
{
m
}
x
{
k
}
x
{
n
}
)"
)
print_timers
(
timers
)
results
.
extend
(
timers
)
return
results
# output makers
def
make_output
(
data
:
Iterable
[
TMeasurement
],
MKNs
:
Iterable
[
tuple
[
int
,
int
,
int
]],
base_description
:
str
,
timestamp
=
None
,
):
print
(
f
"== All Results
{
base_description
}
===="
)
print_timers
(
data
)
# pickle all the results
timestamp
=
int
(
time
.
time
())
if
timestamp
is
None
else
timestamp
with
open
(
f
"
{
base_description
}
-
{
timestamp
}
.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
(
data
,
f
)
# argparse runners
def
run_square_bench
(
args
):
dim_sizes
=
list
(
range
(
args
.
dim_start
,
args
.
dim_end
+
1
,
args
.
dim_increment
))
MKNs
=
list
(
zip
(
dim_sizes
,
dim_sizes
,
dim_sizes
))
data
=
run
(
args
.
dtype
,
MKNs
)
make_output
(
data
,
MKNs
,
f
"square_bench-
{
args
.
dtype
}
"
)
def
run_range_bench
(
args
):
dim_sizes
=
list
(
range
(
args
.
dim_start
,
args
.
dim_end
,
args
.
dim_increment
))
n
=
len
(
dim_sizes
)
Ms
=
[
args
.
m_constant
]
*
n
if
args
.
m_constant
is
not
None
else
dim_sizes
Ks
=
[
args
.
k_constant
]
*
n
if
args
.
k_constant
is
not
None
else
dim_sizes
Ns
=
[
args
.
n_constant
]
*
n
if
args
.
n_constant
is
not
None
else
dim_sizes
MKNs
=
list
(
zip
(
Ms
,
Ks
,
Ns
))
data
=
run
(
args
.
dtype
,
MKNs
)
make_output
(
data
,
MKNs
,
f
"range_bench-
{
args
.
dtype
}
"
)
def
run_model_bench
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
def
model_shapes
(
model_name
:
str
,
tp_size
:
int
)
->
list
[
tuple
[
int
,
int
]]:
KNs
=
[]
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model_name
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
KNs
.
append
(
KN
)
return
KNs
model_bench_data
=
[]
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
for
model
,
tp_size
in
models_tps
:
Ms
=
args
.
batch_sizes
KNs
=
model_shapes
(
model
,
tp_size
)
MKNs
=
[]
for
m
in
Ms
:
for
k
,
n
in
KNs
:
MKNs
.
append
((
m
,
k
,
n
))
data
=
run
(
args
.
dtype
,
MKNs
)
model_bench_data
.
append
(
data
)
# Print all results
for
data
,
model_tp
in
zip
(
model_bench_data
,
models_tps
):
model
,
tp_size
=
model_tp
print
(
f
"== Results
{
args
.
dtype
}
{
model
}
-TP
{
tp_size
}
===="
)
print_timers
(
data
)
timestamp
=
int
(
time
.
time
())
all_data
=
[]
for
d
in
model_bench_data
:
all_data
.
extend
(
d
)
# pickle all data
with
open
(
f
"model_bench-
{
args
.
dtype
}
-
{
timestamp
}
.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
(
all_data
,
f
)
if
__name__
==
"__main__"
:
def
to_torch_dtype
(
dt
):
if
dt
==
"int8"
:
return
torch
.
int8
if
dt
==
"fp8"
:
return
torch
.
float8_e4m3fn
raise
ValueError
(
"unsupported dtype"
)
parser
=
FlexibleArgumentParser
(
description
=
"""
Benchmark Cutlass GEMM.
To run square GEMMs:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
"""
,
# noqa: E501
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
parser
.
add_argument
(
"--dtype"
,
type
=
to_torch_dtype
,
required
=
True
,
help
=
"Available options are ['int8', 'fp8']"
,
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
)
square_parser
=
subparsers
.
add_parser
(
"square_bench"
)
square_parser
.
add_argument
(
"--dim-start"
,
type
=
int
,
required
=
True
)
square_parser
.
add_argument
(
"--dim-end"
,
type
=
int
,
required
=
True
)
square_parser
.
add_argument
(
"--dim-increment"
,
type
=
int
,
required
=
True
)
square_parser
.
set_defaults
(
func
=
run_square_bench
)
range_parser
=
subparsers
.
add_parser
(
"range_bench"
)
range_parser
.
add_argument
(
"--dim-start"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--dim-end"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--dim-increment"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--m-constant"
,
type
=
int
,
default
=
None
)
range_parser
.
add_argument
(
"--n-constant"
,
type
=
int
,
default
=
None
)
range_parser
.
add_argument
(
"--k-constant"
,
type
=
int
,
default
=
None
)
range_parser
.
set_defaults
(
func
=
run_range_bench
)
model_parser
=
subparsers
.
add_parser
(
"model_bench"
)
model_parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
model_parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
model_parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
model_parser
.
set_defaults
(
func
=
run_model_bench
)
args
=
parser
.
parse_args
()
args
.
func
(
args
)
benchmarks/cutlass_benchmarks/utils.py
View file @
38364a7e
...
...
@@ -5,8 +5,6 @@
import
torch
import
vllm._custom_ops
as
ops
def
to_fp8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
...
...
@@ -39,49 +37,3 @@ def make_rand_tensors(
return
to_fp8
(
a
),
to_fp8
(
b
)
raise
ValueError
(
"unsupported dtype"
)
def
prune_to_2_4
(
tensor
):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape
=
tensor
.
shape
reshaped
=
tensor
.
reshape
(
-
1
,
4
)
# Get indices of top 2 absolute values in each group of 4
_
,
indices
=
torch
.
topk
(
torch
.
abs
(
reshaped
),
k
=
2
,
dim
=
1
)
# Create binary mask
mask
=
torch
.
zeros_like
(
reshaped
)
mask
.
scatter_
(
dim
=
1
,
index
=
indices
,
src
=
torch
.
ones_like
(
indices
,
dtype
=
mask
.
dtype
))
# Apply mask and reshape back
pruned
=
reshaped
*
mask
# Turn all -0.0 to 0.0
pruned
[
pruned
==
-
0.0
]
=
0.0
return
pruned
.
reshape
(
original_shape
)
def
make_rand_sparse_tensors
(
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
)
*
5
b
=
torch
.
randn
((
n
,
k
),
device
=
"cuda"
).
t
()
*
5
b
=
prune_to_2_4
(
b
.
t
()).
t
()
if
dtype
==
torch
.
int8
:
a
,
b
=
to_int8
(
a
),
to_int8
(
b
)
elif
dtype
==
torch
.
float8_e4m3fn
:
a
,
b
=
to_fp8
(
a
),
to_fp8
(
b
)
elif
dtype
==
torch
.
float16
:
a
,
b
=
to_fp16
(
a
),
to_fp16
(
b
)
elif
dtype
==
torch
.
bfloat16
:
a
,
b
=
to_bf16
(
a
),
to_bf16
(
b
)
else
:
raise
ValueError
(
"unsupported dtype"
)
b_compressed
,
e
=
ops
.
cutlass_sparse_compress
(
b
.
t
())
# Compressed B, Metadata, Original A, B
return
b_compressed
,
e
,
a
,
b
csrc/ops.h
View file @
38364a7e
...
...
@@ -285,16 +285,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
std
::
optional
<
torch
::
Tensor
>
const
&
azp
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
bool
cutlass_sparse_scaled_mm_supported
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_sparse_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
e
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
vector
<
torch
::
Tensor
>
cutlass_sparse_compress
(
torch
::
Tensor
const
&
a
);
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
scaled_fp4_quant_func
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
input_scale
,
bool
is_sf_swizzled_layout
);
...
...
csrc/sparse/cutlass/sparse_compressor_c3x.cuh
deleted
100644 → 0
View file @
fafe76b4
#pragma once
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass/numeric_conversion.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
// clang-format on
using
namespace
cute
;
using
namespace
vllm
;
using
CompressorResult
=
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
;
/// Make A structured sparse by replacing elements with 0 and compress it
template
<
typename
Gemm
>
CompressorResult
cutlass_sparse_compress
(
torch
::
Tensor
const
&
a
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
||
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
a
.
dtype
()
==
torch
::
kFloat16
||
a
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
a
.
dim
()
==
2
)
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
0
)
%
4
==
0
)
// Required for semi-structured sparsity
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
)
using
GemmKernel
=
typename
Gemm
::
KernelType
;
using
ElementA
=
typename
Gemm
::
ElementAB
;
using
ElementE
=
typename
GemmKernel
::
CollectiveMainloop
::
ElementE
;
int
m
=
a
.
size
(
0
);
int
k
=
a
.
size
(
1
);
using
ProblemShape
=
typename
GemmKernel
::
ProblemShape
;
ProblemShape
prob_shape
{
m
,
1
,
k
,
1
};
int64_t
lda
=
a
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
using
CompressorUtility
=
typename
Gemm
::
CompressorUtility
;
CompressorUtility
compressor_utility
(
prob_shape
,
a_stride
);
// Allocate buffers for the metadata E and the compressed matrix A
int
ME
=
compressor_utility
.
get_metadata_m_physical
();
int
KE
=
compressor_utility
.
get_metadata_k_physical
();
int
MC
=
compressor_utility
.
get_tensorA_m_physical
();
int
KC
=
compressor_utility
.
get_tensorA_k_physical
();
auto
const
a_meta_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
const
a_nzs_options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
auto
a_meta
=
torch
::
zeros
({
ME
,
KE
},
a_meta_options
);
auto
a_nzs
=
torch
::
zeros
({
MC
,
KC
},
a_nzs_options
);
auto
a_ptr
=
static_cast
<
ElementA
*>
(
a
.
data_ptr
());
auto
a_nzs_ptr
=
static_cast
<
ElementA
*>
(
a_nzs
.
data_ptr
());
auto
a_meta_ptr
=
static_cast
<
ElementE
*>
(
a_meta
.
data_ptr
());
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
a
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
using
Compressor
=
typename
Gemm
::
Compressor
;
typename
Compressor
::
Arguments
arguments
{
prob_shape
,
{
a_ptr
,
a_stride
,
a_nzs_ptr
,
a_meta_ptr
},
{
hw_info
}};
Compressor
compressor_op
;
size_t
workspace_size
=
Compressor
::
get_workspace_size
(
arguments
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
CUTLASS_CHECK
(
compressor_op
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
compressor_op
.
initialize
(
arguments
,
workspace
.
data_ptr
()));
CUTLASS_CHECK
(
compressor_op
.
run
());
CUDA_CHECK
(
cudaDeviceSynchronize
());
return
{
a_meta
,
a_nzs
};
}
#endif
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu
deleted
100644 → 0
View file @
fafe76b4
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
// clang-format on
using
namespace
cute
;
using
namespace
vllm
;
struct
GemmCallerTraits
{
using
return_type
=
void
;
template
<
typename
GemmConfig
,
typename
...
Args
>
static
return_type
invoke
(
Args
&&
...
args
)
{
return
cutlass_sparse_gemm_caller
<
GemmConfig
>
(
std
::
forward
<
Args
>
(
args
)...);
}
};
struct
GemmCompressorTraits
{
using
return_type
=
CompressorResult
;
template
<
typename
GemmConfig
,
typename
...
Args
>
static
return_type
invoke
(
Args
&&
...
args
)
{
return
cutlass_sparse_compress
<
GemmConfig
>
(
std
::
forward
<
Args
>
(
args
)...);
}
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
DispatchFunc
,
typename
...
Args
>
typename
DispatchFunc
::
return_type
cutlass_gemm_sm90_fp8_dispatch
(
uint32_t
m
,
uint32_t
n
,
Args
&&
...
args
)
{
static_assert
(
std
::
is_same_v
<
InType
,
cutlass
::
float_e4m3_t
>
);
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_fp8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM256
=
typename
sm90_fp8_config_M256
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM512
=
typename
sm90_fp8_config_M512
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm1
=
typename
sm90_fp8_config_1
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm2
=
typename
sm90_fp8_config_2
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm3
=
typename
sm90_fp8_config_3
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm4
=
typename
sm90_fp8_config_4
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm5
=
typename
sm90_fp8_config_5
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm6
=
typename
sm90_fp8_config_6
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm7
=
typename
sm90_fp8_config_7
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemm8
=
typename
sm90_fp8_config_8
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
64
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
64
)
{
if
(
n
==
28672
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm2
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
n
==
4096
||
n
==
6144
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm1
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
else
if
(
mp2
<=
128
)
{
if
(
n
==
4096
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm3
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
n
==
28672
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm5
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
n
==
6144
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm4
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
else
if
(
mp2
<=
256
)
{
if
(
n
==
4096
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm6
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
n
==
28672
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm8
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
n
==
6144
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm7
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
else
{
if
(
n
==
6144
||
n
==
28672
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm8
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
n
==
4096
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemm7
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
// Otherwise the default heuristic
if
(
mp2
<=
64
)
{
// n in [1, 64]
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM64
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// n in (64, 128]
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM128
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// n in (128, 256]
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM256
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
{
// n in (256, inf)
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM512
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
DispatchFunc
,
typename
...
Args
>
typename
DispatchFunc
::
return_type
cutlass_gemm_sm90_16bit_dispatch
(
uint32_t
m
,
uint32_t
n
,
Args
&&
...
args
)
{
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmDefault
>(
std
::
forward
<
Args
>
(
args
)...);
}
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
DispatchFunc
,
typename
...
Args
>
typename
DispatchFunc
::
return_type
cutlass_gemm_sm90_int8_dispatch
(
uint32_t
m
,
uint32_t
n
,
Args
&&
...
args
)
{
static_assert
(
std
::
is_same_v
<
InType
,
int8_t
>
);
using
Cutlass3xGemmDefault
=
typename
sm90_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
typename
sm90_int8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM64
=
typename
sm90_int8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NBig
=
typename
sm90_int8_config_M32_NBig
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM32NSmall
=
typename
sm90_int8_config_M32_NSmall
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
bool
const
is_small_n
=
n
<
8192
;
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
32
),
next_pow_2
(
m
));
// next power of 2
if
(
mp2
<=
32
)
{
// m in [1, 32]
if
(
is_small_n
)
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM32NSmall
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
{
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM32NBig
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
else
if
(
mp2
<=
64
)
{
// m in (32, 64]
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM64
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
if
(
mp2
<=
128
)
{
// m in (64, 128]
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmM128
>(
std
::
forward
<
Args
>
(
args
)...);
}
else
{
// m in (128, inf)
return
DispatchFunc
::
template
invoke
<
Cutlass3xGemmDefault
>(
std
::
forward
<
Args
>
(
args
)...);
}
}
// Dispatch to GEMM implementations based on element types
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_sparse_mm_sm90_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
epilogue_args
)
{
uint32_t
const
m
=
out
.
size
(
0
);
uint32_t
const
n
=
out
.
size
(
1
);
// TODO: add dispatch functions to all of these
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
Epilogue
,
GemmCallerTraits
>
(
m
,
n
,
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
half_t
,
Epilogue
,
GemmCallerTraits
>
(
m
,
n
,
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
,
GemmCallerTraits
>
(
m
,
n
,
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
,
GemmCallerTraits
>
(
m
,
n
,
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat16
)
{
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm90_16bit_dispatch
<
cutlass
::
half_t
,
cutlass
::
half_t
,
Epilogue
,
GemmCallerTraits
>
(
m
,
n
,
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
// a.dtype() == torch::kBFloat16
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
bt_nzs
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kBFloat16
);
return
cutlass_gemm_sm90_16bit_dispatch
<
cutlass
::
bfloat16_t
,
cutlass
::
bfloat16_t
,
Epilogue
,
GemmCallerTraits
>
(
m
,
n
,
out
,
a
,
bt_nzs
,
bt_meta
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
void
cutlass_scaled_sparse_mm_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
bt_meta
.
dtype
()
==
torch
::
kUInt8
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"CUTLASS scaled_mm bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_sparse_mm_sm90_epilogue
<
c3x
::
ScaledEpilogueColumnBias
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
b_scales
,
a_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_sparse_mm_sm90_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
bt_nzs
,
bt_meta
,
b_scales
,
a_scales
);
}
}
CompressorResult
cutlass_sparse_compress_sm90
(
torch
::
Tensor
const
&
a
)
{
// These m and n variables are fordispatching to different GEMM algorithms.
uint32_t
const
m
=
1
;
// Set M to 1 for compression
uint32_t
const
n
=
a
.
size
(
1
);
// Note: For correctness, the compressed format must be invariant in:
// - M, the flattened number of tokens
// - Whether output dtype is fp16 or bf16
// - CUTLASS epilogues
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
return
cutlass_gemm_sm90_int8_dispatch
<
int8_t
,
cutlass
::
bfloat16_t
,
c3x
::
TrivialEpilogue
,
GemmCompressorTraits
>
(
m
,
n
,
a
);
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
return
cutlass_gemm_sm90_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
c3x
::
TrivialEpilogue
,
GemmCompressorTraits
>
(
m
,
n
,
a
);
}
else
if
(
a
.
dtype
()
==
torch
::
kFloat16
)
{
return
cutlass_gemm_sm90_16bit_dispatch
<
cutlass
::
bfloat16_t
,
cutlass
::
bfloat16_t
,
c3x
::
TrivialEpilogue
,
GemmCompressorTraits
>
(
m
,
n
,
a
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kBFloat16
,
"cutlass_sparse_compress only supports int8, fp8_e4m3, fp16, "
"and bf16 datatypes"
);
return
cutlass_gemm_sm90_16bit_dispatch
<
cutlass
::
half_t
,
cutlass
::
half_t
,
c3x
::
TrivialEpilogue
,
GemmCompressorTraits
>
(
m
,
n
,
a
);
}
}
#endif
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
deleted
100644 → 0
View file @
fafe76b4
#pragma once
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include "cutlass_extensions/common.hpp"
#include "cutlass_extensions/torch_utils.hpp"
// clang-format on
using
namespace
cute
;
/*
This file defines 2:4 sparse GEMM operations using the CUTLASS 3.x API,
for NVIDIA GPUs with sm90a (Hopper) or later.
*/
namespace
{
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template
<
typename
Kernel
>
struct
enable_sm90_or_later
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 900
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
using
GemmUniversalMode
=
cutlass
::
gemm
::
GemmUniversalMode
;
/*
* cutlass_sparse_3x_gemm defines a 2:4 sparse GEMM kernel via CUTLASS
* for SM90 Hopper systems.
*/
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
>
struct
cutlass_sparse_3x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
ElementC
=
void
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
LayoutC_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutC
>::
type
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
// These are the minimum alignments needed for the kernels to compile
static
constexpr
int
AlignmentAB
=
128
/
cutlass
::
sizeof_bits
<
ElementAB
>::
value
;
static
constexpr
int
AlignmentCD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAcc
,
float
,
ElementC
,
LayoutC_Transpose
,
AlignmentCD
,
ElementD
,
LayoutC_Transpose
,
AlignmentCD
,
EpilogueSchedule
,
EVTCompute
>::
CollectiveOp
;
static
constexpr
size_t
CEStorageSize
=
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
);
using
Stages
=
typename
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
CEStorageSize
)
>
;
// clang-format off
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassSparseTensorOp
,
ElementAB
,
cutlass
::
layout
::
RowMajor
,
AlignmentAB
,
ElementAB
,
cutlass
::
layout
::
ColumnMajor
,
AlignmentAB
,
ElementAcc
,
TileShape
,
ClusterShape
,
Stages
,
KernelSchedule
>::
CollectiveOp
;
// clang-format on
using
KernelType
=
enable_sm90_or_later
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
cute
::
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>>
;
struct
GemmKernel
:
public
KernelType
{};
// Sparse compressor definitions
using
SparseConfig
=
typename
GemmKernel
::
CollectiveMainloop
::
SparseConfig
;
using
LayoutTagA
=
cutlass
::
layout
::
RowMajor
;
using
CompressorUtility
=
cutlass
::
transform
::
kernel
::
StructuredSparseCompressorUtility
<
typename
GemmKernel
::
ProblemShape
,
ElementAB
,
LayoutTagA
,
SparseConfig
>
;
using
CompressorKernel
=
cutlass
::
transform
::
kernel
::
StructuredSparseCompressor
<
typename
GemmKernel
::
ProblemShape
,
ElementAB
,
LayoutTagA
,
SparseConfig
,
cutlass
::
arch
::
Sm90
>
;
using
Compressor
=
cutlass
::
transform
::
device
::
TransformUniversalAdapter
<
CompressorKernel
>
;
};
/*
* This class defines kernel to compress a 2:4 sparse matrix.
* The particular format is defined by the Gemm template parameter,
* which is a cutlass_sparse_3x_gemm.
*/
using
CompressorResult
=
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
;
/// Make A structured sparse by replacing elements with 0 and compress it
template
<
typename
Gemm
>
CompressorResult
cutlass_sparse_compress
(
torch
::
Tensor
const
&
a
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
||
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
||
a
.
dtype
()
==
torch
::
kFloat16
||
a
.
dtype
()
==
torch
::
kBFloat16
);
TORCH_CHECK
(
a
.
dim
()
==
2
)
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
0
)
%
4
==
0
)
// Required for semi-structured sparsity
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
)
using
GemmKernel
=
typename
Gemm
::
KernelType
;
using
ElementA
=
typename
Gemm
::
ElementAB
;
using
ElementE
=
typename
GemmKernel
::
CollectiveMainloop
::
ElementE
;
int
m
=
a
.
size
(
0
);
int
k
=
a
.
size
(
1
);
using
ProblemShape
=
typename
GemmKernel
::
ProblemShape
;
ProblemShape
prob_shape
{
m
,
1
,
k
,
1
};
int64_t
lda
=
a
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
0
};
using
CompressorUtility
=
typename
Gemm
::
CompressorUtility
;
CompressorUtility
compressor_utility
(
prob_shape
,
a_stride
);
// Allocate buffers for the metadata E and the compressed matrix A
int
ME
=
compressor_utility
.
get_metadata_m_physical
();
int
KE
=
compressor_utility
.
get_metadata_k_physical
();
int
MC
=
compressor_utility
.
get_tensorA_m_physical
();
int
KC
=
compressor_utility
.
get_tensorA_k_physical
();
auto
const
a_meta_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
const
a_nzs_options
=
torch
::
TensorOptions
().
dtype
(
a
.
dtype
()).
device
(
a
.
device
());
auto
a_meta
=
torch
::
zeros
({
ME
,
KE
},
a_meta_options
);
auto
a_nzs
=
torch
::
zeros
({
MC
,
KC
},
a_nzs_options
);
auto
a_ptr
=
static_cast
<
ElementA
*>
(
a
.
data_ptr
());
auto
a_nzs_ptr
=
static_cast
<
ElementA
*>
(
a_nzs
.
data_ptr
());
auto
a_meta_ptr
=
static_cast
<
ElementE
*>
(
a_meta
.
data_ptr
());
cutlass
::
KernelHardwareInfo
hw_info
;
hw_info
.
device_id
=
a
.
device
().
index
();
hw_info
.
sm_count
=
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
hw_info
.
device_id
);
using
Compressor
=
typename
Gemm
::
Compressor
;
typename
Compressor
::
Arguments
arguments
{
prob_shape
,
{
a_ptr
,
a_stride
,
a_nzs_ptr
,
a_meta_ptr
},
{
hw_info
}};
Compressor
compressor_op
;
size_t
workspace_size
=
Compressor
::
get_workspace_size
(
arguments
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
CUTLASS_CHECK
(
compressor_op
.
can_implement
(
arguments
));
CUTLASS_CHECK
(
compressor_op
.
initialize
(
arguments
,
workspace
.
data_ptr
()));
CUTLASS_CHECK
(
compressor_op
.
run
());
CUDA_CHECK
(
cudaDeviceSynchronize
());
return
{
a_meta
,
a_nzs
};
}
template
<
typename
Gemm
,
typename
...
EpilogueArgs
>
void
cutlass_sparse_gemm_caller
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
// Interface stride expected from the argument a (will get transposed)
// We compute C^T = B^T * A^T, but we assume B is transposed before
// compression and hence the bt_* naming
using
LayoutB
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
LayoutA
;
using
LayoutE
=
typename
Gemm
::
GemmKernel
::
CollectiveMainloop
::
LayoutE
;
// M, N, K after transposition
int32_t
m
=
out
.
size
(
1
);
int32_t
n
=
out
.
size
(
0
);
int32_t
k
=
a
.
size
(
1
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
Stride
<
Int
<
1
>
,
int64_t
,
int64_t
>
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
Int
<
0
>
{}};
StrideC
c_stride
{
Int
<
1
>
{},
ldc
,
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
n
,
k
,
1
};
using
ElementE
=
typename
GemmKernel
::
CollectiveMainloop
::
ElementE
;
using
SparseConfig
=
typename
GemmKernel
::
CollectiveMainloop
::
SparseConfig
;
LayoutB
b_layout
=
SparseConfig
::
fill_layoutA
(
prob_shape
);
LayoutE
e_layout
=
SparseConfig
::
fill_layoutE
(
prob_shape
);
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
bt_nzs
.
data_ptr
());
auto
e_ptr
=
static_cast
<
ElementE
*>
(
bt_meta
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
b_ptr
,
b_layout
,
a_ptr
,
a_stride
,
e_ptr
,
e_layout
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
Gemm
::
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...),
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
epilogue_args
};
// Launch the CUTLASS GEMM kernel.
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
GemmOp
gemm_op
;
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
a
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
a
.
get_device
());
cutlass
::
Status
status
=
gemm_op
.
run
(
args
,
workspace
.
data_ptr
(),
stream
);
CUTLASS_CHECK
(
status
);
}
//////////////////////////////////////////////////
// Gemm Configs are defined below
//////////////////////////////////////////////////
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_config_default
{};
template
<
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_config_default
<
half_t
,
OutType
,
Epilogue
>
{
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
half_t
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_config_default
<
cutlass
::
bfloat16_t
,
OutType
,
Epilogue
>
{
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
cutlass
::
bfloat16_t
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
//////////////////////// Cherry-Picking Kernels ////////////////////////
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_1
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_8
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_2
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_128
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_8
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_3
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_4
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_8
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_5
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_8
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_6
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_7
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_8
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_128
,
_256
,
_128
>
;
using
ClusterShape
=
Shape
<
_8
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
////////////////////////////////////////////////////////////////////////
template
<
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_config_default
<
cutlass
::
float_e4m3_t
,
OutType
,
Epilogue
>
{
// M in (128, inf)
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
cutlass
::
float_e4m3_t
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M64
{
// M in [1, 64]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M128
{
// M in (64, 128]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpongFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M256
{
// M in (128, 256]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_fp8_config_M512
{
// M in (256, ]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_config_default
<
int8_t
,
OutType
,
Epilogue
>
{
// For M > 128 and any N
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
int8_t
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M128
{
// For M in (64, 128] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M64
{
// For M in (32, 64] and any N
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NBig
{
// For M in [1, 32] and N >= 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_4
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm90_int8_config_M32_NSmall
{
// For M in [1, 32] and N < 8192
static_assert
(
std
::
is_same
<
InType
,
int8_t
>
());
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecialized
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileShape
=
Shape
<
_64
,
_64
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_8
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_sparse_3x_gemm
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
}
// namespace
csrc/sparse/cutlass/sparse_scaled_mm_entry.cu
deleted
100644 → 0
View file @
fafe76b4
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
bool
cutlass_sparse_scaled_mm_supported
(
int64_t
cuda_device_capability
)
{
// sparse CUTLASS kernels need exactly hopper and are not forward compatible
// CUDA 12.2 and SM90 (Hopper)
#if defined CUDA_VERSION
return
CUDA_VERSION
>=
12020
&&
cuda_device_capability
==
90
;
#endif
return
false
;
}
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
void
cutlass_scaled_sparse_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
e
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
using
CompressorResult
=
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
;
CompressorResult
cutlass_sparse_compress_sm90
(
torch
::
Tensor
const
&
a
);
#endif
void
cutlass_scaled_sparse_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
bt_nzs
,
torch
::
Tensor
const
&
bt_meta
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
bt_nzs
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
1
)
==
bt_nzs
.
size
(
0
)
&&
bt_nzs
.
size
(
1
)
*
2
==
a
.
size
(
1
)
&&
a
.
size
(
0
)
==
c
.
size
(
0
));
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
bt_nzs
.
size
(
0
));
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
bt_nzs
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
bt_nzs
.
stride
(
0
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
numel
()
==
bt_nzs
.
size
(
0
)
&&
bias
->
is_contiguous
()
&&
bias
->
dim
()
==
1
);
}
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
// We build for 9.0a which is not forward compatible, so restrict this to
// Hopper only
if
(
version_num
==
90
)
{
cutlass_scaled_sparse_mm_sm90
(
c
,
a
,
bt_nzs
,
bt_meta
,
a_scales
,
b_scales
,
bias
);
return
;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
std
::
vector
<
torch
::
Tensor
>
cutlass_sparse_compress
(
torch
::
Tensor
const
&
a
)
{
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
a
.
stride
(
0
)
%
8
==
0
);
// 8 Byte Alignment for Compression
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
// We build for 9.0a which is not forward compatible, so restrict this to
// Hopper only
if
(
version_num
==
90
)
{
std
::
vector
<
torch
::
Tensor
>
result_tensors
;
auto
[
a_meta
,
a_nzs
]
=
cutlass_sparse_compress_sm90
(
a
);
result_tensors
.
push_back
(
std
::
move
(
a_nzs
));
result_tensors
.
push_back
(
std
::
move
(
a_meta
));
return
result_tensors
;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_sparse_compress for a compute capability equal to "
"CUDA device capability: "
,
version_num
);
}
csrc/torch_bindings.cpp
View file @
38364a7e
...
...
@@ -523,26 +523,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"cutlass_scaled_mm_supports_block_fp8"
,
&
cutlass_scaled_mm_supports_block_fp8
);
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
// given capability
ops
.
def
(
"cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool"
);
ops
.
impl
(
"cutlass_sparse_scaled_mm_supported"
,
&
cutlass_sparse_scaled_mm_supported
);
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops
.
def
(
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor bt_nzs,"
" Tensor bt_meta, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()"
);
ops
.
impl
(
"cutlass_scaled_sparse_mm"
,
torch
::
kCUDA
,
&
cutlass_scaled_sparse_mm
);
// CUTLASS sparse matrix compressor
ops
.
def
(
"cutlass_sparse_compress(Tensor a) -> Tensor[]"
);
ops
.
impl
(
"cutlass_sparse_compress"
,
&
cutlass_sparse_compress
);
// SM100 CUTLASS MLA decode
ops
.
def
(
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
...
...
tests/kernels/quantization/test_cutlass_2of4_sparse.py
deleted
100644 → 0
View file @
fafe76b4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for sparse cutlass kernels
Run `pytest tests/kernels/quantization/test_cutlass_2of4_sparse.py`.
"""
import
pytest
import
torch
from
tests.kernels.utils
import
baseline_scaled_mm
,
to_fp8
,
to_int8
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
sparse_cutlass_supported
,
)
from
vllm.platforms
import
current_platform
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
accelerator
.
device_count
()
==
1
else
2
)
]
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
def
to_bf16
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
tensor
.
to
(
dtype
=
torch
.
bfloat16
)
def
to_fp16
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
tensor
.
to
(
dtype
=
torch
.
float16
)
def
prune_to_2_4
(
tensor
):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape
=
tensor
.
shape
reshaped
=
tensor
.
reshape
(
-
1
,
4
)
# Get indices of top 2 absolute values in each group of 4
_
,
indices
=
torch
.
topk
(
torch
.
abs
(
reshaped
),
k
=
2
,
dim
=
1
)
# Create binary mask
mask
=
torch
.
zeros_like
(
reshaped
)
mask
.
scatter_
(
dim
=
1
,
index
=
indices
,
src
=
torch
.
ones_like
(
indices
,
dtype
=
mask
.
dtype
))
# Apply mask and reshape back
pruned
=
reshaped
*
mask
# Turn all -0.0 to 0.0
pruned
[
pruned
==
-
0.0
]
=
0.0
return
pruned
.
reshape
(
original_shape
)
# This function checks that applying an identity matrix multiplication
# to the compressed weights yields the original uncompressed weights.
def
check_compress_decompress_invariance
(
dtype
:
torch
.
dtype
,
b
:
torch
.
Tensor
,
b_compressed
:
torch
.
Tensor
,
b_metadata
:
torch
.
Tensor
,
):
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
# same dtype as its inputs. This line addresses that constraint while
# arbitrarily using bfloat16 for the int8/fp8 cases.
out_dtype
=
torch
.
float16
if
dtype
is
torch
.
float16
else
torch
.
bfloat16
eye
=
torch
.
eye
(
b
.
shape
[
0
],
device
=
"cuda"
,
dtype
=
dtype
)
eye_scale
=
torch
.
ones
(
1
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
b_decomp
=
ops
.
cutlass_scaled_sparse_mm
(
eye
,
b_compressed
,
b_metadata
,
eye_scale
,
eye_scale
,
out_dtype
=
out_dtype
)
torch
.
testing
.
assert_close
(
b
.
to
(
dtype
=
out_dtype
),
b_decomp
)
def
make_rand_sparse_tensors
(
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
)
b
=
torch
.
randn
((
n
,
k
),
device
=
"cuda"
).
t
()
if
dtype
==
torch
.
int8
:
# ensure A and B aren't all zeros after rounding
a
=
a
*
5.0
b
=
b
*
5.0
b
=
prune_to_2_4
(
b
.
t
()).
t
()
if
dtype
==
torch
.
int8
:
a
,
b
=
to_int8
(
a
),
to_int8
(
b
)
elif
dtype
==
torch
.
float8_e4m3fn
:
a
,
b
=
to_fp8
(
a
),
to_fp8
(
b
)
elif
dtype
==
torch
.
float16
:
a
,
b
=
to_fp16
(
a
),
to_fp16
(
b
)
elif
dtype
==
torch
.
bfloat16
:
a
,
b
=
to_bf16
(
a
),
to_bf16
(
b
)
else
:
raise
ValueError
(
"unsupported dtype"
)
b_compressed
,
e
=
ops
.
cutlass_sparse_compress
(
b
.
t
())
check_compress_decompress_invariance
(
dtype
,
b
,
b_compressed
,
e
)
# Compressed B, Metadata, Original A, B
return
b_compressed
,
e
,
a
,
b
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse CUTLASS is not supported on this GPU type."
,
)
# Test working with a subset of A and B for sparse matmul
def
test_cutlass_sparse_subset
():
big_m
=
1024
m
,
n
,
k
=
512
,
512
,
512
# Create tensors
b_comp
,
e
,
whole_a
,
b
=
make_rand_sparse_tensors
(
torch
.
float8_e4m3fn
,
big_m
,
n
,
k
)
a
=
whole_a
[
0
:
m
,
0
:
k
]
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_comp
,
e
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
MNK_FACTORS
=
[
(
1
,
256
,
128
),
(
1
,
16384
,
1024
),
(
1
,
24576
,
512
),
(
16
,
256
,
512
),
(
16
,
16384
,
128
),
(
16
,
24576
,
4096
),
(
32
,
8192
,
4096
),
(
32
,
16384
,
4096
),
(
33
,
1024
,
1024
),
(
33
,
8192
,
128
),
(
64
,
2048
,
512
),
(
64
,
16384
,
1024
),
(
100
,
8192
,
512
),
(
128
,
32768
,
4096
),
(
256
,
4096
,
4096
),
(
512
,
256
,
1024
),
(
512
,
8192
,
4096
),
(
512
,
16384
,
128
),
(
512
,
24576
,
128
),
]
# Test working with a subset of A and B for sparse matmul
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse CUTLASS is not supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"m, n, k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_sparse_gemm
(
m
:
int
,
k
:
int
,
n
:
int
,
dtype
:
type
[
torch
.
dtype
],
use_bias
:
bool
):
# Create tensors
b_comp
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
dtype
,
m
,
n
,
k
)
scale_a
=
torch
.
ones
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
ones
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
rand
((
n
,),
device
=
"cuda"
,
dtype
=
dtype
)
if
use_bias
else
None
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_comp
,
e
,
scale_a
,
scale_b
,
out_dtype
=
dtype
,
bias
=
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
dtype
,
bias
=
bias
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
3e-1
)
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse CUTLASS is not supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"m, k, n"
,
MNK_FACTORS
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_sparse_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
use_bias
:
bool
):
# Create tensors
b_comp
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
torch
.
float8_e4m3fn
,
m
,
n
,
k
)
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
out_dtype
=
torch
.
bfloat16
bias
=
torch
.
rand
((
n
,),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
if
use_bias
else
None
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_comp
,
e
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
,
bias
=
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
,
bias
=
bias
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
3e-1
)
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse CUTLASS is not supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"m,k,n"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_sparse_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
# Create tensors
b_comp
,
e
,
a
,
b
=
make_rand_sparse_tensors
(
torch
.
int8
,
m
,
n
,
k
)
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
out_dtype
=
torch
.
bfloat16
bias
=
torch
.
rand
((
n
,),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
if
use_bias
else
None
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
,
b_comp
,
e
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
,
bias
=
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
,
bias
=
bias
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e0
,
atol
=
2e0
)
tests/quantization/test_compressed_tensors.py
View file @
38364a7e
...
...
@@ -12,7 +12,6 @@ from compressed_tensors.quantization import QuantizationType
from
tests.models.utils
import
check_logprobs_close
from
vllm.model_executor.layers.fused_moe
import
UnquantizedFusedMoEMethod
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensors24
,
CompressedTensorsLinearMethod
,
CompressedTensorsW4A4Fp4
,
CompressedTensorsW4A8Fp8
,
...
...
@@ -27,9 +26,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8
from
vllm.model_executor.layers.quantization.utils.nvfp4_utils
import
(
cutlass_fp4_supported
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
sparse_cutlass_supported
,
)
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.fa_utils
import
get_flash_attn_version
...
...
@@ -362,283 +358,6 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
def
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
,
format
=
"dense"
):
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
qkv_proj
.
scheme
.
weight_quant
.
strategy
==
weight_strategy
assert
qkv_proj
.
scheme
.
input_quant
.
strategy
==
input_strategy
assert
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
format
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
()
or
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing"
,
"channel"
,
"token"
,
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing"
,
"channel"
,
"tensor"
,
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing"
,
"tensor"
,
"tensor"
,
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing"
,
"tensor"
,
"token"
,
),
],
)
def
test_compressed_tensors_2of4_quant_fp8
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
float8_e4m3fn
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
()
or
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
,
"channel"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM"
,
"channel"
,
"tensor"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM"
,
"tensor"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM"
,
"tensor"
,
"tensor"
,
),
],
)
def
test_compressed_tensors_2of4_quant_fp8_compressed
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
float8_e4m3fn
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
,
format
=
"sparse-24-bitmask"
,
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"cutlass is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM"
,
"channel"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM"
,
"channel"
,
"tensor"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM"
,
"tensor"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM"
,
"tensor"
,
"tensor"
,
),
],
)
def
test_compressed_tensors_2of4_quant_int8_compressed
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
int8
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
,
format
=
"sparse-24-bitmask"
,
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Sparse FP8 is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing"
,
"channel"
,
"token"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing"
,
"tensor"
,
"tensor"
,
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing"
,
"tensor"
,
"token"
,
),
],
)
def
test_compressed_tensors_2of4_quant_int8
(
vllm_runner
,
args_2of4
):
model
,
weight_strategy
,
input_strategy
=
args_2of4
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
qkv_proj
.
scheme
.
weights_dtype
==
torch
.
int8
_test_2of4_quant_models
(
qkv_proj
,
weight_strategy
,
input_strategy
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"2of4 Sparse is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor"
)],
)
def
test_compressed_tensors_2of4_sparse
(
vllm_runner
,
args_2of4
):
model
=
args_2of4
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
qkv_proj
.
scheme
.
weight_quant
is
None
assert
qkv_proj
.
scheme
.
input_quant
is
None
assert
not
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"dense"
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
sparse_cutlass_supported
(),
reason
=
"Cutlass is not yet supported on this GPU type."
,
)
@
pytest
.
mark
.
parametrize
(
"args_2of4"
,
[(
"nm-testing/llama2.c-stories42M-pruned2.4-compressed"
)]
)
def
test_compressed_tensors_2of4_sparse_compressed
(
vllm_runner
,
args_2of4
):
model
=
args_2of4
with
vllm_runner
(
model
,
enforce_eager
=
True
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
CompressedTensorsLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
CompressedTensors24
)
assert
qkv_proj
.
scheme
.
weight_quant
is
None
assert
qkv_proj
.
scheme
.
input_quant
is
None
assert
not
qkv_proj
.
scheme
.
quantized
assert
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
sparsity_map
=
qkv_proj
.
quant_method
.
quantization_config
.
sparsity_scheme_map
# noqa: E501
assert
sparsity_map
.
get
(
"Linear"
).
format
==
"sparse-24-bitmask"
assert
sparsity_map
.
get
(
"Linear"
).
sparsity_structure
==
"2:4"
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
4
)
print
(
output
)
assert
output
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
...
...
tests/weight_loading/models.txt
View file @
38364a7e
...
...
@@ -20,8 +20,6 @@ compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
#compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
...
...
vllm/_custom_ops.py
View file @
38364a7e
...
...
@@ -876,10 +876,6 @@ def cutlass_scaled_mm_azp(
return
out
.
view
(
*
target_shape
)
def
cutlass_sparse_scaled_mm_supported
(
cuda_device_capability
:
int
)
->
bool
:
return
torch
.
ops
.
_C
.
cutlass_sparse_scaled_mm_supported
(
cuda_device_capability
)
def
cutlass_group_gemm_supported
(
cuda_device_capability
:
int
)
->
bool
:
if
cuda_device_capability
<
90
or
cuda_device_capability
>=
110
:
return
False
...
...
@@ -890,94 +886,6 @@ def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
return
False
def
cutlass_sparse_compress
(
a
:
torch
.
Tensor
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Compresses a sparse matrix for use with Cutlass sparse operations.
This function takes a dense tensor and compresses it into two components:
non-zero elements and metadata. The compressed representation is compatible
with Cutlass sparse kernels.
Args:
a (torch.Tensor):
The input tensor to be compressed. Must have one of the following data types:
- `torch.int8`
- `torch.float8_e4m3fn`
- `torch.bfloat16`
- `torch.float16`
Returns:
tuple[torch.Tensor, torch.Tensor]:
A tuple containing:
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
Raises:
ValueError: If the compression operation fails.
Notes:
- The `a_meta` tensor has a data type of `torch.uint8`.
- Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`).
- The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor.
- The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`.
"""
assert
a
.
dtype
in
[
torch
.
int8
,
torch
.
float8_e4m3fn
,
torch
.
bfloat16
,
torch
.
float16
]
assert
a
.
is_contiguous
()
# a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
elemsPerMetaElem
=
4
assert
a
.
shape
[
1
]
%
(
2
*
elemsPerMetaElem
)
==
0
return
torch
.
ops
.
_C
.
cutlass_sparse_compress
(
a
)
def
cutlass_scaled_sparse_mm
(
a
:
torch
.
Tensor
,
bt_nzs
:
torch
.
Tensor
,
bt_meta
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""
Performs a scaled sparse matrix multiplication using Cutlass.
Steps:
1. Create a dense matrix `a` of shape (m, k) on the CUDA device:
`a = torch.randn((m, k), device='cuda')`.
2. Create a dense matrix `b` of shape (k, n) on the CUDA device:
`b = torch.randn((k, n), device='cuda')`.
3. Prune matrix `b` to 2:4 sparsity along the specified dimension:
`b = prune_to_2_4(b, dim=0)`.
4. Compress the transposed sparse matrix `b.t()`:
`bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`.
5. Perform sparse matrix multiplication using the compressed matrix,
applying scaling factors for `a` and `b`, and the output data type:
`out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`.
Returns:
- The result of the scaled sparse matrix multiplication.
"""
assert
bt_nzs
.
shape
[
0
]
%
16
==
0
and
bt_nzs
.
shape
[
1
]
%
16
==
0
assert
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
assert
bias
is
None
or
bias
.
shape
[
0
]
==
bt_nzs
.
shape
[
0
]
and
bias
.
dtype
==
out_dtype
m
=
a
.
shape
[
0
]
n
=
bt_nzs
.
shape
[
0
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
_C
.
cutlass_scaled_sparse_mm
(
out
,
a
,
bt_nzs
,
bt_meta
,
scale_a
,
scale_b
,
bias
)
return
out
def
get_cutlass_moe_mm_data
(
topk_ids
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py
View file @
38364a7e
...
...
@@ -5,39 +5,16 @@ from collections.abc import Callable
from
typing
import
Any
import
torch
from
compressed_tensors
import
CompressionFormat
,
ModelCompressor
from
compressed_tensors.quantization
import
(
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
,
)
from
compressed_tensors.utils
import
combine_shards
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
)
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
convert_to_channelwise
,
sparse_cutlass_supported
,
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
ModelWeightParameter
,
PerTensorScaleParameter
,
)
__all__
=
[
"CompressedTensors24"
]
from
vllm.platforms
import
current_platform
class
CompressedTensors24
(
CompressedTensorsScheme
):
def
__init__
(
...
...
@@ -47,33 +24,11 @@ class CompressedTensors24(CompressedTensorsScheme):
input_quant
:
QuantizationArgs
|
None
=
None
,
model_compression_config
:
dict
[
str
,
Any
]
|
None
=
None
,
):
self
.
quantized
=
quantized
self
.
weight_quant
=
weight_quant
self
.
input_quant
=
input_quant
model_compressor
=
ModelCompressor
.
from_compression_config
(
model_compression_config
)
self
.
do_sparse_decompress
=
(
model_compressor
is
not
None
and
model_compressor
.
sparsity_config
.
format
==
CompressionFormat
.
sparse_24_bitmask
.
value
)
if
self
.
do_sparse_decompress
:
self
.
model_compressor
=
model_compressor
if
(
quantized
and
input_quant
is
not
None
and
self
.
_get_quant_dtype
()
==
current_platform
.
fp8_dtype
()
):
static
=
not
input_quant
.
dynamic
g_shape
=
GroupShape
.
PER_TENSOR
if
static
else
GroupShape
.
PER_TOKEN
self
.
quant_fp8
=
QuantFP8
(
static
,
g_shape
)
raise
NotImplementedError
(
"Sparse24 models are no longer supported by vLLM"
)
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# Only cutlass 3.x kernels are implemented so far
return
90
raise
NotImplementedError
(
"Sparse24 models are no longer supported by vLLM"
)
def
create_weights
(
self
,
...
...
@@ -85,164 +40,10 @@ class CompressedTensors24(CompressedTensorsScheme):
weight_loader
:
Callable
,
**
kwargs
,
):
if
not
sparse_cutlass_supported
():
raise
ValueError
(
"Sparse CUTLASS not supported. vLLM must be built with "
"CUDA 12.2 or later to use this feature"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size
=
input_size
layer
.
input_size_per_partition
=
input_size_per_partition
self
.
weights_dtype
:
torch
.
dtype
=
self
.
_get_params_dtype
(
params_dtype
)
# parameter to store uncompressed weight
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
self
.
weights_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
if
self
.
do_sparse_decompress
:
assert
all
(
partition_size
%
8
==
0
for
partition_size
in
output_partition_sizes
),
"All partitions must be divisible by 8 for "
"2:4 sparse compressed models"
shape
=
BasevLLMParameter
(
data
=
torch
.
empty
(
2
,
1
,
dtype
=
torch
.
int64
),
weight_loader
=
weight_loader
,
)
compressed_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
//
2
,
dtype
=
self
.
weights_dtype
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
bitmask
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
//
8
,
dtype
=
torch
.
uint8
,
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"shape"
,
shape
)
layer
.
register_parameter
(
"compressed"
,
compressed_weight
)
layer
.
register_parameter
(
"bitmask"
,
bitmask
)
# Check if quantized, not just 2:4 Sparse
if
self
.
quantized
:
if
(
self
.
weight_quant
and
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
):
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
(
(
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
,
)
else
:
assert
(
self
.
weight_quant
and
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# input quant will be non-none
if
self
.
input_quant
and
not
self
.
input_quant
.
dynamic
:
# register input quant scale
assert
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
input_scale
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
else
:
# for sparse-only, pass in 1 for weight/input scales
weight_scale
=
torch
.
nn
.
Parameter
(
data
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
input_scale
=
torch
.
nn
.
Parameter
(
data
=
torch
.
ones
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight"
,
weight
)
raise
NotImplementedError
(
"Sparse24 models are no longer supported by vLLM"
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
if
self
.
do_sparse_decompress
:
layer
.
weight
.
data
=
self
.
_decompress_bitmask_compressed_weight
(
compressed
=
layer
.
compressed
,
bitmask
=
layer
.
bitmask
,
layer
=
layer
,
)
# compressed and bitmask tensors
# are no longer needed after decompression
del
layer
.
compressed
del
layer
.
bitmask
# torch.compile workaround
if
hasattr
(
layer
,
"input_scale"
):
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
if
self
.
weight_quant
:
if
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
convert_to_channelwise
(
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
),
requires_grad
=
False
,
)
else
:
# torch.compile workaround
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
# Set all negative zero values to 0 prior to compression
if
layer
.
weight
.
dtype
.
is_floating_point
and
layer
.
weight
.
dtype
.
itemsize
>=
2
:
layer
.
weight
.
data
[
layer
.
weight
.
data
==
-
0.0
]
=
0.0
w_compressed
,
meta
=
ops
.
cutlass_sparse_compress
(
layer
.
weight
.
data
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
w_compressed
,
requires_grad
=
False
)
layer
.
meta
=
torch
.
nn
.
Parameter
(
meta
,
requires_grad
=
False
)
raise
NotImplementedError
(
"Sparse24 models are no longer supported by vLLM"
)
def
apply_weights
(
self
,
...
...
@@ -250,143 +51,4 @@ class CompressedTensors24(CompressedTensorsScheme):
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
if
self
.
quantized
:
scale
=
getattr
(
layer
,
"input_scale"
,
None
)
if
self
.
weights_dtype
==
torch
.
int8
:
ops_output
=
ops
.
scaled_int8_quant
(
x
,
scale
=
scale
)
q_input
=
ops_output
[
0
]
input_scale
=
ops_output
[
1
]
else
:
assert
self
.
weights_dtype
==
torch
.
float8_e4m3fn
q_input
,
input_scale
=
self
.
quant_fp8
(
x
,
scale
=
scale
)
else
:
# Not quantized, nothing to do with the input_scales, use as is
input_scale
=
layer
.
input_scale
q_input
=
x
out
=
ops
.
cutlass_scaled_sparse_mm
(
a
=
q_input
,
bt_nzs
=
layer
.
weight
,
bt_meta
=
layer
.
meta
,
scale_a
=
input_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
)
assert
out
.
is_contiguous
()
return
out
def
_get_params_dtype
(
self
,
params_dtype
:
torch
.
dtype
)
->
torch
.
dtype
:
if
not
self
.
quantized
:
return
params_dtype
return
self
.
_get_quant_dtype
()
def
_get_quant_dtype
(
self
)
->
torch
.
dtype
:
assert
self
.
quantized
assert
self
.
weight_quant
is
not
None
assert
self
.
input_quant
is
not
None
is_8_bits
=
self
.
weight_quant
.
num_bits
==
self
.
input_quant
.
num_bits
==
8
if
not
is_8_bits
:
raise
ValueError
(
"Cutlass only supports 8-bit quantization"
)
if
(
self
.
weight_quant
.
type
==
QuantizationType
.
FLOAT
and
self
.
input_quant
.
type
==
QuantizationType
.
FLOAT
):
return
torch
.
float8_e4m3fn
if
(
self
.
weight_quant
.
type
==
QuantizationType
.
INT
and
self
.
input_quant
.
type
==
QuantizationType
.
INT
):
return
torch
.
int8
raise
ValueError
(
"Quantization type not supported by Cutlass"
)
def
_decompress_bitmask_compressed_weight
(
self
,
compressed
:
torch
.
Tensor
,
bitmask
:
torch
.
Tensor
,
layer
:
torch
.
nn
.
Module
,
)
->
torch
.
Tensor
:
"""
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
return the result.
This function also supports sharded decompression.
:param compressed: The 2:4 sparse weight tensor compressed using the
sparse-24-bitmask compressor. This is different from
`cutlass_sparse_compress` which uses a different scheme (2 bits for
every nonzero element that represent the coordinate within the block
of 4). The bitmask compression here uses a bitmask to indicate the
positions of non-zero elements.
:param bitmask: The 2:4 bitmask associated with the compressed weights,
representing the positions of non-zero elements in the compressed
tensor.
:param layer: The layer whose weights need to be processed after
loading.
:return: The decompressed 2:4 sparse weight tensor.
"""
sparsity_compressor
=
self
.
model_compressor
.
sparsity_compressor
def
_process_split
(
bitmask_compressed_weight
:
torch
.
Tensor
,
shape
,
bitmask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
weight_data
=
dict
(
compressed
=
bitmask_compressed_weight
,
shape
=
shape
,
bitmask
=
bitmask
,
)
return
sparsity_compressor
.
decompress_weight
(
weight_data
)
split_weights
:
list
[
torch
.
Tensor
]
=
[]
split_bitmask
:
list
[
torch
.
Tensor
]
=
[]
split_shape
:
list
[
tuple
[
int
,
int
]]
=
[]
if
isinstance
(
layer
,
(
QKVParallelLinear
,
MergedColumnParallelLinear
)):
split_weights
=
torch
.
split
(
compressed
,
layer
.
logical_widths
)
split_bitmask
=
torch
.
split
(
bitmask
,
layer
.
logical_widths
)
split_shape
=
[
(
out
,
layer
.
input_size_per_partition
)
for
out
in
layer
.
logical_widths
]
if
split_weights
:
decompressed_shards
=
[
_process_split
(
compressed_weight
,
shape
,
bitmask
)
for
compressed_weight
,
shape
,
bitmask
in
zip
(
split_weights
,
split_shape
,
split_bitmask
)
]
decompressed
=
combine_shards
(
decompressed_shards
)
else
:
decompressed
=
sparsity_compressor
.
decompress_weight
(
dict
(
compressed
=
compressed
,
shape
=
(
layer
.
logical_widths
[
0
],
layer
.
input_size_per_partition
,
),
bitmask
=
bitmask
,
)
)
return
decompressed
raise
NotImplementedError
(
"Sparse24 models are no longer supported by vLLM"
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
View file @
38364a7e
...
...
@@ -20,7 +20,7 @@ class CompressedTensorsScheme(ABC):
"""
Get minimum device capability.
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
def
create_weights
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -28,7 +28,7 @@ class CompressedTensorsScheme(ABC):
Weight creation for the particular scheme. Inputs to this function
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
def
apply_weights
(
...
...
@@ -44,7 +44,7 @@ class CompressedTensorsScheme(ABC):
:param bias: bias parameter
"""
raise
NotImplementedError
raise
NotImplementedError
()
@
abstractmethod
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
...
...
@@ -52,4 +52,4 @@ class CompressedTensorsScheme(ABC):
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise
NotImplementedError
raise
NotImplementedError
()
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
38364a7e
...
...
@@ -8,16 +8,6 @@ from vllm import _custom_ops as ops
from
vllm.platforms
import
current_platform
def
sparse_cutlass_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
capability_tuple
=
current_platform
.
get_device_capability
()
capability
=
-
1
if
capability_tuple
is
None
else
capability_tuple
.
to_int
()
return
ops
.
cutlass_sparse_scaled_mm_supported
(
capability
)
def
cutlass_fp8_supported
()
->
bool
:
if
not
current_platform
.
is_cuda
():
return
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment