Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af7f4372
Commit
af7f4372
authored
Sep 03, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1
parents
5e19cdef
09c77926
Changes
465
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2300 additions
and
182 deletions
+2300
-182
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
+107
-78
benchmarks/kernels/benchmark_layernorm.py
benchmarks/kernels/benchmark_layernorm.py
+89
-0
benchmarks/kernels/benchmark_machete.py
benchmarks/kernels/benchmark_machete.py
+372
-0
benchmarks/kernels/benchmark_moe.py
benchmarks/kernels/benchmark_moe.py
+70
-38
benchmarks/kernels/benchmark_quant.py
benchmarks/kernels/benchmark_quant.py
+103
-0
benchmarks/kernels/graph_machete_bench.py
benchmarks/kernels/graph_machete_bench.py
+64
-0
benchmarks/kernels/weight_shapes.py
benchmarks/kernels/weight_shapes.py
+43
-0
collect_env.py
collect_env.py
+7
-2
csrc/attention/attention_utils.cuh
csrc/attention/attention_utils.cuh
+1
-1
csrc/core/scalar_type.hpp
csrc/core/scalar_type.hpp
+196
-31
csrc/cuda_utils.h
csrc/cuda_utils.h
+10
-0
csrc/cutlass_extensions/cute_utils.cuh
csrc/cutlass_extensions/cute_utils.cuh
+68
-0
csrc/cutlass_extensions/torch_utils.hpp
csrc/cutlass_extensions/torch_utils.hpp
+154
-0
csrc/cutlass_extensions/vllm_collective_builder.cuh
csrc/cutlass_extensions/vllm_collective_builder.cuh
+43
-0
csrc/cutlass_extensions/vllm_custom_types.cuh
csrc/cutlass_extensions/vllm_custom_types.cuh
+50
-0
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
+49
-0
csrc/cutlass_extensions/vllm_numeric_conversion.cuh
csrc/cutlass_extensions/vllm_numeric_conversion.cuh
+795
-0
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+19
-14
csrc/ops.h
csrc/ops.h
+40
-4
csrc/opt/layernorm_kernels_opt.cu
csrc/opt/layernorm_kernels_opt.cu
+20
-14
No files found.
Too many changes to show.
To preserve performance only
465 of 465+
files are displayed.
Plain diff
Email patch
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
View file @
af7f4372
...
...
@@ -32,7 +32,6 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor:
def
make_rand_tensors
(
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
a
=
torch
.
randn
((
m
,
k
),
device
=
'cuda'
)
*
5
b
=
torch
.
randn
((
n
,
k
),
device
=
'cuda'
).
t
()
*
5
...
...
@@ -44,59 +43,18 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
raise
ValueError
(
"unsupported dtype"
)
# impl
def
pytorch_mm_impl
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
return
torch
.
mm
(
a
,
b
)
def
pytorch_fp8_impl
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
return
torch
.
_scaled_mm
(
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
out_dtype
)
def
pytorch_fp8_impl_fast_accum
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
return
torch
.
_scaled_mm
(
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
out_dtype
,
use_fast_accum
=
True
)
def
cutlass_impl
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
return
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
)
# bench
def
bench_fn
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
label
:
str
,
sub_label
:
str
,
fn
:
Callable
,
description
:
str
)
->
TMeasurement
:
def
bench_fn
(
label
:
str
,
sub_label
:
str
,
description
:
str
,
fn
:
Callable
,
*
args
,
**
kwargs
)
->
TMeasurement
:
min_run_time
=
1
globals
=
{
"a"
:
a
,
"b"
:
b
,
"scale_a"
:
scale_a
,
"scale_b"
:
scale_b
,
"out_dtype"
:
out_dtype
,
"args"
:
args
,
"kwargs"
:
kwargs
,
"fn"
:
fn
,
}
return
TBenchmark
.
Timer
(
stmt
=
"fn(
a, b, scale_a, scale_b, out_dtype
)"
,
stmt
=
"fn(
*args, **kwargs
)"
,
globals
=
globals
,
label
=
label
,
sub_label
=
sub_label
,
...
...
@@ -110,26 +68,58 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
a
,
b
=
make_rand_tensors
(
torch
.
int8
,
m
,
n
,
k
)
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
zeros
((
n
,
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
azp
=
torch
.
zeros
((
m
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
azp_adj
=
torch
.
zeros
((
n
,
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
timers
=
[]
# pytorch impl - bfloat16
timers
.
append
(
bench_fn
(
a
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
pytorch_mm_impl
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
))
bench_fn
(
label
,
sub_label
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
bfloat16
),
b
.
to
(
dtype
=
torch
.
bfloat16
)))
# pytorch impl - float16
timers
.
append
(
bench_fn
(
a
.
to
(
dtype
=
torch
.
float16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
float16
,
device
=
"cuda"
),
scale_a
,
scale_b
,
torch
.
float16
,
label
,
sub_label
,
pytorch_mm_impl
,
"pytorch_fp16_fp16_fp16_matmul-no-scales"
))
bench_fn
(
label
,
sub_label
,
"pytorch_fp16_fp16_fp16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
float16
),
b
.
to
(
dtype
=
torch
.
float16
)))
# cutlass impl
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
cutlass_impl
,
"cutlass_i8_i8_bf16_scaled_mm"
))
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
))
# cutlass with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
))
# cutlass with azp per-tensor
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
))
# cutlass with azp per-tensor + bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp_bias"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
None
,
bias
))
# cutlass with azp per-token
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp_pt"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
azp
))
# cutlass with azp per-token + bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias"
,
ops
.
cutlass_scaled_mm_azp
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
azp_adj
,
azp
,
bias
))
return
timers
...
...
@@ -140,46 +130,88 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
a
,
b
=
make_rand_tensors
(
torch
.
float8_e4m3fn
,
m
,
n
,
k
)
scale_a
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
scale_b
=
torch
.
tensor
(
1.0
,
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
zeros
((
n
,
),
device
=
"cuda"
,
dtype
=
torch
.
bfloat16
)
timers
=
[]
# pytorch impl w. bf16
timers
.
append
(
bench_fn
(
a
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
pytorch_mm_impl
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
))
bench_fn
(
label
,
sub_label
,
"pytorch_bf16_bf16_bf16_matmul-no-scales"
,
torch
.
mm
,
a
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
),
b
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)))
# pytorch impl: bf16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
pytorch_fp8_impl
,
"pytorch_fp8_fp8_bf16_scaled_mm"
))
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_bf16_scaled_mm"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
bfloat16
))
# pytorch impl: bf16 output, with fp8 fast accum
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
pytorch_fp8_impl_fast_accum
,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"
))
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
bfloat16
,
use_fast_accum
=
True
))
# pytorch impl: fp16 output, without fp8 fast accum
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
label
,
sub_label
,
pytorch_fp8_impl
,
"pytorch_fp8_fp8_fp16_scaled_mm"
))
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_fp16_scaled_mm"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float16
))
# pytorch impl: fp16 output, with fp8 fast accum
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
label
,
sub_label
,
pytorch_fp8_impl_fast_accum
,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"
))
bench_fn
(
label
,
sub_label
,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"
,
torch
.
_scaled_mm
,
a
,
b
,
scale_a
=
scale_a
,
scale_b
=
scale_b
,
out_dtype
=
torch
.
float16
,
use_fast_accum
=
True
))
# cutlass impl: bf16 output
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
label
,
sub_label
,
cutlass_impl
,
"cutlass_fp8_fp8_bf16_scaled_mm"
))
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
))
# cutlass impl: fp16 output
timers
.
append
(
bench_fn
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
label
,
sub_label
,
cutlass_impl
,
"cutlass_fp8_fp8_fp16_scaled_mm"
))
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_fp16_scaled_mm"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
))
# cutlass impl: bf16 output, with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_bf16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
bfloat16
,
bias
))
# cutlass impl: fp16 output, with bias
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"cutlass_fp8_fp8_fp16_scaled_mm_bias"
,
ops
.
cutlass_scaled_mm
,
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
,
bias
.
to
(
dtype
=
torch
.
float16
)))
return
timers
...
...
@@ -200,7 +232,6 @@ def print_timers(timers: Iterable[TMeasurement]):
def
run
(
dtype
:
torch
.
dtype
,
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]])
->
Iterable
[
TMeasurement
]:
results
=
[]
for
m
,
k
,
n
in
MKNs
:
timers
=
bench
(
dtype
,
m
,
k
,
n
,
f
"scaled-
{
dtype
}
-gemm"
,
...
...
@@ -216,7 +247,6 @@ def make_output(data: Iterable[TMeasurement],
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]],
base_description
:
str
,
timestamp
=
None
):
print
(
f
"== All Results
{
base_description
}
===="
)
print_timers
(
data
)
...
...
@@ -251,7 +281,6 @@ def run_range_bench(args):
def
run_model_bench
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
...
...
benchmarks/kernels/benchmark_layernorm.py
0 → 100644
View file @
af7f4372
import
random
import
time
import
torch
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
@
torch
.
inference_mode
()
def
main
(
num_tokens
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
do_profile
:
bool
=
False
,
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
"cuda"
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
x
*=
scale
residual
=
torch
.
randn_like
(
x
)
*
scale
if
add_residual
else
None
def
run_cuda_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
num_iters
):
layer
(
x
,
residual
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
print
(
"Warming up..."
)
run_benchmark
=
run_cuda_benchmark
run_benchmark
(
num_iters
=
num_warmup_iters
,
profile
=
False
)
# Benchmark.
if
do_profile
:
latency
=
run_benchmark
(
num_iters
=
1
,
profile
=
True
)
else
:
latency
=
run_benchmark
(
num_iters
=
num_iters
,
profile
=
False
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'__main__'
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the layernorm kernel."
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--add-residual"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-warmup-iters"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-iters"
,
type
=
int
,
default
=
100
,
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
num_tokens
=
args
.
num_tokens
,
hidden_size
=
args
.
hidden_size
,
add_residual
=
args
.
add_residual
,
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
)
benchmarks/kernels/benchmark_machete.py
0 → 100644
View file @
af7f4372
import
argparse
import
copy
import
itertools
import
math
import
pickle
as
pkl
import
time
from
typing
import
Callable
,
Iterable
,
List
,
Tuple
import
torch
import
torch.utils.benchmark
as
TBenchmark
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
marlin_permute_scales
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
gptq_pack
,
pack_rows
,
quantize_weights
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
FlexibleArgumentParser
DEFAULT_MODELS
=
[
"meta-llama/Llama-3-8b"
,
"meta-llama/Llama-2-70b-hf"
]
DEFAULT_BATCH_SIZES
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
]
DEFAULT_TP_SIZES
=
[
1
]
def
machete_pack_weights
(
w_q
:
torch
.
tensor
,
wtype
:
ScalarType
)
->
torch
.
tensor
:
w_q
=
pack_rows
(
w_q
,
wtype
.
size_bits
,
*
w_q
.
shape
)
w_q
=
w_q
.
t
().
contiguous
().
t
()
# make col major
return
ops
.
machete_prepack_B
(
w_q
,
wtype
)
def
make_bench_tensors
(
atype
:
torch
.
dtype
,
wtype
:
ScalarType
,
group_size
:
int
,
m
:
int
,
n
:
int
,
k
:
int
)
->
Tuple
[
torch
.
tensor
,
List
[
Tuple
[
torch
.
tensor
,
torch
.
tensor
,
torch
.
tensor
,
torch
.
tensor
]]]:
assert
wtype
.
is_integer
(),
"TODO: support floating point weights"
# we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb
num_weights
=
math
.
ceil
(
2
*
50
*
1024
**
2
*
8
/
(
k
*
n
*
wtype
.
size_bits
))
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
atype
)
*
5
weights
=
[
torch
.
randn
((
k
,
n
),
device
=
"cuda"
,
dtype
=
atype
)
for
_
in
range
(
num_weights
)
]
quanitized_weights
=
[
quantize_weights
(
w
,
wtype
,
group_size
)
for
w
in
weights
]
return
a
,
quanitized_weights
# impl
# bench
def
bench_fn
(
label
:
str
,
sub_label
:
str
,
description
:
str
,
fn
:
Callable
)
->
TMeasurement
:
min_run_time
=
1
return
TBenchmark
.
Timer
(
stmt
=
"fn()"
,
globals
=
{
"fn"
:
fn
},
label
=
label
,
sub_label
=
sub_label
,
description
=
description
,
).
blocked_autorange
(
min_run_time
=
min_run_time
)
def
loop_over_weights
(
a
:
torch
.
tensor
,
weights
:
List
[
Tuple
[
torch
.
tensor
,
torch
.
tensor
,
torch
.
tensor
,
torch
.
tensor
]],
fn
:
Callable
[[
torch
.
tensor
,
torch
.
tensor
,
torch
.
tensor
,
torch
.
tensor
],
None
]):
for
w_ref
,
w_q
,
w_s
,
_
in
weights
:
fn
(
a
,
w_ref
,
w_q
,
w_s
)
def
bench
(
atype
:
torch
.
dtype
,
wtype
:
ScalarType
,
group_size
:
int
,
m
:
int
,
k
:
int
,
n
:
int
,
label
:
str
,
sub_label
:
str
,
benchmark_marlinv1
:
bool
=
True
,
sweep_schedules
:
bool
=
True
)
->
Iterable
[
TMeasurement
]:
a
,
weights
=
make_bench_tensors
(
atype
,
wtype
,
group_size
,
m
,
n
,
k
)
sub_label
+=
f
", L=
{
len
(
weights
)
}
"
weights_machete
=
[(
w_ref
,
machete_pack_weights
(
w_q
,
wtype
),
w_s
,
w_zp
)
for
w_ref
,
w_q
,
w_s
,
w_zp
in
weights
]
timers
=
[]
# pytorch impl
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"torch.matmul"
,
lambda
:
loop_over_weights
(
a
,
weights
,
lambda
a
,
w_ref
,
w_q
,
w_s
:
torch
.
matmul
(
a
,
w_ref
),
)))
if
benchmark_marlinv1
:
w_ref
=
weights
[
0
][
0
]
w_zp_empty
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w_ref
.
device
)
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w_ref
.
device
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
w_ref
.
device
)
def
marlinv1_pack_weights
(
w_q
:
torch
.
tensor
)
->
torch
.
tensor
:
w_q_gptq
=
gptq_pack
(
w_q
,
wtype
.
size_bits
,
*
w_ref
.
shape
)
return
ops
.
gptq_marlin_repack
(
w_q_gptq
,
sort_indices
,
*
w_ref
.
shape
,
wtype
.
size_bits
)
def
marlinv1_permute_scales
(
w_s
:
torch
.
tensor
)
->
torch
.
tensor
:
return
marlin_permute_scales
(
w_s
,
*
w_ref
.
shape
,
group_size
)
weights_marlinv1
=
[(
w_ref
,
marlinv1_pack_weights
(
w_q
),
marlinv1_permute_scales
(
w_s
),
w_zp
)
for
w_ref
,
w_q
,
w_s
,
w_zp
in
weights
]
workspace
=
MarlinWorkspace
(
w_ref
.
shape
[
1
],
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
# marlinv1
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"marlin_orig"
,
lambda
:
loop_over_weights
(
a
,
weights_marlinv1
,
lambda
a
,
w_ref
,
w_q
,
w_s
:
ops
.
gptq_marlin_gemm
(
a
,
w_q
,
w_s
,
w_zp_empty
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
wtype
,
size_m
=
a
.
shape
[
0
],
size_n
=
w_ref
.
shape
[
1
],
size_k
=
w_ref
.
shape
[
0
],
is_k_full
=
True
))))
# machete
timers
.
append
(
bench_fn
(
label
,
sub_label
,
"machete_heuristic"
,
lambda
:
loop_over_weights
(
a
,
weights_machete
,
lambda
a
,
_
,
w_q
,
w_s
:
ops
.
machete_gemm
(
a
,
w_q
,
wtype
,
b_scales
=
w_s
,
b_group_size
=
group_size
))))
if
sweep_schedules
:
print
(
"Finding best schedule for machete"
)
best
=
None
best_schedule
=
None
schedules
=
ops
.
machete_supported_schedules
(
wtype
)
for
schedule
in
reversed
(
schedules
):
def
run
(
a
,
_
,
w_q
,
w_s
,
schedule
=
schedule
):
ops
.
machete_gemm
(
a
,
w_q
,
wtype
,
w_s
,
b_group_size
=
group_size
,
schedule
=
schedule
)
res
=
bench_fn
(
label
,
sub_label
,
"machete_best"
,
lambda
:
loop_over_weights
(
a
,
weights_machete
,
run
))
print
(
f
"
{
res
.
median
:
5.5
}
"
,
schedule
)
if
not
best
or
res
.
median
<
best
.
median
:
best
=
res
best_schedule
=
schedule
print
(
"Best schedule:"
,
best_schedule
)
timers
.
append
(
best
)
return
timers
# runner
def
print_timers
(
timers
:
Iterable
[
TMeasurement
]):
compare
=
TBenchmark
.
Compare
(
timers
)
compare
.
print
()
def
run
(
dtype
:
torch
.
dtype
,
sweep_schedules
:
bool
,
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]])
->
Iterable
[
TMeasurement
]:
results
=
[]
for
m
,
k
,
n
in
MKNs
:
timers
=
bench
(
dtype
,
scalar_types
.
uint4b8
,
128
,
m
,
k
,
n
,
f
"
{
dtype
}
-gemm"
,
f
"MKN=(
{
m
}
x
{
k
}
x
{
n
}
)"
,
sweep_schedules
=
sweep_schedules
)
print_timers
(
timers
)
results
.
extend
(
timers
)
return
results
# output makers
def
make_output
(
data
:
Iterable
[
TMeasurement
],
MKNs
:
Iterable
[
Tuple
[
int
,
int
,
int
]],
base_description
:
str
,
timestamp
=
None
,
):
print
(
f
"== All Results
{
base_description
}
===="
)
print_timers
(
data
)
# pickle all the results
timestamp
=
int
(
time
.
time
())
if
timestamp
is
None
else
timestamp
with
open
(
f
"
{
base_description
}
-
{
timestamp
}
.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
(
data
,
f
)
# argparse runners
def
run_square_bench
(
args
):
dim_sizes
=
list
(
range
(
args
.
dim_start
,
args
.
dim_end
+
1
,
args
.
dim_increment
))
MKNs
=
list
(
zip
(
dim_sizes
,
dim_sizes
,
dim_sizes
))
data
=
run
(
args
.
dtype
,
args
.
sweep_schedules
,
MKNs
)
make_output
(
data
,
MKNs
,
f
"square_bench-
{
args
.
dtype
}
"
)
def
run_range_bench
(
args
):
dim_sizes
=
list
(
range
(
args
.
dim_start
,
args
.
dim_end
,
args
.
dim_increment
))
n
=
len
(
dim_sizes
)
Ms
=
[
args
.
m_constant
]
*
n
if
args
.
m_constant
is
not
None
else
dim_sizes
Ks
=
[
args
.
k_constant
]
*
n
if
args
.
k_constant
is
not
None
else
dim_sizes
Ns
=
[
args
.
n_constant
]
*
n
if
args
.
n_constant
is
not
None
else
dim_sizes
MKNs
=
list
(
zip
(
Ms
,
Ks
,
Ns
))
data
=
run
(
args
.
dtype
,
args
.
sweep_schedules
,
MKNs
)
make_output
(
data
,
MKNs
,
f
"range_bench-
{
args
.
dtype
}
"
)
def
run_model_bench
(
args
):
print
(
"Benchmarking models:"
)
for
i
,
model
in
enumerate
(
args
.
models
):
print
(
f
"[
{
i
}
]
{
model
}
"
)
def
model_shapes
(
model_name
:
str
,
tp_size
:
int
)
->
List
[
Tuple
[
int
,
int
]]:
KNs
=
[]
for
KN
,
tp_split_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model_name
]):
KN
[
tp_split_dim
]
=
KN
[
tp_split_dim
]
//
tp_size
KNs
.
append
(
KN
)
return
KNs
model_bench_data
=
[]
models_tps
=
list
(
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
))
for
model
,
tp_size
in
models_tps
:
Ms
=
args
.
batch_sizes
KNs
=
model_shapes
(
model
,
tp_size
)
MKNs
=
[]
for
m
in
Ms
:
for
k
,
n
in
KNs
:
MKNs
.
append
((
m
,
k
,
n
))
data
=
run
(
args
.
dtype
,
args
.
sweep_schedules
,
MKNs
)
model_bench_data
.
append
(
data
)
# Print all results
for
data
,
model_tp
in
zip
(
model_bench_data
,
models_tps
):
model
,
tp_size
=
model_tp
print
(
f
"== Results
{
args
.
dtype
}
{
model
}
-TP
{
tp_size
}
===="
)
print_timers
(
data
)
timestamp
=
int
(
time
.
time
())
all_data
=
[]
for
d
in
model_bench_data
:
all_data
.
extend
(
d
)
# pickle all data
with
open
(
f
"model_bench-
{
args
.
dtype
}
-
{
timestamp
}
.pkl"
,
"wb"
)
as
f
:
pkl
.
dump
(
all_data
,
f
)
if
__name__
==
"__main__"
:
def
to_torch_dtype
(
dt
):
if
dt
==
"bfloat16"
:
return
torch
.
bfloat16
if
dt
==
"float16"
:
return
torch
.
float16
raise
ValueError
(
"unsupported dtype"
)
parser
=
FlexibleArgumentParser
(
description
=
"""
Benchmark Machete GEMM.
To run square GEMMs:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
"""
,
# noqa: E501
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
parser
.
add_argument
(
"--dtype"
,
type
=
to_torch_dtype
,
required
=
True
,
help
=
"Available options are ['bfloat16', 'float16']"
,
)
parser
.
add_argument
(
"--sweep-schedules"
,
action
=
"store_true"
,
help
=
"Run a sweep over all supported schedules"
,
)
subparsers
=
parser
.
add_subparsers
(
dest
=
"cmd"
,
required
=
True
)
square_parser
=
subparsers
.
add_parser
(
"square_bench"
)
square_parser
.
add_argument
(
"--dim-start"
,
type
=
int
,
required
=
True
)
square_parser
.
add_argument
(
"--dim-end"
,
type
=
int
,
required
=
True
)
square_parser
.
add_argument
(
"--dim-increment"
,
type
=
int
,
required
=
True
)
square_parser
.
set_defaults
(
func
=
run_square_bench
)
range_parser
=
subparsers
.
add_parser
(
"range_bench"
)
range_parser
.
add_argument
(
"--dim-start"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--dim-end"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--dim-increment"
,
type
=
int
,
required
=
True
)
range_parser
.
add_argument
(
"--m-constant"
,
type
=
int
,
default
=
None
)
range_parser
.
add_argument
(
"--n-constant"
,
type
=
int
,
default
=
None
)
range_parser
.
add_argument
(
"--k-constant"
,
type
=
int
,
default
=
None
)
range_parser
.
set_defaults
(
func
=
run_range_bench
)
model_parser
=
subparsers
.
add_parser
(
"model_bench"
)
model_parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
DEFAULT_MODELS
,
choices
=
WEIGHT_SHAPES
.
keys
(),
)
model_parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_TP_SIZES
)
model_parser
.
add_argument
(
"--batch-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
DEFAULT_BATCH_SIZES
)
model_parser
.
set_defaults
(
func
=
run_model_bench
)
args
=
parser
.
parse_args
()
args
.
func
(
args
)
benchmarks/kernels/benchmark_moe.py
View file @
af7f4372
...
...
@@ -30,11 +30,28 @@ def benchmark_config(
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
)
->
float
:
init_dtype
=
torch
.
float16
if
use_fp8
else
dtype
init_dtype
=
torch
.
float16
if
use_fp8
_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
if
use_int8_w8a16
:
w1
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
),
dtype
=
torch
.
int8
)
w2
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
),
dtype
=
torch
.
int8
)
else
:
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
...
...
@@ -52,7 +69,11 @@ def benchmark_config(
w2_scale
=
None
a1_scale
=
None
a2_scale
=
None
if
use_fp8
:
if
use_int8_w8a16
:
w1_scale
=
torch
.
randn
((
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
)
if
use_fp8_w8a8
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
)
...
...
@@ -76,7 +97,8 @@ def benchmark_config(
renormalize
=
True
,
inplace
=
True
,
override_config
=
config
,
use_fp8
=
use_fp8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
...
...
@@ -155,11 +177,13 @@ class BenchmarkWorker:
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8
:
bool
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
)
->
Tuple
[
Dict
[
str
,
int
],
float
]:
torch
.
cuda
.
manual_seed_all
(
self
.
seed
)
dtype_str
=
"float8"
if
use_fp8
else
None
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config
=
get_moe_configs
(
num_experts
,
shard_intermediate_size
//
2
,
...
...
@@ -173,7 +197,8 @@ class BenchmarkWorker:
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
)
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
)
return
config
,
kernel_time
def
tune
(
...
...
@@ -184,9 +209,10 @@ class BenchmarkWorker:
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8
:
bool
,
search_space
:
List
[
BenchmarkConfig
],
)
->
BenchmarkConfig
:
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
search_space
:
List
[
Dict
[
str
,
int
]],
)
->
Dict
[
str
,
int
]:
best_config
=
None
best_time
=
float
(
"inf"
)
for
config
in
tqdm
(
search_space
):
...
...
@@ -198,7 +224,8 @@ class BenchmarkWorker:
hidden_size
,
topk
,
dtype
,
use_fp8
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
10
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
...
...
@@ -224,20 +251,19 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
}
def
save_configs
(
configs
:
Dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8
:
bool
,
)
->
None
:
dtype_str
=
"float8"
if
use_fp8
else
None
def
save_configs
(
configs
:
Dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
)
->
None
:
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename
=
get_config_file_name
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
)
print
(
f
"Writing best config to
{
filename
}
..."
)
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
configs
,
f
,
indent
=
4
)
...
...
@@ -253,6 +279,11 @@ def main(args: argparse.Namespace):
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
args
.
tp_size
else
:
# Default: Mixtral.
E
=
config
.
num_local_experts
...
...
@@ -262,7 +293,8 @@ def main(args: argparse.Namespace):
hidden_size
=
config
.
hidden_size
dtype
=
config
.
torch_dtype
use_fp8
=
args
.
dtype
==
"fp8"
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
if
args
.
batch_size
is
None
:
batch_sizes
=
[
...
...
@@ -294,20 +326,20 @@ def main(args: argparse.Namespace):
start
=
time
.
time
()
configs
=
_distribute
(
"tune"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
,
search_space
)
topk
,
dtype
,
use_fp8
_w8a8
,
use_int8_w8a16
,
search_space
)
for
batch_size
in
batch_sizes
])
best_configs
=
{
M
:
sort_config
(
config
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
}
save_configs
(
best_configs
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
)
topk
,
dtype
,
use_fp8
_w8a8
,
use_int8_w8a16
)
end
=
time
.
time
()
print
(
f
"Tuning took
{
end
-
start
:.
2
f
}
seconds"
)
else
:
outputs
=
_distribute
(
"benchmark"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
)
outputs
=
_distribute
(
"benchmark"
,
[(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8
_w8a8
,
use_int8_w8a16
)
for
batch_size
in
batch_sizes
])
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
...
...
@@ -323,7 +355,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8"
],
choices
=
[
"auto"
,
"fp8
_w8a8"
,
"int8_w8a16
"
],
default
=
"auto"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
required
=
False
)
...
...
benchmarks/kernels/benchmark_quant.py
0 → 100644
View file @
af7f4372
import
random
import
time
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
@
torch
.
inference_mode
()
def
main
(
num_tokens
:
int
,
hidden_size
:
int
,
static_scale
:
bool
,
quant_dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
do_profile
:
bool
=
False
,
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
"cuda"
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
scale
=
torch
.
randn
(
1
,
1
,
dtype
=
torch
.
float32
)
if
static_scale
else
None
def
run_cuda_benchmark
(
num_iters
:
int
,
profile
:
bool
=
False
)
->
float
:
torch
.
cuda
.
synchronize
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
start_time
=
time
.
perf_counter
()
for
_
in
range
(
num_iters
):
if
quant_dtype
==
torch
.
int8
:
ops
.
scaled_int8_quant
(
x
,
scale
)
else
:
ops
.
scaled_fp8_quant
(
x
,
scale
)
torch
.
cuda
.
synchronize
()
end_time
=
time
.
perf_counter
()
if
profile
:
torch
.
cuda
.
cudart
().
cudaProfilerStart
()
return
(
end_time
-
start_time
)
/
num_iters
# Warmup.
print
(
"Warming up..."
)
run_benchmark
=
run_cuda_benchmark
run_benchmark
(
num_iters
=
num_warmup_iters
,
profile
=
False
)
# Benchmark.
if
do_profile
:
latency
=
run_benchmark
(
num_iters
=
1
,
profile
=
True
)
else
:
latency
=
run_benchmark
(
num_iters
=
num_iters
,
profile
=
False
)
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'__main__'
:
def
to_torch_dtype
(
dt
):
if
dt
==
"int8"
:
return
torch
.
int8
if
dt
==
"fp8"
:
return
torch
.
float8_e4m3fn
raise
ValueError
(
f
"Unsupported dtype:
{
dt
}
"
)
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the quantization (fp8 or int8) kernel."
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--static-scale"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--quant-dtype"
,
type
=
str
,
choices
=
[
"fp8"
,
"int8"
],
default
=
"int8"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-warmup-iters"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-iters"
,
type
=
int
,
default
=
100
,
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
num_tokens
=
args
.
num_tokens
,
hidden_size
=
args
.
hidden_size
,
static_scale
=
args
.
static_scale
,
quant_dtype
=
to_torch_dtype
(
args
.
quant_dtype
),
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
)
benchmarks/kernels/graph_machete_bench.py
0 → 100644
View file @
af7f4372
import
math
import
pickle
import
re
from
collections
import
defaultdict
from
typing
import
List
import
matplotlib.pyplot
as
plt
import
pandas
as
pd
import
seaborn
as
sns
from
torch.utils.benchmark
import
Measurement
as
TMeasurement
from
vllm.utils
import
FlexibleArgumentParser
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
'Benchmark the latency of processing a single batch of '
'requests till completion.'
)
parser
.
add_argument
(
'filename'
,
type
=
str
)
args
=
parser
.
parse_args
()
with
open
(
args
.
filename
,
'rb'
)
as
f
:
data
:
List
[
TMeasurement
]
=
pickle
.
load
(
f
)
results
=
defaultdict
(
lambda
:
list
())
for
v
in
data
:
result
=
re
.
search
(
r
"MKN=\(\d+x(\d+x\d+)\)"
,
v
.
task_spec
.
sub_label
)
if
result
is
not
None
:
KN
=
result
.
group
(
1
)
else
:
raise
Exception
(
"MKN not found"
)
result
=
re
.
search
(
r
"MKN=\((\d+)x\d+x\d+\)"
,
v
.
task_spec
.
sub_label
)
if
result
is
not
None
:
M
=
result
.
group
(
1
)
else
:
raise
Exception
(
"MKN not found"
)
kernel
=
v
.
task_spec
.
description
results
[
KN
].
append
({
"kernel"
:
kernel
,
"batch_size"
:
M
,
"median"
:
v
.
median
})
rows
=
int
(
math
.
ceil
(
len
(
results
)
/
2
))
fig
,
axs
=
plt
.
subplots
(
rows
,
2
,
figsize
=
(
12
,
5
*
rows
))
axs
=
axs
.
flatten
()
axs_idx
=
0
for
shape
,
data
in
results
.
items
():
plt
.
sca
(
axs
[
axs_idx
])
df
=
pd
.
DataFrame
(
data
)
sns
.
lineplot
(
data
=
df
,
x
=
"batch_size"
,
y
=
"median"
,
hue
=
"kernel"
,
style
=
"kernel"
,
markers
=
True
,
dashes
=
False
,
palette
=
"Dark2"
)
plt
.
title
(
f
"Shape:
{
shape
}
"
)
plt
.
ylabel
(
"time (median, s)"
)
axs_idx
+=
1
plt
.
tight_layout
()
plt
.
savefig
(
"graph_machete_bench.pdf"
)
benchmarks/kernels/weight_shapes.py
0 → 100644
View file @
af7f4372
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES
=
{
"mistralai/Mistral-7B-v0.1"
:
[
([
4096
,
6144
],
1
),
([
4096
,
4096
],
0
),
([
4096
,
28672
],
1
),
([
14336
,
4096
],
0
),
],
"meta-llama/Llama-2-7b-hf"
:
[
([
4096
,
12288
],
1
),
([
4096
,
4096
],
0
),
([
4096
,
22016
],
1
),
([
11008
,
4096
],
0
),
],
"meta-llama/Llama-3-8b"
:
[
([
4096
,
6144
],
1
),
([
4096
,
4096
],
0
),
([
4096
,
28672
],
1
),
([
14336
,
4096
],
0
),
],
"meta-llama/Llama-2-13b-hf"
:
[
([
5120
,
15360
],
1
),
([
5120
,
5120
],
0
),
([
5120
,
27648
],
1
),
([
13824
,
5120
],
0
),
],
"meta-llama/Llama-2-70b-hf"
:
[
([
8192
,
10240
],
1
),
([
8192
,
8192
],
0
),
([
8192
,
57344
],
1
),
([
28672
,
8192
],
0
),
],
}
collect_env.py
View file @
af7f4372
...
...
@@ -66,6 +66,8 @@ DEFAULT_CONDA_PATTERNS = {
"nccl"
,
"transformers"
,
"zmq"
,
"nvidia"
,
"pynvml"
,
}
DEFAULT_PIP_PATTERNS
=
{
...
...
@@ -79,6 +81,8 @@ DEFAULT_PIP_PATTERNS = {
"nccl"
,
"transformers"
,
"zmq"
,
"nvidia"
,
"pynvml"
,
}
...
...
@@ -265,8 +269,9 @@ def get_neuron_sdk_version(run_lambda):
def
get_vllm_version
():
try
:
import
vllm
return
vllm
.
__version__
except
ImportError
:
return
vllm
.
__version__
+
"@"
+
vllm
.
__commit__
except
Exception
:
# old version of vllm does not have __commit__
return
'N/A'
...
...
csrc/attention/attention_utils.cuh
View file @
af7f4372
...
...
@@ -122,7 +122,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
A_vec
qk_vec
=
mul
<
A_vec
,
Vec
,
Vec
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
qk_vec
=
vllm
::
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
// Finalize the reduction across lanes.
float
qk
=
sum
(
qk_vec
);
...
...
csrc/core/scalar_type.hpp
View file @
af7f4372
...
...
@@ -21,7 +21,7 @@ namespace vllm {
//
class
ScalarType
{
public:
enum
NanRepr
:
int
64
_t
{
enum
NanRepr
:
u
int
8
_t
{
NAN_NONE
=
0
,
// nans are not supported
NAN_IEEE_754
=
1
,
// nans are: exp all 1s, mantissa not all 0s
NAN_EXTD_RANGE_MAX_MIN
=
2
,
// nans are: exp all 1s, mantissa all 1s
...
...
@@ -29,33 +29,33 @@ class ScalarType {
NAN_REPR_ID_MAX
};
constexpr
ScalarType
(
bool
signed_
,
int
64
_t
exponent
,
int
64
_t
mantissa
,
int
64
_t
bias
,
bool
finite_values_only
=
false
,
constexpr
ScalarType
(
u
int
8
_t
exponent
,
u
int
8
_t
mantissa
,
bool
signed_
,
int
32
_t
bias
,
bool
finite_values_only
=
false
,
NanRepr
nan_repr
=
NAN_IEEE_754
)
:
exponent
(
exponent
),
mantissa
(
mantissa
),
bias
(
bias
),
signed_
(
signed_
),
bias
(
bias
),
finite_values_only
(
finite_values_only
),
nan_repr
(
nan_repr
){};
static
constexpr
ScalarType
int_
(
int
64
_t
size_bits
,
int
64
_t
bias
=
0
)
{
return
ScalarType
(
true
,
0
,
size_bits
-
1
,
bias
);
static
constexpr
ScalarType
int_
(
u
int
8
_t
size_bits
,
int
32
_t
bias
=
0
)
{
return
ScalarType
(
0
,
size_bits
-
1
,
true
,
bias
);
}
static
constexpr
ScalarType
uint
(
int
64
_t
size_bits
,
int
64
_t
bias
=
0
)
{
return
ScalarType
(
false
,
0
,
size_bits
,
bias
);
static
constexpr
ScalarType
uint
(
u
int
8
_t
size_bits
,
int
32
_t
bias
=
0
)
{
return
ScalarType
(
0
,
size_bits
,
false
,
bias
);
}
// IEEE 754 compliant floating point type
static
constexpr
ScalarType
float_IEEE754
(
int
64
_t
exponent
,
int
64
_t
mantissa
)
{
static
constexpr
ScalarType
float_IEEE754
(
u
int
8
_t
exponent
,
u
int
8
_t
mantissa
)
{
TORCH_CHECK
(
mantissa
>
0
&&
exponent
>
0
);
return
ScalarType
(
true
,
exponent
,
mantissa
,
0
,
false
,
NAN_IEEE_754
);
return
ScalarType
(
exponent
,
mantissa
,
true
,
0
,
false
,
NAN_IEEE_754
);
}
// IEEE 754 non-compliant floating point type
static
constexpr
ScalarType
float_
(
int
64
_t
exponent
,
int
64
_t
mantissa
,
static
constexpr
ScalarType
float_
(
u
int
8
_t
exponent
,
u
int
8
_t
mantissa
,
bool
finite_values_only
,
NanRepr
nan_repr
)
{
TORCH_CHECK
(
nan_repr
<
NAN_REPR_ID_MAX
,
"Invalid NanRepr"
);
...
...
@@ -63,36 +63,121 @@ class ScalarType {
TORCH_CHECK
(
nan_repr
!=
NAN_IEEE_754
,
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions"
);
return
ScalarType
(
true
,
exponent
,
mantissa
,
0
,
finite_values_only
,
return
ScalarType
(
exponent
,
mantissa
,
true
,
0
,
finite_values_only
,
nan_repr
);
}
int
64
_t
const
exponent
;
// size of the exponent field (0 for integer types)
int
64
_t
const
mantissa
;
// size of the mantissa field (size of the integer
u
int
8
_t
const
exponent
;
// size of the exponent field (0 for integer types)
u
int
8
_t
const
mantissa
;
// size of the mantissa field (size of the integer
// excluding the sign bit for integer types)
int64_t
const
bias
;
// stored values equal value + bias,
// used for quantized type
bool
const
signed_
;
// flag if the type supports negative numbers (i.e. has a
// sign bit)
int32_t
const
bias
;
// stored values equal value + bias,
// used for quantized type
// Extra Floating point info
bool
const
finite_values_only
;
// i.e. no +/-inf if true
NanRepr
const
nan_repr
;
// how NaNs are represented
// (not applicable for integer types)
int64_t
size_bits
()
const
{
return
mantissa
+
exponent
+
is_signed
();
}
bool
is_signed
()
const
{
return
signed_
;
}
bool
is_integer
()
const
{
return
exponent
==
0
;
}
bool
is_floating_point
()
const
{
return
exponent
>
0
;
}
bool
is_ieee_754
()
const
{
using
Id
=
int64_t
;
private:
// Field size in id
template
<
typename
T_
>
static
constexpr
size_t
member_id_field_width
()
{
using
T
=
std
::
decay_t
<
T_
>
;
return
std
::
is_same_v
<
T
,
bool
>
?
1
:
sizeof
(
T
)
*
8
;
}
template
<
typename
Fn
,
typename
Init
,
typename
Member
,
typename
...
Rest
>
static
constexpr
auto
reduce_members_helper
(
Fn
f
,
Init
val
,
Member
member
,
Rest
...
rest
)
{
auto
new_val
=
f
(
val
,
member
);
if
constexpr
(
sizeof
...(
rest
)
>
0
)
{
return
reduce_members_helper
(
f
,
new_val
,
rest
...);
}
else
{
return
new_val
;
};
}
template
<
typename
Fn
,
typename
Init
>
constexpr
auto
reduce_members
(
Fn
f
,
Init
init
)
const
{
// Should be in constructor order for `from_id`
return
reduce_members_helper
(
f
,
init
,
exponent
,
mantissa
,
signed_
,
bias
,
finite_values_only
,
nan_repr
);
};
template
<
typename
Fn
,
typename
Init
>
static
constexpr
auto
reduce_member_types
(
Fn
f
,
Init
init
)
{
constexpr
auto
dummy_type
=
ScalarType
(
0
,
0
,
false
,
0
,
false
,
NAN_NONE
);
return
dummy_type
.
reduce_members
(
f
,
init
);
};
static
constexpr
auto
id_size_bits
()
{
return
reduce_member_types
(
[](
int
acc
,
auto
member
)
->
int
{
return
acc
+
member_id_field_width
<
decltype
(
member
)
>
();
},
0
);
}
public:
// unique id for this scalar type that can be computed at compile time for
// c++17 template specialization this is not needed once we migrate to
// c++20 and can pass literal classes as template parameters
constexpr
Id
id
()
const
{
static_assert
(
id_size_bits
()
<=
sizeof
(
Id
)
*
8
,
"ScalarType id is too large to be stored"
);
auto
or_and_advance
=
[](
std
::
pair
<
Id
,
uint32_t
>
result
,
auto
member
)
->
std
::
pair
<
Id
,
uint32_t
>
{
auto
[
id
,
bit_offset
]
=
result
;
auto
constexpr
bits
=
member_id_field_width
<
decltype
(
member
)
>
();
return
{
id
|
(
int64_t
(
member
)
&
((
uint64_t
(
1
)
<<
bits
)
-
1
))
<<
bit_offset
,
bit_offset
+
bits
};
};
return
reduce_members
(
or_and_advance
,
std
::
pair
<
Id
,
uint32_t
>
{}).
first
;
}
// create a ScalarType from an id, for c++17 template specialization,
// this is not needed once we migrate to c++20 and can pass literal
// classes as template parameters
static
constexpr
ScalarType
from_id
(
Id
id
)
{
auto
extract_and_advance
=
[
id
](
auto
result
,
auto
member
)
{
using
T
=
decltype
(
member
);
auto
[
tuple
,
bit_offset
]
=
result
;
auto
constexpr
bits
=
member_id_field_width
<
T
>
();
auto
extracted_val
=
static_cast
<
T
>
((
int64_t
(
id
)
>>
bit_offset
)
&
((
uint64_t
(
1
)
<<
bits
)
-
1
));
auto
new_tuple
=
std
::
tuple_cat
(
tuple
,
std
::
make_tuple
(
extracted_val
));
return
std
::
pair
<
decltype
(
new_tuple
),
int
>
{
new_tuple
,
bit_offset
+
bits
};
};
auto
[
tuple_args
,
_
]
=
reduce_member_types
(
extract_and_advance
,
std
::
pair
<
std
::
tuple
<>
,
int
>
{});
return
std
::
apply
([](
auto
...
args
)
{
return
ScalarType
(
args
...);
},
tuple_args
);
}
constexpr
int64_t
size_bits
()
const
{
return
mantissa
+
exponent
+
is_signed
();
}
constexpr
bool
is_signed
()
const
{
return
signed_
;
}
constexpr
bool
is_integer
()
const
{
return
exponent
==
0
;
}
constexpr
bool
is_floating_point
()
const
{
return
exponent
>
0
;
}
constexpr
bool
is_ieee_754
()
const
{
return
is_floating_point
()
&&
finite_values_only
==
false
&&
nan_repr
==
NAN_IEEE_754
;
}
bool
has_nans
()
const
{
return
is_floating_point
()
&&
nan_repr
!=
NAN_NONE
;
}
bool
has_infs
()
const
{
constexpr
bool
has_nans
()
const
{
return
is_floating_point
()
&&
nan_repr
!=
NAN_NONE
;
}
constexpr
bool
has_infs
()
const
{
return
is_floating_point
()
&&
finite_values_only
==
false
;
}
bool
has_bias
()
const
{
return
bias
!=
0
;
}
constexpr
bool
has_bias
()
const
{
return
bias
!=
0
;
}
private:
double
_floating_point_max
()
const
{
...
...
@@ -132,7 +217,7 @@ class ScalarType {
return
*
reinterpret_cast
<
double
*>
(
&
double_raw
);
}
std
::
variant
<
int64_t
,
double
>
_raw_max
()
const
{
constexpr
std
::
variant
<
int64_t
,
double
>
_raw_max
()
const
{
if
(
is_floating_point
())
{
return
{
_floating_point_max
()};
}
else
{
...
...
@@ -142,7 +227,7 @@ class ScalarType {
}
}
std
::
variant
<
int64_t
,
double
>
_raw_min
()
const
{
constexpr
std
::
variant
<
int64_t
,
double
>
_raw_min
()
const
{
if
(
is_floating_point
())
{
TORCH_CHECK
(
is_signed
(),
"We currently assume all floating point types are signed"
);
...
...
@@ -169,7 +254,7 @@ class ScalarType {
public:
// Max representable value for this scalar type.
// (accounting for bias if there is one)
std
::
variant
<
int64_t
,
double
>
max
()
const
{
constexpr
std
::
variant
<
int64_t
,
double
>
max
()
const
{
return
std
::
visit
(
[
this
](
auto
x
)
->
std
::
variant
<
int64_t
,
double
>
{
return
{
x
-
bias
};
},
_raw_max
());
...
...
@@ -177,7 +262,7 @@ class ScalarType {
// Min representable value for this scalar type.
// (accounting for bias if there is one)
std
::
variant
<
int64_t
,
double
>
min
()
const
{
constexpr
std
::
variant
<
int64_t
,
double
>
min
()
const
{
return
std
::
visit
(
[
this
](
auto
x
)
->
std
::
variant
<
int64_t
,
double
>
{
return
{
x
-
bias
};
},
_raw_min
());
...
...
@@ -216,7 +301,7 @@ class ScalarType {
}
}
bool
operator
==
(
ScalarType
const
&
other
)
const
{
constexpr
bool
operator
==
(
ScalarType
const
&
other
)
const
{
return
mantissa
==
other
.
mantissa
&&
exponent
==
other
.
exponent
&&
bias
==
other
.
bias
&&
signed_
==
other
.
signed_
&&
finite_values_only
==
other
.
finite_values_only
&&
...
...
@@ -229,6 +314,8 @@ class ScalarType {
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
// constructor at the same time (torch::CustomClassHolder does not have a
// constexpr destructor)
// See also:
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
class
ScalarTypeTorch
:
public
torch
::
CustomClassHolder
,
public
ScalarType
{
public:
ScalarTypeTorch
(
int64_t
exponent
,
int64_t
mantissa
,
int64_t
bias
,
...
...
@@ -241,31 +328,90 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
using
Self
=
ScalarTypeTorch
;
using
SelfPtr
=
c10
::
intrusive_ptr
<
Self
>
;
static
void
check_size_bits
(
int64_t
size_bits
,
bool
signed_
)
{
TORCH_CHECK
(
size_bits
<=
std
::
numeric_limits
<
decltype
(
std
::
declval
<
Self
>
().
mantissa
)
>::
max
(),
"size_bits bit width is too large to be represented"
);
}
static
void
check_bias
(
int64_t
bias
)
{
using
Bias
=
decltype
(
std
::
declval
<
Self
>
().
bias
);
TORCH_CHECK
(
bias
<=
std
::
numeric_limits
<
Bias
>::
max
()
&&
bias
>=
std
::
numeric_limits
<
Bias
>::
min
(),
"bias too large or small to be represented"
);
}
static
void
check_exponent
(
int64_t
exponent
)
{
TORCH_CHECK
(
exponent
<=
std
::
numeric_limits
<
decltype
(
std
::
declval
<
Self
>
().
exponent
)
>::
max
(),
"exponent bit width is too large to be represented"
);
}
static
void
check_mantissa
(
int64_t
mantissa
)
{
TORCH_CHECK
(
mantissa
<=
std
::
numeric_limits
<
decltype
(
std
::
declval
<
Self
>
().
mantissa
)
>::
max
(),
"mantissa bit width is too large to be represented"
);
}
static
SelfPtr
int_
(
int64_t
size_bits
,
c10
::
optional
<
int64_t
>
bias
)
{
check_size_bits
(
size_bits
,
true
);
check_bias
(
bias
.
value_or
(
0
));
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
int_
(
size_bits
,
bias
.
value_or
(
0
)));
}
static
SelfPtr
uint
(
int64_t
size_bits
,
c10
::
optional
<
int64_t
>
bias
)
{
check_size_bits
(
size_bits
,
true
);
check_bias
(
bias
.
value_or
(
0
));
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
uint
(
size_bits
,
bias
.
value_or
(
0
)));
}
static
SelfPtr
float_IEEE754
(
int64_t
exponent
,
int64_t
mantissa
)
{
check_mantissa
(
mantissa
);
check_exponent
(
exponent
);
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
float_IEEE754
(
exponent
,
mantissa
));
}
static
SelfPtr
float_
(
int64_t
exponent
,
int64_t
mantissa
,
bool
finite_values_only
,
int64_t
nan_repr
)
{
check_mantissa
(
mantissa
);
check_exponent
(
exponent
);
return
c10
::
make_intrusive
<
Self
>
(
ScalarType
::
float_
(
exponent
,
mantissa
,
finite_values_only
,
NanRepr
(
nan_repr
)));
}
// This needs to be implemented and throw a TypeError in order for
// PyTorch's opcheck to work on ops that use ScalarTypes.
int64_t
len
()
const
{
throw
c10
::
TypeError
(
"__len__ not implemented"
);
return
0
;
}
// Serialize a ScalarType into a tuple of pairs. Where each pair
// is a (fieldname, value).
// For simplicity, we are just going to convert to a ScalarTypeId.
std
::
tuple
<
std
::
tuple
<
std
::
string
,
int64_t
>>
obj_flatten
()
const
{
return
{{
"ScalarType"
,
id
()}};
}
// Deserialize a scalar type that has been serialized by obj_flatten,
// ostensibly from a tuple of (member name, value) pairs, but in reality
// just a ScalarTypeId.
static
SelfPtr
obj_unflatten
(
std
::
tuple
<
std
::
tuple
<
std
::
string
,
int64_t
>>
const
&
flat_type
)
{
return
c10
::
make_intrusive
<
Self
>
(
from_id
(
std
::
get
<
1
>
(
std
::
get
<
0
>
(
flat_type
))));
}
template
<
typename
T
>
static
void
bind_readonly_property
(
torch
::
class_
<
Self
>&
cls
,
std
::
string
const
&
name
,
T
Base
::*
field
)
{
auto
getter_func
=
[
field
=
std
::
move
(
field
)](
SelfPtr
const
&
self
)
{
auto
getter_func
_helper
=
[
field
=
std
::
move
(
field
)](
SelfPtr
const
&
self
)
{
if
constexpr
(
std
::
is_member_function_pointer_v
<
decltype
(
field
)
>
)
{
return
(
self
.
get
()
->*
field
)();
}
else
{
...
...
@@ -273,6 +419,18 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
}
};
auto
getter_func
=
[
field
=
std
::
move
(
field
),
getter_func_helper
=
std
::
move
(
getter_func_helper
)](
SelfPtr
const
&
self
)
{
auto
val
=
getter_func_helper
(
self
);
// upconvert uint8_t, int32_t etc. to int64_t for python
if
constexpr
(
std
::
is_integral_v
<
T
>
)
{
return
static_cast
<
int64_t
>
(
val
);
}
else
{
return
val
;
}
};
cls
.
def_property
(
name
,
getter_func
);
}
...
...
@@ -325,6 +483,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
self
.
get
()
->
min
());
});
bind_function
(
cls
,
"__len__"
,
&
ScalarTypeTorch
::
len
);
bind_function
(
cls
,
"__str__"
,
&
Base
::
str
);
bind_function
(
cls
,
"__eq__"
,
[](
SelfPtr
const
&
self
,
SelfPtr
const
&
other
)
{
return
*
self
==
*
other
;
...
...
@@ -333,6 +492,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
return
"ScalarType."
+
self
.
get
()
->
str
();
});
bind_function
(
cls
,
"__obj_flatten__"
,
&
ScalarTypeTorch
::
obj_flatten
);
bind_static_function
(
cls
,
"__obj_unflatten__"
,
&
ScalarTypeTorch
::
obj_unflatten
);
// Bind static functions (convenience constructors)
bind_static_function
(
cls
,
"int_"
,
&
ScalarTypeTorch
::
int_
);
bind_static_function
(
cls
,
"uint"
,
&
ScalarTypeTorch
::
uint
);
...
...
@@ -341,6 +504,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
}
};
using
ScalarTypeId
=
int64_t
;
using
ScalarTypeTorchPtr
=
c10
::
intrusive_ptr
<
ScalarTypeTorch
>
;
// "rust style" names generally following:
...
...
@@ -380,4 +544,5 @@ static inline constexpr auto kHalf = kFE5M10;
static
inline
constexpr
auto
kFloat16
=
kHalf
;
static
inline
constexpr
auto
kBFloat16
=
kFE8M7
;
static
inline
constexpr
auto
kFloat16Id
=
kFloat16
.
id
();
};
// namespace vllm
csrc/cuda_utils.h
View file @
af7f4372
#pragma once
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
#define DEVICE_INLINE __forceinline__ __device__
#define HOST_INLINE __forceinline__ __host__
#else
#define HOST_DEVICE_INLINE inline
#define DEVICE_INLINE inline
#define HOST_INLINE inline
#endif
int64_t
get_device_attribute
(
int64_t
attribute
,
int64_t
device_id
);
int64_t
get_max_shared_memory_per_block_device_attribute
(
int64_t
device_id
);
csrc/cutlass_extensions/cute_utils.cuh
0 → 100644
View file @
af7f4372
#pragma once
#include <cute/tensor.hpp>
#include <torch/all.h>
namespace
cute
{
////////////////////////////////////////////////////////////////////
// layout utils
////////////////////////////////////////////////////////////////////
// Permute layout based on indices, example:
// permute_layout<1, 0>(layout) will swap the two dimensions
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
template
<
size_t
...
I
,
typename
Layout
>
CUTE_HOST_DEVICE
static
constexpr
auto
permute_layout
(
Layout
l
)
{
static_assert
(
rank
(
l
)
==
sizeof
...(
I
),
"Invalid permutation, rank mismatch"
);
return
cute
::
make_layout
(
cute
::
get
<
I
>
(
l
)...);
}
// is the layout f(x) = x
template
<
typename
Layout
>
CUTE_HOST_DEVICE
static
constexpr
bool
is_identity_layout
()
{
if
constexpr
(
std
::
is_same_v
<
Layout
,
void
>
)
return
true
;
else
{
constexpr
auto
coalesced_layout
=
coalesce
(
Layout
{});
if
constexpr
(
rank
(
coalesced_layout
)
==
1
&&
stride
<
0
>
(
coalesced_layout
)
==
1
)
{
return
true
;
}
return
false
;
}
}
////////////////////////////////////////////////////////////////////
// Pointer utils
////////////////////////////////////////////////////////////////////
template
<
class
PointerType
>
static
constexpr
auto
get_logical_ptr
(
PointerType
*
ptr
)
{
if
constexpr
(
cute
::
sizeof_bits_v
<
PointerType
>
<
8
)
{
return
cute
::
subbyte_iterator
<
PointerType
>
(
ptr
);
}
else
{
return
ptr
;
}
}
////////////////////////////////////////////////////////////////////
// Misc utils
////////////////////////////////////////////////////////////////////
template
<
typename
T
,
typename
Elements
>
CUTE_HOST_DEVICE
static
constexpr
auto
create_auto_vectorizing_copy
()
{
constexpr
auto
bits
=
sizeof_bits_v
<
T
>
*
Elements
{};
if
constexpr
(
bits
%
128
==
0
)
{
return
AutoVectorizingCopyWithAssumedAlignment
<
128
>
{};
}
else
if
constexpr
(
bits
%
64
==
0
)
{
return
AutoVectorizingCopyWithAssumedAlignment
<
64
>
{};
}
else
if
constexpr
(
bits
%
32
==
0
)
{
return
AutoVectorizingCopyWithAssumedAlignment
<
32
>
{};
}
else
if
constexpr
(
bits
%
16
==
0
)
{
return
AutoVectorizingCopyWithAssumedAlignment
<
16
>
{};
}
else
{
return
AutoVectorizingCopyWithAssumedAlignment
<
8
>
{};
}
}
};
// namespace cute
csrc/cutlass_extensions/torch_utils.hpp
0 → 100644
View file @
af7f4372
#pragma once
#include <torch/all.h>
#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"
using
ColumnMajor
=
typename
cutlass
::
layout
::
ColumnMajor
;
using
RowMajor
=
typename
cutlass
::
layout
::
RowMajor
;
namespace
cute
{
namespace
detail
{
template
<
class
T
,
class
F
,
class
G
,
int
...
I
>
CUTE_HOST_DEVICE
constexpr
auto
tapply_with_idx
(
T
&&
t
,
F
&&
f
,
G
&&
g
,
seq
<
I
...
>
)
{
return
g
(
f
(
cute
::
get
<
I
>
(
static_cast
<
T
&&>
(
t
)),
I
)...);
}
template
<
class
F
,
int
...
I
>
CUTE_HOST_DEVICE
constexpr
auto
make_shape_from_idx
(
F
&&
f
,
seq
<
I
...
>
)
{
return
make_shape
(
f
(
I
)...);
}
};
// namespace detail
template
<
class
T
,
class
F
>
CUTE_HOST_DEVICE
constexpr
auto
transform_with_idx
(
T
const
&
t
,
F
&&
f
)
{
if
constexpr
(
cute
::
is_tuple
<
T
>::
value
)
{
return
detail
::
tapply_with_idx
(
t
,
f
,
[](
auto
const
&
...
a
)
{
return
cute
::
make_tuple
(
a
...);
},
tuple_seq
<
T
>
{});
}
else
{
return
f
(
t
);
}
CUTE_GCC_UNREACHABLE
;
}
// calls: make_shape(f(0), f(1), ..., f(N-1))
template
<
int
N
,
class
F
>
CUTE_HOST_DEVICE
constexpr
auto
make_shape_from_idx
(
F
&&
f
)
{
return
detail
::
make_shape_from_idx
(
f
,
make_seq
<
N
>
{});
}
};
// namespace cute
// Make a layout from a tensor with `rank(Stride{})`, where the shape is the
// shape of the passed in tensor and the strides are of type `Stride` and
// contain the strides of the passed in tensor, checking that any static strides
// in `Stride{}` match the strides of the passed in tensor.
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
template
<
typename
Stride
>
static
inline
auto
make_cute_layout
(
torch
::
Tensor
const
&
tensor
,
std
::
string_view
name
=
"tensor"
)
{
TORCH_CHECK
(
tensor
.
dim
()
<=
rank
(
Stride
{}));
auto
stride
=
cute
::
transform_with_idx
(
Stride
{},
[
&
](
auto
const
&
stride_ele
,
auto
const
&
idx
)
{
using
StrideEle
=
std
::
decay_t
<
decltype
(
stride_ele
)
>
;
if
(
idx
<
tensor
.
dim
())
{
if
constexpr
(
cute
::
is_static_v
<
StrideEle
>
)
{
TORCH_CHECK
(
StrideEle
::
value
==
tensor
.
stride
(
idx
),
"Expected "
,
name
,
".stride("
,
idx
,
") to be "
,
StrideEle
::
value
);
return
StrideEle
{};
}
else
{
return
tensor
.
stride
(
idx
);
}
}
else
{
// Extra strides are assumed to be 0 or 1
if
constexpr
(
cute
::
is_static_v
<
StrideEle
>
)
{
static_assert
(
StrideEle
::
value
==
0
||
StrideEle
::
value
==
1
);
}
return
StrideEle
{};
}
});
auto
shape
=
cute
::
make_shape_from_idx
<
rank
(
Stride
{})
>
([
&
](
auto
const
&
idx
)
{
if
(
idx
<
tensor
.
dim
())
return
tensor
.
size
(
idx
);
else
return
int64_t
(
1
);
});
return
make_layout
(
shape
,
stride
);
}
template
<
typename
Stride
>
static
inline
auto
maybe_make_cute_layout
(
c10
::
optional
<
torch
::
Tensor
>
const
&
tensor
,
std
::
string_view
name
=
"tensor"
)
{
using
Layout
=
decltype
(
make_cute_layout
<
Stride
>
(
*
tensor
));
if
(
tensor
)
{
return
std
::
optional
<
Layout
>
{
make_cute_layout
<
Stride
>
(
*
tensor
,
name
)};
}
else
{
return
std
::
optional
<
Layout
>
{};
}
}
//
// Torch Type to Cutlass Type (equivalent_cutlass_type)
//
template
<
typename
T
>
struct
equivalent_cutlass_type
{
using
type
=
T
;
};
template
<
typename
T
>
using
equivalent_cutlass_type_t
=
typename
equivalent_cutlass_type
<
T
>::
type
;
template
<
>
struct
equivalent_cutlass_type
<
c10
::
Half
>
{
using
type
=
cutlass
::
half_t
;
};
template
<
>
struct
equivalent_cutlass_type
<
c10
::
BFloat16
>
{
using
type
=
cutlass
::
bfloat16_t
;
};
//
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
//
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
template
<
typename
T
>
struct
equivalent_scalar_type
{
using
type
=
T
;
};
template
<
typename
T
>
using
equivalent_scalar_type_t
=
typename
equivalent_scalar_type
<
T
>::
type
;
template
<
>
struct
equivalent_scalar_type
<
cutlass
::
half_t
>
{
using
type
=
c10
::
Half
;
};
template
<
>
struct
equivalent_scalar_type
<
cutlass
::
bfloat16_t
>
{
using
type
=
c10
::
BFloat16
;
};
// get equivalent c10::ScalarType tag from compile time type
template
<
typename
T
>
static
inline
constexpr
c10
::
ScalarType
equivalent_scalar_type_v
=
c10
::
CppTypeToScalarType
<
equivalent_scalar_type_t
<
T
>>::
value
;
\ No newline at end of file
csrc/cutlass_extensions/vllm_collective_builder.cuh
0 → 100644
View file @
af7f4372
#pragma once
#include "cutlass/gemm/collective/collective_builder.hpp"
namespace
cutlass
::
gemm
::
collective
{
using
namespace
cute
;
//
// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for
// for custom kernel tags, allowing you to build custom collectives. Without
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
// will resort to using the standard cutlass collective builder.
//
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
// collective
struct
CutlassKernelTag
{};
template
<
class
KernelTag
,
class
ArchTag
,
class
OpClass
,
class
ElementA
,
class
GmemLayoutA
,
int
AlignmentA
,
class
ElementB
,
class
GmemLayoutB
,
int
AlignmentB
,
class
ElementAccumulator
,
class
TileShape_MNK
,
class
ClusterShape_MNK
,
class
StageCountType
,
class
KernelScheduleType
,
class
Enable
=
void
>
struct
VLLMCollectiveBuilder
{
static_assert
(
sizeof
(
ElementA
)
==
0
,
"Could not build a collective for given parameters."
);
};
template
<
class
ArchTag
,
class
OpClass
,
class
ElementA
,
class
GmemLayoutA
,
int
AlignmentA
,
class
ElementB
,
class
GmemLayoutB
,
int
AlignmentB
,
class
ElementAccumulator
,
class
TileShape_MNK
,
class
ClusterShape_MNK
,
class
StageCountType
,
class
KernelScheduleType
>
struct
VLLMCollectiveBuilder
<
CutlassKernelTag
,
ArchTag
,
OpClass
,
ElementA
,
GmemLayoutA
,
AlignmentA
,
ElementB
,
GmemLayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelScheduleType
>
{
using
CollectiveOp
=
typename
CollectiveBuilder
<
ArchTag
,
OpClass
,
ElementA
,
GmemLayoutA
,
AlignmentA
,
ElementB
,
GmemLayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape_MNK
,
ClusterShape_MNK
,
StageCountType
,
KernelScheduleType
>::
CollectiveOp
;
};
};
// namespace cutlass::gemm::collective
\ No newline at end of file
csrc/cutlass_extensions/vllm_custom_types.cuh
0 → 100644
View file @
af7f4372
#pragma once
#include "cutlass/integer_subbyte.h"
namespace
cutlass
{
///////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
Bits
,
int
Bias
,
bool
Signed
=
false
>
struct
vllm_biased_integer_subbyte
:
public
integer_subbyte
<
Bits
,
Signed
>
{
using
Base
=
integer_subbyte
<
Bits
,
Signed
>
;
using
Storage
=
typename
Base
::
Storage
;
using
xint_t
=
typename
Base
::
xint_t
;
using
Base
::
bits_mask_
;
using
Base
::
sign_mask_
;
using
Base
::
storage
;
//
// Methods
//
/// No operation
vllm_biased_integer_subbyte
()
=
default
;
/// Conversion from integer type
CUTLASS_HOST_DEVICE
explicit
vllm_biased_integer_subbyte
(
int
value
)
:
Base
(
value
)
{}
CUTLASS_HOST_DEVICE
explicit
vllm_biased_integer_subbyte
(
unsigned
value
)
:
Base
(
value
)
{}
CUTLASS_HOST_DEVICE
explicit
vllm_biased_integer_subbyte
(
double
value
)
:
Base
(
value
)
{}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// "GPTQ" types, i.e. symmetric quantization
using
vllm_uint4b8_t
=
vllm_biased_integer_subbyte
<
4
,
8
>
;
// u4b8
using
vllm_uint8b128_t
=
vllm_biased_integer_subbyte
<
8
,
128
>
;
// u8b128
///////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
Bits
,
int
Bias
,
bool
Signed
>
struct
sizeof_bits
<
vllm_biased_integer_subbyte
<
Bits
,
Bias
,
Signed
>>
{
static
constexpr
int
value
=
Bits
;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass
csrc/cutlass_extensions/vllm_cutlass_library_extension.py
0 → 100644
View file @
af7f4372
import
enum
from
typing
import
Dict
,
Union
from
cutlass_library
import
*
#
# Extend cutlass library with custom types, and missing values
#
class
VLLMDataType
(
enum
.
Enum
):
u4b8
=
enum_auto
()
u8b128
=
enum_auto
()
class
MixedInputKernelScheduleType
(
enum
.
Enum
):
TmaWarpSpecializedMixedInput
=
enum_auto
()
TmaWarpSpecializedPingpongMixedInput
=
enum_auto
()
TmaWarpSpecializedCooperativeMixedInput
=
enum_auto
()
VLLMDataTypeNames
:
Dict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
**
DataTypeNames
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
"u4b8"
,
VLLMDataType
.
u8b128
:
"u8b128"
,
}
}
VLLMDataTypeTag
:
Dict
[
Union
[
VLLMDataType
,
DataType
],
str
]
=
{
**
DataTypeTag
,
# type: ignore
**
{
VLLMDataType
.
u4b8
:
"cutlass::vllm_uint4b8_t"
,
VLLMDataType
.
u8b128
:
"cutlass::vllm_uint8b128_t"
,
}
}
VLLMKernelScheduleTag
:
Dict
[
Union
[
MixedInputKernelScheduleType
,
KernelScheduleType
],
str
]
=
{
**
KernelScheduleTag
,
# type: ignore
**
{
MixedInputKernelScheduleType
.
TmaWarpSpecializedMixedInput
:
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput"
,
MixedInputKernelScheduleType
.
TmaWarpSpecializedPingpongMixedInput
:
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput"
,
MixedInputKernelScheduleType
.
TmaWarpSpecializedCooperativeMixedInput
:
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput"
,
}
}
csrc/cutlass_extensions/vllm_numeric_conversion.cuh
0 → 100644
View file @
af7f4372
#pragma once
#include "cutlass/numeric_conversion.h"
#include "cutlass_extensions/vllm_custom_types.cuh"
#include "cutlass_extensions/cute_utils.cuh"
// this file extends:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t
// as well as adds interleaved numeric array converters for specific types.
// (interleaved numeric array converters can be more efficient for subbyte
// types)
namespace
cutlass
{
// InterleavedNumericArrayConverter is like NumericArrayConverter but also
// deinterleaves converted elements based on IlvBlkLayout, interleaving can
// make subbyte converts more efficient by allowing for efficient extraction
// of subbyte elements from a 32bit register.
template
<
typename
IlvBlkLayout
,
typename
T
,
typename
S
,
int
N
,
FloatRoundStyle
Round
=
FloatRoundStyle
::
round_to_nearest
,
class
Enable
=
void
>
struct
InterleavedNumericArrayConverter
{
using
Converter
=
NumericArrayConverter
<
T
,
S
,
N
,
Round
>
;
using
result_type
=
typename
Converter
::
result_type
;
using
source_type
=
typename
Converter
::
source_type
;
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
CUTE_INVALID_CONTROL_PATH
(
"InterleavedNumericArrayConverter not implemented
\n
"
);
return
{};
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
template
<
typename
IlvBlkLayout
,
typename
T
,
typename
S
,
int
N
,
FloatRoundStyle
Round
>
struct
InterleavedNumericArrayConverter
<
IlvBlkLayout
,
T
,
S
,
N
,
Round
,
std
::
enable_if_t
<
is_identity_layout
<
IlvBlkLayout
>
()
>>
{
using
Converter
=
NumericArrayConverter
<
T
,
S
,
N
,
Round
>
;
using
result_type
=
typename
Converter
::
result_type
;
using
source_type
=
typename
Converter
::
source_type
;
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
Converter
::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// TODO (LucasWilkinson): Implement
// for Array<cutlass::float8_e4m3fn, N> <= Array<vllm_uint4b8_t, N>
// ....
template
<
typename
RegConvert32bit
,
typename
T
,
typename
S
,
int
N
>
struct
ArrayConverterPacked32Bit
{
using
result_type
=
Array
<
T
,
N
>
;
using
source_type
=
Array
<
S
,
N
>
;
using
result_packed_8_t
=
Array
<
T
,
8
>
;
using
result_packed_4_t
=
Array
<
T
,
4
>
;
using
result_packed_2_t
=
Array
<
T
,
2
>
;
using
src_packed_8_t
=
Array
<
S
,
8
>
;
using
src_packed_4_t
=
Array
<
S
,
4
>
;
using
src_packed_2_t
=
Array
<
S
,
2
>
;
static_assert
(
N
%
2
==
0
,
"N must be a multiple of 2"
);
static_assert
(
cutlass
::
sizeof_bits_v
<
S
>
>=
4
);
// TODO: add 16 packed sources
static_assert
(
32
%
cutlass
::
sizeof_bits_v
<
S
>
==
0
);
static
constexpr
auto
src_elems_per_32bit_reg
=
32
/
cutlass
::
sizeof_bits_v
<
S
>
;
// Maybe not Valid. ScalarConverter will not actually work unless
// NumericConverter<T, S, Round> is implemented. However it won't be used
// anyways since we assert N % 2 == 0, just here for compliance with
// VectorizedConverter.
using
ScalarConverter
=
NumericConverter
<
T
,
S
>
;
template
<
typename
PackedSrc
>
CUTLASS_DEVICE
static
uint32_t
to_reg
(
PackedSrc
const
&
source
)
{
if
constexpr
(
sizeof
(
PackedSrc
)
==
1
)
{
return
static_cast
<
uint32_t
>
(
reinterpret_cast
<
const
uint8_t
&>
(
source
));
}
else
if
constexpr
(
sizeof
(
PackedSrc
)
==
2
)
{
return
static_cast
<
uint32_t
>
(
reinterpret_cast
<
const
uint16_t
&>
(
source
));
}
else
{
static_assert
(
sizeof
(
PackedSrc
)
==
4
);
return
reinterpret_cast
<
const
uint32_t
&>
(
source
);
}
}
// The core converter uses bit tricks to construct a known FP16 number, then
// does a subtraction in FP16 for the final result.
template
<
typename
PackedResultType
,
typename
PackedSrcType
>
CUTLASS_DEVICE
static
PackedResultType
packed_convert
(
PackedSrcType
const
&
source
)
{
static_assert
(
PackedSrcType
::
kElements
==
PackedResultType
::
kElements
);
static_assert
(
PackedResultType
::
kElements
==
2
||
PackedResultType
::
kElements
==
4
||
PackedResultType
::
kElements
==
8
,
"Invalid PackedResultType must be 2, 4 or 8."
);
static_assert
(
std
::
is_same_v
<
typename
PackedSrcType
::
Element
,
S
>
);
static_assert
(
std
::
is_same_v
<
typename
PackedResultType
::
Element
,
T
>
);
return
RegConvert32bit
::
template
convert
<
PackedResultType
>(
to_reg
(
source
));
}
friend
class
detail
::
VectorizedConverter
;
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
result_type
result
;
using
ConverterType
=
ArrayConverterPacked32Bit
<
RegConvert32bit
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>
;
if
constexpr
(
src_elems_per_32bit_reg
>=
8
)
{
detail
::
VectorizedConverter
::
convert
<
ConverterType
,
result_packed_8_t
,
src_packed_8_t
,
result_packed_4_t
,
src_packed_4_t
,
result_packed_2_t
,
src_packed_2_t
>
(
result
,
source
);
}
else
if
constexpr
(
src_elems_per_32bit_reg
>=
4
)
{
detail
::
VectorizedConverter
::
convert
<
ConverterType
,
result_packed_4_t
,
src_packed_4_t
,
result_packed_2_t
,
src_packed_2_t
>
(
result
,
source
);
}
else
{
detail
::
VectorizedConverter
::
convert
<
ConverterType
,
result_packed_2_t
,
src_packed_2_t
>
(
result
,
source
);
}
return
result
;
}
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
template
<
FloatRoundStyle
Round
,
int
N
>
struct
NumericArrayConverter
<
cutlass
::
half_t
,
vllm_uint4b8_t
,
N
,
Round
>
{
using
result_type
=
Array
<
cutlass
::
half_t
,
N
>
;
using
source_type
=
Array
<
vllm_uint4b8_t
,
N
>
;
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
// Below constructs the following temporary:
// fp16s_01 = {0x00, i4_01, 0x00, i4_01}
// fp16s_23 = {0x00, i4_23, 0x00, i4_23}
// fp16s_45 = {0x00, i4_45, 0x00, i4_45}
// fp16s_67 = {0x00, i4_67, 0x00, i4_67}
// We use inline asm instead of __byte_perm intrinsic since we don't want
// the documented (& 0x7) on the index. NVCC might be able to optimize it
// out since the index is a constexpr, but we choose to be safe about it
// here.
uint32_t
prmt_indices
[
4
]
=
{
0x4040
,
0x4141
,
0x4242
,
0x4343
};
static_assert
(
RegArray
::
kElements
<=
4
,
"Too many inputs for F16 -> I4 vector converter"
);
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
asm
volatile
(
"{
\n
"
" prmt.b32 %0, %1, %2, %3;
\n
"
"}
\n
"
:
"=r"
(
r
[
ii
])
:
"r"
(
src
),
"n"
(
0
),
"r"
(
prmt_indices
[
ii
]));
}
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
// we are trying to construct x and a fp16 value
// The below XOR does the following:
// 1) Sets the exponent bits of the FP16 to the correct value for the
// FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)},
// where x1 in the high nibble and x0 is the low nibble then using hfma
// to subtract 1032 from that
// The AND does the following:
// 1) Clear the set bits for the int4 we will ignore.
// We use lop3 so that we can use 1 instruction for AND and XOR.
static
constexpr
uint32_t
xor_mask
=
0x64006400
;
static
constexpr
uint32_t
and_mask
=
0xFFF0FF0F
;
static
constexpr
uint32_t
immLut
=
(
0xf0
&
0xcc
)
^
0xaa
;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
])
:
"n"
(
and_mask
),
"n"
(
xor_mask
),
"n"
(
immLut
));
}
// We will issue 2 hfmas that do the following:
// {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032}
// = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032}
static
constexpr
uint32_t
hfma_bias_rep
=
0xD480E408
;
// {72, 1032}
static
constexpr
uint32_t
hfma_scale_rep
=
0x2C003C00
;
// {1 / 16, 1}
const
half2
&
hfma_bias
=
reinterpret_cast
<
const
half2
&>
(
hfma_bias_rep
);
const
half2
&
hfma_scale
=
reinterpret_cast
<
const
half2
&>
(
hfma_scale_rep
);
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
half2
&
fp16x2_val
=
reinterpret_cast
<
__half2
&>
(
r
[
ii
]);
fp16x2_val
=
__hfma2
(
hfma_scale
,
fp16x2_val
,
hfma_bias
);
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template
<
FloatRoundStyle
Round
,
int
N
>
struct
InterleavedNumericArrayConverter
<
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
,
cutlass
::
half_t
,
vllm_uint4b8_t
,
N
,
Round
,
void
>
{
using
IlvdLayout
=
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
;
static_assert
(
N
%
size
(
IlvdLayout
{})
==
0
);
using
result_type
=
Array
<
cutlass
::
half_t
,
N
>
;
using
source_type
=
Array
<
vllm_uint4b8_t
,
N
>
;
static
FloatRoundStyle
const
round_style
=
Round
;
private:
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
static_assert
(
PackedResultType
::
kElements
<=
size
(
IlvdLayout
{}));
static
constexpr
uint32_t
xor_mask
=
0x64006400
;
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
ii
+=
2
)
{
auto
src_
=
src
>>
(
4
*
(
ii
));
r
[
ii
+
0
]
=
src_
;
r
[
ii
+
1
]
=
src_
;
static
constexpr
uint32_t
and_xor_imm_lut
=
(
0xf0
&
0xcc
)
^
0xaa
;
static
constexpr
uint32_t
low_nib_mask
=
0x000F000F
;
static
constexpr
uint32_t
high_nib_mask
=
0x00F000F0
;
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
+
0
])
:
"n"
(
low_nib_mask
),
"n"
(
xor_mask
),
"n"
(
and_xor_imm_lut
));
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
+
1
])
:
"n"
(
high_nib_mask
),
"n"
(
xor_mask
),
"n"
(
and_xor_imm_lut
));
// For low nibble:
// {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032}
// For high nibble:
// {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16}
// - {72, 72}
static
constexpr
uint32_t
low_nib_bias
=
0x64086408
;
// {1032, 1032}
static
constexpr
uint32_t
high_nib_scale
=
0x2C002C00
;
// {1/16, 1/16}
static
constexpr
uint32_t
high_nib_bias
=
0xD480D480
;
// {-72, -72}
{
half2
&
fp16x2_val
=
reinterpret_cast
<
__half2
&>
(
r
[
ii
+
0
]);
fp16x2_val
=
__hsub2
(
fp16x2_val
,
reinterpret_cast
<
const
half2
&>
(
low_nib_bias
));
}
{
half2
&
fp16x2_val
=
reinterpret_cast
<
__half2
&>
(
r
[
ii
+
1
]);
fp16x2_val
=
__hfma2
(
fp16x2_val
,
reinterpret_cast
<
const
half2
&>
(
high_nib_scale
),
reinterpret_cast
<
const
half2
&>
(
high_nib_bias
));
}
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::half_t, N> <= Array<uint4_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template
<
FloatRoundStyle
Round
,
int
N
>
struct
InterleavedNumericArrayConverter
<
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
,
cutlass
::
half_t
,
uint4_t
,
N
,
Round
,
void
>
{
using
IlvdLayout
=
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
;
static_assert
(
N
%
size
(
IlvdLayout
{})
==
0
);
using
result_type
=
Array
<
cutlass
::
half_t
,
N
>
;
using
source_type
=
Array
<
uint4_t
,
N
>
;
static
FloatRoundStyle
const
round_style
=
Round
;
private:
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
static_assert
(
PackedResultType
::
kElements
<=
size
(
IlvdLayout
{}));
static
constexpr
uint32_t
xor_mask
=
0x64006400
;
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
ii
+=
2
)
{
auto
src_
=
src
>>
(
4
*
(
ii
));
r
[
ii
+
0
]
=
src_
;
r
[
ii
+
1
]
=
src_
;
static
constexpr
uint32_t
and_xor_imm_lut
=
(
0xf0
&
0xcc
)
^
0xaa
;
static
constexpr
uint32_t
low_nib_mask
=
0x000F000F
;
static
constexpr
uint32_t
high_nib_mask
=
0x00F000F0
;
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
+
0
])
:
"n"
(
low_nib_mask
),
"n"
(
xor_mask
),
"n"
(
and_xor_imm_lut
));
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
+
1
])
:
"n"
(
high_nib_mask
),
"n"
(
xor_mask
),
"n"
(
and_xor_imm_lut
));
// For low nibble:
// {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024}
// For high nibble:
// {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64}
static
constexpr
uint32_t
low_nib_bias
=
0x64006400
;
// {1024, 1024}
static
constexpr
uint32_t
high_nib_scale
=
0x2C002C00
;
// {1/16, 1/16}
static
constexpr
uint32_t
high_nib_bias
=
0xD400D400
;
// {-64, -64}
{
half2
&
fp16x2_val
=
reinterpret_cast
<
__half2
&>
(
r
[
ii
+
0
]);
fp16x2_val
=
__hsub2
(
fp16x2_val
,
reinterpret_cast
<
const
half2
&>
(
low_nib_bias
));
}
{
half2
&
fp16x2_val
=
reinterpret_cast
<
__half2
&>
(
r
[
ii
+
1
]);
fp16x2_val
=
__hfma2
(
fp16x2_val
,
reinterpret_cast
<
const
half2
&>
(
high_nib_scale
),
reinterpret_cast
<
const
half2
&>
(
high_nib_bias
));
}
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint8b128_t, N>
template
<
FloatRoundStyle
Round
,
int
N
>
struct
NumericArrayConverter
<
cutlass
::
half_t
,
vllm_uint8b128_t
,
N
,
Round
>
{
using
result_type
=
Array
<
cutlass
::
half_t
,
N
>
;
using
source_type
=
Array
<
vllm_uint8b128_t
,
N
>
;
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
// Hold output FP16s in reg. We need 1 reg for every 2 elements
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
uint32_t
const
prmt_indices
[
2
]
=
{
0x5150
,
0x5352
};
static
constexpr
uint32_t
start_byte_for_fp16
=
0x64646464
;
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
asm
volatile
(
"prmt.b32 %0,%1,%2,%3;
\n
"
:
"=r"
(
r
[
ii
])
:
"r"
(
src
),
"n"
(
start_byte_for_fp16
),
"r"
(
prmt_indices
[
ii
]));
}
// -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes
static
constexpr
uint32_t
bias_rep
=
0x64806480
;
const
half2
&
bias
=
reinterpret_cast
<
const
half2
&>
(
bias_rep
);
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
half2
&
fp16x2_val
=
reinterpret_cast
<
__half2
&>
(
r
[
ii
]);
fp16x2_val
=
__hsub2
(
fp16x2_val
,
bias
);
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::float, N> <= Array<vllm_uint8b128_t, N>
template
<
FloatRoundStyle
Round
,
int
N
>
struct
NumericArrayConverter
<
float
,
vllm_uint8b128_t
,
N
,
Round
>
{
using
result_type
=
Array
<
float
,
N
>
;
using
source_type
=
Array
<
vllm_uint8b128_t
,
N
>
;
static
FloatRoundStyle
const
round_style
=
Round
;
private:
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
PackedResultType
r
;
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
// u8x4 source and stores the result in r (without introducing extra
// cvt.u32.u8 instruction)
uint32_t
const
prmt_indices
[
4
]
=
{
0x7650
,
0x7651
,
0x7652
,
0x7653
};
uint32_t
*
result_as_int
=
reinterpret_cast
<
uint32_t
*>
(
&
r
);
for
(
int
ii
=
0
;
ii
<
PackedResultType
::
kElements
;
++
ii
)
{
result_as_int
[
ii
]
=
__byte_perm
(
src
,
0x4B000000
,
prmt_indices
[
ii
]);
// Subtract the magic number 0x4B000000 from tmp in floating-point
// arithmetic to obtain final result
r
[
ii
]
-=
(
8388608.
f
+
128.
f
);
// fold in -128 bias
}
return
r
;
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint4b8_t, N>
template
<
FloatRoundStyle
Round
,
int
N
>
struct
NumericArrayConverter
<
cutlass
::
bfloat16_t
,
vllm_uint4b8_t
,
N
,
Round
>
{
using
result_type
=
Array
<
cutlass
::
bfloat16_t
,
N
>
;
using
source_type
=
Array
<
vllm_uint4b8_t
,
N
>
;
static
FloatRoundStyle
const
round_style
=
Round
;
private:
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src_reg
)
{
// Hold output BF16s in reg. We need 1 reg for every 2 elements
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
uint32_t
src_reg_shifted
=
src_reg
>>
4
;
// Below constructs the following temporary:
uint32_t
const
prmt_indices
[
4
]
=
{
0xF4F0
,
0xF5F1
,
0xF6F2
,
0xF7F3
};
static_assert
(
RegArray
::
kElements
<=
4
,
"Too many inputs for uint4b8_t -> BF16 vector converter"
);
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
asm
volatile
(
"{
\n
"
" prmt.b32 %0, %1, %2, %3;
\n
"
"}
\n
"
:
"=r"
(
r
[
ii
])
:
"r"
(
src_reg
),
"r"
(
src_reg_shifted
),
"r"
(
prmt_indices
[
ii
]));
}
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
// we are trying to construct x and a BF16 value
// The below XOR does the following:
// 1) Sets the exponent bits of the BF16 to the correct value for the
// BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)}
// and subtracting 136 to get {x1, x0}
static
constexpr
uint32_t
xor_mask
=
0x43004300
;
static
constexpr
uint32_t
and_mask
=
0x000F000F
;
static
constexpr
uint32_t
immLut
=
(
0xf0
&
0xcc
)
^
0xaa
;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
])
:
"n"
(
and_mask
),
"n"
(
xor_mask
),
"n"
(
immLut
));
}
// We will issue 2 bfmas that do the following:
// high BF16:
// hi_bf16 - 136, lo_bf16 - 136
// This is the BF16 {136, 136} represented as an integer.
static
constexpr
uint32_t
bias_rep
=
0x43084308
;
const
__nv_bfloat162
&
bias
=
reinterpret_cast
<
const
__nv_bfloat162
&>
(
bias_rep
);
CUTLASS_PRAGMA_UNROLL
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
__nv_bfloat162
&
bf16x2_val
=
reinterpret_cast
<
__nv_bfloat162
&>
(
r
[
ii
]);
bf16x2_val
=
__hsub2
(
bf16x2_val
,
bias
);
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
}
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint4b8_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template
<
FloatRoundStyle
Round
,
int
N
>
struct
InterleavedNumericArrayConverter
<
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
,
cutlass
::
bfloat16_t
,
vllm_uint4b8_t
,
N
,
Round
,
void
>
{
using
IlvdLayout
=
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
;
static_assert
(
N
%
size
(
IlvdLayout
{})
==
0
);
using
result_type
=
Array
<
cutlass
::
bfloat16_t
,
N
>
;
using
source_type
=
Array
<
vllm_uint4b8_t
,
N
>
;
private:
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
static_assert
(
PackedResultType
::
kElements
<=
size
(
IlvdLayout
{}));
static
constexpr
uint32_t
or_mask
=
0x43004300
;
// Unlike float16 where the mantissa is large enough to contain 2
// nibbles, bfloat16 can only fit one, so we can only convert one
// nibble at a time
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
r
[
ii
]
=
src
>>
(
4
*
ii
);
static
constexpr
uint32_t
and_or_imm_lut
=
(
0xf0
&
0xcc
)
|
0xaa
;
static
constexpr
uint32_t
low_nib_mask
=
0x000F000F
;
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
+
0
])
:
"n"
(
low_nib_mask
),
"n"
(
or_mask
),
"n"
(
and_or_imm_lut
));
// For low nibble:
// {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136}
static
constexpr
uint32_t
low_nib_bias
=
0x43084308
;
// {136, 136}
{
__nv_bfloat162
&
fp16x2_val
=
reinterpret_cast
<
__nv_bfloat162
&>
(
r
[
ii
]);
fp16x2_val
=
__hsub2
(
fp16x2_val
,
reinterpret_cast
<
const
__nv_bfloat162
&>
(
low_nib_bias
));
}
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::bfloat16_t, N> <= Array<uint4_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template
<
FloatRoundStyle
Round
,
int
N
>
struct
InterleavedNumericArrayConverter
<
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
,
cutlass
::
bfloat16_t
,
uint4_t
,
N
,
Round
,
void
>
{
using
IlvdLayout
=
Layout
<
Shape
<
_2
,
_4
>
,
Stride
<
_4
,
_1
>>
;
static_assert
(
N
%
size
(
IlvdLayout
{})
==
0
);
using
result_type
=
Array
<
cutlass
::
bfloat16_t
,
N
>
;
using
source_type
=
Array
<
uint4_t
,
N
>
;
private:
struct
RegConvert
{
template
<
typename
PackedResultType
>
CUTLASS_DEVICE
static
PackedResultType
convert
(
uint32_t
src
)
{
using
RegArray
=
cutlass
::
AlignedArray
<
uint32_t
,
PackedResultType
::
kElements
/
2
,
sizeof
(
PackedResultType
)
>
;
RegArray
r
;
static_assert
(
PackedResultType
::
kElements
<=
size
(
IlvdLayout
{}));
static
constexpr
uint32_t
or_mask
=
0x43004300
;
// Unlike float16 where the mantissa is large enough to contain 2
// nibbles, bfloat16 can only fit one, so we can only convert one
// nibble at a time
for
(
int
ii
=
0
;
ii
<
RegArray
::
kElements
;
++
ii
)
{
r
[
ii
]
=
src
>>
(
4
*
ii
);
static
constexpr
uint32_t
and_or_imm_lut
=
(
0xf0
&
0xcc
)
|
0xaa
;
static
constexpr
uint32_t
low_nib_mask
=
0x000F000F
;
asm
volatile
(
"{
\n
"
" lop3.b32 %0, %0, %1, %2, %3;
\n
"
"}
\n
"
:
"+r"
(
r
[
ii
])
:
"n"
(
low_nib_mask
),
"n"
(
or_mask
),
"n"
(
and_or_imm_lut
));
// For low nibble:
// {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128}
static
constexpr
uint32_t
low_nib_bias
=
0x43004300
;
// {128, 128}
{
__nv_bfloat162
&
fp16x2_val
=
reinterpret_cast
<
__nv_bfloat162
&>
(
r
[
ii
]);
fp16x2_val
=
__hsub2
(
fp16x2_val
,
reinterpret_cast
<
const
__nv_bfloat162
&>
(
low_nib_bias
));
}
}
return
reinterpret_cast
<
PackedResultType
&>
(
r
);
};
};
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
return
ArrayConverterPacked32Bit
<
RegConvert
,
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
>::
convert
(
source
);
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint8b128_t, N>
template
<
FloatRoundStyle
Round
,
int
N
>
struct
NumericArrayConverter
<
cutlass
::
bfloat16_t
,
vllm_uint8b128_t
,
N
,
Round
>
{
using
result_type
=
Array
<
cutlass
::
bfloat16_t
,
N
>
;
using
source_type
=
Array
<
vllm_uint8b128_t
,
N
>
;
static
FloatRoundStyle
const
round_style
=
Round
;
private:
using
result_packed_4_t
=
Array
<
cutlass
::
bfloat16_t
,
4
>
;
using
result_packed_2_t
=
Array
<
cutlass
::
bfloat16_t
,
2
>
;
using
src_packed_4_t
=
Array
<
vllm_uint8b128_t
,
4
>
;
using
src_packed_2_t
=
Array
<
vllm_uint8b128_t
,
2
>
;
// Not Valid, not supported, only here to satisfy the interface and to avoid
// a compile error. ScalarConverter will not actually work until
// NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round> is
// implemented
using
ScalarConverter
=
NumericConverter
<
cutlass
::
bfloat16_t
,
vllm_uint8b128_t
,
Round
>
;
template
<
typename
PackedResultType
,
typename
PackedSrcType
>
CUTLASS_DEVICE
static
PackedResultType
packed_convert
(
PackedSrcType
const
&
source
)
{
static_assert
(
(
platform
::
is_same
<
PackedSrcType
,
src_packed_2_t
>::
value
&&
platform
::
is_same
<
PackedResultType
,
result_packed_2_t
>::
value
)
||
(
platform
::
is_same
<
PackedSrcType
,
src_packed_4_t
>::
value
&&
platform
::
is_same
<
PackedResultType
,
result_packed_4_t
>::
value
),
"Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private "
"convert dispatch."
);
NumericArrayConverter
<
float
,
vllm_uint8b128_t
,
PackedResultType
::
kElements
,
Round
>
convert_uint8_to_f32
;
Array
<
float
,
PackedResultType
::
kElements
>
tmp
=
convert_uint8_to_f32
(
source
);
NumericArrayConverter
<
cutlass
::
bfloat16_t
,
float
,
PackedResultType
::
kElements
,
Round
>
convert_f32_to_bf16_
;
return
convert_f32_to_bf16_
(
tmp
);
}
friend
class
detail
::
VectorizedConverter
;
public:
CUTLASS_DEVICE
static
result_type
convert
(
source_type
const
&
source
)
{
result_type
result
;
using
ConverterType
=
NumericArrayConverter
<
typename
result_type
::
Element
,
typename
source_type
::
Element
,
N
,
Round
>
;
detail
::
VectorizedConverter
::
convert
<
ConverterType
,
result_packed_4_t
,
src_packed_4_t
,
result_packed_2_t
,
src_packed_2_t
>
(
result
,
source
);
return
result
;
}
CUTLASS_DEVICE
result_type
operator
()(
source_type
const
&
s
)
const
{
return
convert
(
s
);
}
};
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
csrc/layernorm_kernels.cu
View file @
af7f4372
...
...
@@ -3,13 +3,16 @@
#include <c10/cuda/CUDAGuard.h>
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
...
...
@@ -31,7 +34,11 @@ __global__ void rms_norm_kernel(
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
variance
+=
x
*
x
;
}
variance
=
blockReduceSum
<
float
>
(
variance
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
...
...
@@ -228,12 +235,11 @@ fused_add_rms_norm_kernel(
variance
+=
temp
.
sum_squares
();
residual_v
[
id
]
=
temp
;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if
(
num_tokens
<
256
)
{
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
...
...
@@ -268,12 +274,11 @@ fused_add_rms_norm_kernel(
variance
+=
x
*
x
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if
(
num_tokens
<
256
)
{
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
...
...
csrc/ops.h
View file @
af7f4372
...
...
@@ -105,12 +105,12 @@ void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
std
::
vector
<
int64_t
>
&
codebook_partition_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
);
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
std
::
vector
<
int64_t
>
&
codebook_partition_sizes
);
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
...
...
@@ -125,6 +125,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
namespace
machete
{
std
::
vector
<
std
::
string
>
supported_schedules
(
vllm
::
ScalarTypeTorchPtr
const
&
btype
);
torch
::
Tensor
gemm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
vllm
::
ScalarTypeTorchPtr
const
&
btype
,
c10
::
optional
<
torch
::
Tensor
>
const
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
zeros
,
c10
::
optional
<
int64_t
>
group_size
,
c10
::
optional
<
torch
::
Tensor
>
const
&
C
,
c10
::
optional
<
double
>
alpha
,
c10
::
optional
<
double
>
beta
,
c10
::
optional
<
std
::
string
>
schedule
);
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
vllm
::
ScalarTypeTorchPtr
const
&
btype
);
};
// namespace machete
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
...
...
@@ -149,6 +168,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
int64_t
n
);
torch
::
Tensor
ggml_mul_mat_vec_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
...
...
@@ -161,6 +189,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
torch
::
Tensor
marlin_qqq_gemm
(
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b_q_weight
,
torch
::
Tensor
const
&
s_tok
,
...
...
csrc/opt/layernorm_kernels_opt.cu
View file @
af7f4372
...
...
@@ -6,13 +6,17 @@
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include "../dispatch_utils.h"
#include "../reduction_utils.cuh"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
...
...
@@ -34,7 +38,11 @@ __global__ void rms_norm_kernel(
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
variance
+=
x
*
x
;
}
variance
=
blockReduceSum
<
float
>
(
variance
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
...
...
@@ -231,12 +239,11 @@ fused_add_rms_norm_kernel(
variance
+=
temp
.
sum_squares
();
residual_v
[
id
]
=
temp
;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if
(
num_tokens
<
256
)
{
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
...
...
@@ -271,12 +278,11 @@ fused_add_rms_norm_kernel(
variance
+=
x
*
x
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if
(
num_tokens
<
256
)
{
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
...
...
Prev
1
2
3
4
5
6
7
…
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment