Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d29c39ca
Commit
d29c39ca
authored
Apr 30, 2026
by
chenzk
Browse files
vllm kvprune wo:v1.1.0
parent
f81ce56b
Changes
246
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2439 additions
and
0 deletions
+2439
-0
vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
.../matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
+119
-0
vllm/kvprune_legacy_save/triton_kernels/numerics.py
vllm/kvprune_legacy_save/triton_kernels/numerics.py
+42
-0
vllm/kvprune_legacy_save/triton_kernels/numerics_details/__init__.py
...e_legacy_save/triton_kernels/numerics_details/__init__.py
+0
-0
vllm/kvprune_legacy_save/triton_kernels/numerics_details/flexpoint.py
..._legacy_save/triton_kernels/numerics_details/flexpoint.py
+204
-0
vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp.py
...prune_legacy_save/triton_kernels/numerics_details/mxfp.py
+303
-0
vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/__init__.py
.../triton_kernels/numerics_details/mxfp_details/__init__.py
+0
-0
vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
...ernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
+158
-0
vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
...ernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
+125
-0
vllm/kvprune_legacy_save/triton_kernels/proton_opts.py
vllm/kvprune_legacy_save/triton_kernels/proton_opts.py
+19
-0
vllm/kvprune_legacy_save/triton_kernels/reduction_details/__init__.py
..._legacy_save/triton_kernels/reduction_details/__init__.py
+0
-0
vllm/kvprune_legacy_save/triton_kernels/reduction_details/reduce_bitmatrix.py
...save/triton_kernels/reduction_details/reduce_bitmatrix.py
+133
-0
vllm/kvprune_legacy_save/triton_kernels/routing.py
vllm/kvprune_legacy_save/triton_kernels/routing.py
+521
-0
vllm/kvprune_legacy_save/triton_kernels/routing_details/__init__.py
...ne_legacy_save/triton_kernels/routing_details/__init__.py
+0
-0
vllm/kvprune_legacy_save/triton_kernels/routing_details/_expt_data.py
..._legacy_save/triton_kernels/routing_details/_expt_data.py
+75
-0
vllm/kvprune_legacy_save/triton_kernels/routing_details/_routing_compute.py
...y_save/triton_kernels/routing_details/_routing_compute.py
+241
-0
vllm/kvprune_legacy_save/triton_kernels/specialize.py
vllm/kvprune_legacy_save/triton_kernels/specialize.py
+143
-0
vllm/kvprune_legacy_save/triton_kernels/swiglu.py
vllm/kvprune_legacy_save/triton_kernels/swiglu.py
+99
-0
vllm/kvprune_legacy_save/triton_kernels/swiglu_details/__init__.py
...une_legacy_save/triton_kernels/swiglu_details/__init__.py
+0
-0
vllm/kvprune_legacy_save/triton_kernels/swiglu_details/_swiglu.py
...rune_legacy_save/triton_kernels/swiglu_details/_swiglu.py
+141
-0
vllm/kvprune_legacy_save/triton_kernels/target_info.py
vllm/kvprune_legacy_save/triton_kernels/target_info.py
+116
-0
No files found.
vllm/kvprune_legacy_save/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
0 → 100644
View file @
d29c39ca
import
torch
import
triton
from
vllm.kvprune.triton_kernels
import
target_info
from
vllm.kvprune.triton_kernels.tensor
import
get_layout
,
bitwidth
,
FP4
from
vllm.kvprune.triton_kernels.tensor_details.layout
import
HopperMXScaleLayout
from
vllm.kvprune.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp
import
(
MXFP_BLOCK_SIZE
,
)
def
compute_grid_size
(
routing_data
,
m
,
n
,
block_m
,
block_n
):
if
routing_data
is
not
None
:
grid_m
=
routing_data
.
n_blocks
(
m
,
block_m
)
else
:
grid_m
=
triton
.
cdiv
(
m
,
block_m
)
grid_n
=
(
n
+
block_n
-
1
)
//
block_n
return
grid_m
*
grid_n
def
compute_block_n
(
n
:
int
,
arch
,
precision_config
):
# block_n:
layout
=
get_layout
(
precision_config
.
weight_scale
)
if
isinstance
(
layout
,
HopperMXScaleLayout
)
and
layout
.
num_warps
==
4
:
return
128
elif
precision_config
.
max_num_imprecise_acc
is
None
and
n
>
128
:
return
256
else
:
return
max
(
16
,
min
(
128
,
triton
.
next_power_of_2
(
n
)))
def
compute_block_k
(
m
:
int
,
k
:
int
|
None
,
is_persistent
:
bool
,
lhs_dtype
,
rhs_dtype
,
precision_config
):
lhs_width
=
bitwidth
(
lhs_dtype
)
rhs_width
=
bitwidth
(
rhs_dtype
)
# block_k needs to match the cacheline size (1024 bits)
block_k
=
int
(
1024
//
min
(
lhs_width
,
rhs_width
))
has_native_mxfp
=
target_info
.
cuda_capability_geq
(
10
,
0
)
if
rhs_width
==
4
and
not
has_native_mxfp
:
block_k
=
128
elif
k
is
not
None
:
block_k
=
max
(
32
,
min
(
triton
.
next_power_of_2
(
k
),
block_k
))
has_mx_weight_scale
=
(
precision_config
is
not
None
and
precision_config
.
weight_scale
is
not
None
)
if
has_native_mxfp
and
is_persistent
and
has_mx_weight_scale
:
block_k
=
min
(
block_k
,
128
)
return
block_k
def
compute_split_k
(
block_k
:
int
,
k
:
int
|
None
,
grid_size
:
int
)
->
int
:
device_props
=
torch
.
cuda
.
get_device_properties
(
0
)
n_sms
=
device_props
.
multi_processor_count
split_k
=
n_sms
//
grid_size
if
k
is
not
None
:
# avoid split_k for small k
num_block_k
=
triton
.
cdiv
(
k
,
block_k
)
split_k
=
min
(
split_k
,
num_block_k
//
4
)
split_k
=
max
(
split_k
,
1
)
return
split_k
def
compute_num_warps
(
block_m
,
block_n
,
precision_config
):
layout
=
get_layout
(
precision_config
.
weight_scale
)
if
isinstance
(
layout
,
HopperMXScaleLayout
):
return
layout
.
num_warps
return
max
(
block_m
*
block_n
//
4096
,
4
)
def
compute_num_stages
(
precision_config
,
is_persistent
,
block_m
,
block_n
,
block_k
,
out_dtype
,
lhs_dtype
,
rhs_dtype
,
epilogue_subtile
,
epilogue_effective_itemsize
,
):
if
precision_config
.
max_num_imprecise_acc
is
not
None
:
return
3
weight_size
=
bitwidth
(
rhs_dtype
)
/
8
stage_size
=
(
block_m
*
block_k
*
lhs_dtype
.
itemsize
+
block_k
*
block_n
*
weight_size
)
device_props
=
torch
.
cuda
.
get_device_properties
(
0
)
smem_capacity
=
device_props
.
shared_memory_per_block_optin
has_native_mxfp
=
target_info
.
cuda_capability_geq
(
10
,
0
)
if
has_native_mxfp
and
getattr
(
precision_config
,
"weight_scale"
,
None
)
is
not
None
:
if
rhs_dtype
==
FP4
:
# 4-bit e2m1 weights are padded 2x
# https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
stage_size
+=
block_k
*
block_n
*
weight_size
if
is_persistent
:
# Per-stage wait barrier
stage_size
+=
8
if
target_info
.
cuda_capability_geq
(
10
,
0
):
acc_size
=
epilogue_effective_itemsize
or
out_dtype
.
itemsize
else
:
acc_size
=
out_dtype
.
itemsize
if
target_info
.
cuda_capability_geq
(
10
,
0
)
and
epilogue_subtile
is
not
None
:
acc_block_n
=
block_n
//
epilogue_subtile
else
:
acc_block_n
=
block_n
# pipelined TMA store local to global, or
# pipelined layout conversion before store of the accumulator
# note: layout conversion has some padding
smem_capacity
-=
int
((
block_m
+
4
)
*
acc_block_n
*
acc_size
)
if
precision_config
.
weight_scale
is
not
None
:
# mx scales
stage_size
+=
block_n
*
(
block_k
//
int
(
MXFP_BLOCK_SIZE
))
elif
has_native_mxfp
:
# mx scales
stage_size
+=
block_n
*
(
block_k
//
int
(
MXFP_BLOCK_SIZE
))
num_stages
=
min
(
4
,
smem_capacity
//
int
(
stage_size
))
return
num_stages
vllm/kvprune_legacy_save/triton_kernels/numerics.py
0 → 100644
View file @
d29c39ca
import
torch
from
dataclasses
import
dataclass
MAX_FINITE_FLOAT8E5
=
57344.0
MAX_FINITE_FLOAT8E4NV
=
448.0
MAX_FINITE_FLOAT8E4B8
=
240.0
@
dataclass
(
frozen
=
True
)
class
BaseFlexData
:
dtype
:
torch
.
dtype
|
None
=
None
def
view
(
self
,
x
:
torch
.
Tensor
):
if
self
.
dtype
is
None
:
return
x
return
x
.
view
(
self
.
dtype
)
def
reinterpret
(
self
,
x
):
if
self
.
dtype
is
None
or
x
.
dtype
.
itemsize
>
1
:
return
x
return
x
.
view
(
self
.
dtype
)
@
dataclass
(
frozen
=
True
)
class
InFlexData
(
BaseFlexData
):
scale
:
torch
.
Tensor
|
None
=
None
@
property
def
is_per_batch
(
self
):
return
False
if
self
.
scale
is
None
else
len
(
self
.
scale
)
>
1
@
dataclass
(
frozen
=
True
)
class
OutFlexData
(
BaseFlexData
):
expected_scale
:
torch
.
Tensor
|
None
=
None
actual_scale
:
torch
.
Tensor
|
None
=
None
checksum_scale
:
torch
.
Tensor
|
None
=
None
def
__iter__
(
self
):
yield
self
.
expected_scale
yield
self
.
actual_scale
yield
self
.
checksum_scale
vllm/kvprune_legacy_save/triton_kernels/numerics_details/__init__.py
0 → 100644
View file @
d29c39ca
vllm/kvprune_legacy_save/triton_kernels/numerics_details/flexpoint.py
0 → 100644
View file @
d29c39ca
from
..numerics
import
MAX_FINITE_FLOAT8E4B8
,
MAX_FINITE_FLOAT8E4NV
,
MAX_FINITE_FLOAT8E5
import
triton
import
triton.language
as
tl
from
vllm.kvprune.triton_kernels.target_info
import
cuda_capability_geq
# -------------------------------
# Kernels stuff
# -------------------------------
TL_MAX_FINITE_FLOAT8E5
=
tl
.
constexpr
(
MAX_FINITE_FLOAT8E5
)
TL_MAX_FINITE_FLOAT8E4NV
=
tl
.
constexpr
(
MAX_FINITE_FLOAT8E4NV
)
TL_MAX_FINITE_FLOAT8E4B8
=
tl
.
constexpr
(
MAX_FINITE_FLOAT8E4B8
)
TL_MAX_FINITE_FLOAT8E4B15
=
tl
.
constexpr
(
1.750
)
TL_MAX_FINITE_FLOAT16
=
tl
.
constexpr
(
65472.0
)
TL_RCP_MAX_FINITE_FLOAT8E5
=
tl
.
constexpr
(
0x37924925
)
# 0x1.24924Ap-16
TL_RCP_MAX_FINITE_FLOAT8E4NV
=
tl
.
constexpr
(
0x3B124925
)
# 0x1.24924Ap-9
TL_RCP_MAX_FINITE_FLOAT8E4B8
=
tl
.
constexpr
(
0x3B888889
)
# 0x1.111112p-8
TL_RCP_MAX_FINITE_FLOAT8E4B15
=
tl
.
constexpr
(
0x3F124925
)
# 0x1.24924Ap-1
TL_RCP_MAX_FINITE_FLOAT16
=
tl
.
constexpr
(
0x37802008
)
# 0x1.004010p-16
@
triton
.
jit
def
max_finite
(
dtype
):
if
dtype
==
tl
.
constexpr
(
tl
.
float8e5
):
return
TL_MAX_FINITE_FLOAT8E5
elif
dtype
==
tl
.
constexpr
(
tl
.
float8e4nv
):
return
TL_MAX_FINITE_FLOAT8E4NV
elif
dtype
==
tl
.
constexpr
(
tl
.
float8e4b8
):
return
TL_MAX_FINITE_FLOAT8E4B8
elif
dtype
==
tl
.
constexpr
(
tl
.
float8e4b15
):
return
TL_MAX_FINITE_FLOAT8E4B15
elif
dtype
==
tl
.
constexpr
(
tl
.
float16
):
return
TL_MAX_FINITE_FLOAT16
else
:
tl
.
static_assert
(
tl
.
constexpr
(
False
),
f
"
{
dtype
}
not supported in flexpoint"
)
@
triton
.
jit
def
rcp_max_finite
(
dtype
):
if
dtype
==
tl
.
constexpr
(
tl
.
float8e5
):
return
TL_RCP_MAX_FINITE_FLOAT8E5
elif
dtype
==
tl
.
constexpr
(
tl
.
float8e4nv
):
return
TL_RCP_MAX_FINITE_FLOAT8E4NV
elif
dtype
==
tl
.
constexpr
(
tl
.
float8e4b8
):
return
TL_RCP_MAX_FINITE_FLOAT8E4B8
elif
dtype
==
tl
.
constexpr
(
tl
.
float8e4b15
):
return
TL_RCP_MAX_FINITE_FLOAT8E4B15
elif
dtype
==
tl
.
constexpr
(
tl
.
float16
):
return
TL_RCP_MAX_FINITE_FLOAT16
else
:
tl
.
static_assert
(
tl
.
constexpr
(
False
),
f
"
{
dtype
}
not supported in flexpoint"
)
@
triton
.
jit
def
sm86_min_nan_xorsign_abs_f32
(
a
,
b
):
"""Wrapper for min.NaN.xorsign.abs.f32 PTX instruction.
Computes the minimum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
NaN inputs are propagated to the output.
Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
"""
tl
.
static_assert
(
cuda_capability_geq
(
8
,
6
),
"min.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+"
,
)
tl
.
static_assert
(
a
.
dtype
==
tl
.
float32
,
"min.NaN.xorsign.abs.f32 requires float32 inputs"
)
tl
.
static_assert
(
b
.
dtype
==
tl
.
float32
,
"min.NaN.xorsign.abs.f32 requires float32 inputs"
)
return
tl
.
inline_asm_elementwise
(
"""{
min.NaN.xorsign.abs.f32 $0, $1, $2;
}"""
,
"=r,r,r"
,
[
a
,
b
],
dtype
=
tl
.
float32
,
is_pure
=
True
,
pack
=
1
,
)
@
triton
.
jit
def
sm86_max_nan_xorsign_abs_f32
(
a
,
b
):
"""Wrapper for max.NaN.xorsign.abs.f32 PTX instruction.
Computes the maximum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
NaN inputs are propagated to the output.
Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
"""
tl
.
static_assert
(
cuda_capability_geq
(
8
,
6
),
"max.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+"
,
)
tl
.
static_assert
(
a
.
dtype
==
tl
.
float32
,
"max.NaN.xorsign.abs.f32 requires float32 inputs"
)
tl
.
static_assert
(
b
.
dtype
==
tl
.
float32
,
"max.NaN.xorsign.abs.f32 requires float32 inputs"
)
return
tl
.
inline_asm_elementwise
(
"""{
max.NaN.xorsign.abs.f32 $0, $1, $2;
}"""
,
"=r,r,r"
,
[
a
,
b
],
dtype
=
tl
.
float32
,
is_pure
=
True
,
pack
=
1
,
)
@
triton
.
jit
def
load_scale
(
scale_ptr
):
return
1.0
if
scale_ptr
is
None
else
tl
.
load
(
scale_ptr
)
@
triton
.
jit
def
flex_to_float
(
x
,
scale_ptr
):
scale
=
load_scale
(
scale_ptr
)
return
x
.
to
(
tl
.
float32
)
*
scale
@
triton
.
jit
def
clip
(
x
,
limit
):
res
=
tl
.
minimum
(
x
,
limit
)
res
=
tl
.
maximum
(
-
limit
,
res
)
return
res
@
triton
.
jit
def
nan_propagating_absmax_reduce
(
x
,
axis
=
None
):
if
cuda_capability_geq
(
8
,
6
):
# abs-max-reduce as floating-point if `max.NaN.xorsign.abs.f32` is supported.
x_absmax
=
tl
.
reduce
(
x
,
axis
,
sm86_max_nan_xorsign_abs_f32
)
# Note: sign of reduction result is the xor of signs of all inputs, explicitly clear the sign bit to fix it.
x_absmax
=
x_absmax
.
to
(
tl
.
uint32
,
bitcast
=
True
)
&
0x7FFFFFFF
else
:
# Clear the sign bit, max-reduce as integer (same as NaN-propagating max-reduce as float)
masked_abs_x
=
x
.
to
(
tl
.
uint32
,
bitcast
=
True
)
&
0x7FFFFFFF
x_absmax
=
tl
.
max
(
masked_abs_x
,
axis
)
return
x_absmax
@
triton
.
jit
def
compute_scale
(
x
,
Out
):
x_absmax
=
nan_propagating_absmax_reduce
(
tl
.
ravel
(
x
,
can_reorder
=
True
))
# atomic_max does not propagate NaNs, so we replace them with +inf (0x7f800000).
# We use integer minimum because NaNs are above +inf in integer representation.
x_absmax
=
tl
.
minimum
(
x_absmax
,
0x7F800000
).
to
(
tl
.
float32
,
bitcast
=
True
)
RCP_MAX_VALUE
=
rcp_max_finite
(
Out
.
dtype
.
element_ty
)
return
tl
.
fma
(
x_absmax
,
RCP_MAX_VALUE
.
to
(
tl
.
float32
,
bitcast
=
True
),
1.0e-30
)
@
triton
.
jit
def
update_scale
(
x
,
scale_ptr
,
Out
)
->
None
:
if
scale_ptr
is
not
None
:
scale
=
compute_scale
(
x
,
Out
)
tl
.
atomic_max
(
scale_ptr
,
scale
,
sem
=
"relaxed"
)
@
triton
.
jit
def
float_to_flex
(
x
,
expected_scale_ptr_or_val
,
actual_scale_ptr
,
checksum_scale_ptr
,
mask
,
Out
,
saturate_infs
:
tl
.
constexpr
,
):
if
expected_scale_ptr_or_val
is
not
None
:
if
expected_scale_ptr_or_val
.
dtype
.
is_ptr
():
invscale
=
1.0
/
tl
.
load
(
expected_scale_ptr_or_val
)
else
:
invscale
=
1.0
/
expected_scale_ptr_or_val
else
:
invscale
=
1.0
if
checksum_scale_ptr
is
not
None
:
x_int32
=
x
.
to
(
tl
.
int32
,
bitcast
=
True
)
zero
=
tl
.
cast
(
0.0
,
tl
.
int32
)
if
mask
is
not
None
:
x_int32
=
tl
.
where
(
mask
,
x_int32
,
zero
)
checksum_local
=
tl
.
xor_sum
(
tl
.
ravel
(
x_int32
,
can_reorder
=
True
),
0
)
tl
.
atomic_add
(
checksum_scale_ptr
,
checksum_local
)
if
mask
is
not
None
:
if
actual_scale_ptr
is
not
None
:
x
=
tl
.
where
(
mask
,
x
,
0.0
)
update_scale
(
x
,
actual_scale_ptr
,
Out
)
x
=
x
*
invscale
# if expected_scale_ptr is not None, we applied flexpoint scale. We only want to clip in this case.
if
expected_scale_ptr_or_val
is
not
None
:
if
saturate_infs
:
CLIP_VALUE
=
max_finite
(
Out
.
dtype
.
element_ty
)
x
=
clip
(
x
,
CLIP_VALUE
)
return
x
vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp.py
0 → 100644
View file @
d29c39ca
# isort: off
# fmt: off
from
enum
import
Enum
import
triton
import
torch
import
torch.nn.functional
as
F
from
.mxfp_details._upcast_from_mxfp
import
_upcast_from_mxfp
from
.mxfp_details._downcast_to_mxfp
import
_downcast_to_mxfp
,
MXFP_BLOCK_SIZE
,
_quantize_mxfp8_fn
# -----------------------------------------------------------------------------
# Dequantization / Quantization Utilities
# -----------------------------------------------------------------------------
class
DequantScaleRoundingMode
(
Enum
):
ROUND_UP
=
0
ROUND_DOWN
=
1
def
downcast_to_mxfp
(
src_tensor
:
torch
.
Tensor
,
out_quant_type
:
torch
.
dtype
,
axis
:
int
,
DEQUANT_SCALE_ROUNDING_MODE
:
DequantScaleRoundingMode
=
DequantScaleRoundingMode
.
ROUND_UP
):
"""
Convert the src weights to mx format. The src weight is quantized along the axis dimension.
If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte.
Note that this means the k_dim of the tensor will be half of the logical k_dim.
If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored
in their respective formats.
"""
ndim
=
src_tensor
.
ndim
assert
-
ndim
<=
axis
<
ndim
,
f
"Invalid axis
{
axis
=
}
"
axis
=
axis
if
axis
>=
0
else
axis
+
ndim
# downcast
src_tensor
=
src_tensor
.
transpose
(
axis
,
src_tensor
.
ndim
-
1
)
is_fp4
=
out_quant_type
==
torch
.
uint8
is_fp8
=
out_quant_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
)
assert
is_fp4
or
is_fp8
divisor
=
2
if
is_fp4
else
1
L
=
src_tensor
.
shape
[
-
1
]
if
is_fp4
:
assert
L
%
2
==
0
,
f
"axis dim must be divisible by 2 for e2m1. Got
{
L
}
"
out_shape
=
src_tensor
.
shape
[:
-
1
]
+
(
L
//
divisor
,
)
out_scale_shape
=
src_tensor
.
shape
[:
-
1
]
+
(
triton
.
cdiv
(
L
,
MXFP_BLOCK_SIZE
),
)
out_quant_tensor
=
src_tensor
.
new_empty
(
out_shape
,
dtype
=
out_quant_type
)
out_scale
=
src_tensor
.
new_empty
(
out_scale_shape
,
dtype
=
torch
.
uint8
)
if
src_tensor
.
numel
()
>
0
:
kernel_src_tensor
=
src_tensor
.
reshape
(
-
1
,
src_tensor
.
shape
[
-
1
])
kernel_quant_tensor
=
out_quant_tensor
.
view
(
-
1
,
out_quant_tensor
.
shape
[
-
1
])
kernel_scale
=
out_scale
.
view
(
-
1
,
out_scale
.
shape
[
-
1
])
BLOCK_OUT_DIM
=
128
BLOCK_QUANT_DIM
=
MXFP_BLOCK_SIZE
.
value
grid_out
=
triton
.
cdiv
(
kernel_src_tensor
.
shape
[
0
],
BLOCK_OUT_DIM
)
grid_quant
=
triton
.
cdiv
(
kernel_src_tensor
.
shape
[
1
],
BLOCK_QUANT_DIM
)
_downcast_to_mxfp
[(
grid_out
,
grid_quant
)](
kernel_quant_tensor
,
*
kernel_quant_tensor
.
stride
(),
kernel_scale
,
*
kernel_scale
.
stride
(),
kernel_src_tensor
,
*
kernel_src_tensor
.
stride
(),
*
kernel_src_tensor
.
shape
,
BLOCK_OUT_DIM
,
BLOCK_QUANT_DIM
,
DEQUANT_SCALE_ROUNDING_MODE
.
value
,
num_warps
=
8
)
out_quant_tensor
=
out_quant_tensor
.
transpose
(
axis
,
src_tensor
.
ndim
-
1
)
out_scale
=
out_scale
.
transpose
(
axis
,
src_tensor
.
ndim
-
1
)
return
out_quant_tensor
,
out_scale
def
upcast_from_mxfp
(
tensor
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
target_dtype
:
torch
.
dtype
,
axis
:
int
):
"""
Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16.
The function assumes that the tensors were quantized along the given axis.
It permutes the tensor so that the quantized axis is last, reshapes to 2D,
launches the Triton upcast kernel, and then unpermutes back to the original order.
"""
ndim
=
tensor
.
ndim
assert
-
ndim
<=
axis
<
ndim
,
f
"Invalid axis
{
axis
=
}
"
axis
=
axis
if
axis
>=
0
else
axis
+
ndim
assert
tensor
.
ndim
==
scale
.
ndim
,
(
f
"Weight and scale must have the same number of dimensions. "
f
"Got
{
tensor
.
ndim
=
}
and
{
scale
.
ndim
=
}
"
)
# dtype checks
assert
tensor
.
dtype
in
{
torch
.
uint8
,
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
},
\
f
"Invalid tensor dtype
{
tensor
.
dtype
=
}
"
assert
scale
.
dtype
==
torch
.
uint8
,
f
"Invalid scale dtype
{
scale
.
dtype
=
}
"
assert
target_dtype
in
(
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
),
f
"Invalid output dtype
{
target_dtype
=
}
"
# upcast
logical_quant_dim
=
tensor
.
shape
[
axis
]
*
(
2
if
tensor
.
dtype
==
torch
.
uint8
else
1
)
tensor
=
tensor
.
transpose
(
axis
,
tensor
.
ndim
-
1
).
contiguous
()
scale
=
scale
.
transpose
(
axis
,
scale
.
ndim
-
1
).
contiguous
()
out
=
torch
.
empty
((
*
tensor
.
shape
[:
-
1
],
logical_quant_dim
),
dtype
=
target_dtype
,
device
=
tensor
.
device
)
reshaped_out
=
out
.
view
(
-
1
,
out
.
shape
[
-
1
])
reshaped_tensor
=
tensor
.
view
(
-
1
,
tensor
.
shape
[
-
1
])
reshaped_scale
=
scale
.
view
(
-
1
,
scale
.
shape
[
-
1
])
BLOCK_OUT_DIM
=
128
BLOCK_QUANT_DIM
=
MXFP_BLOCK_SIZE
.
value
blocks_out_dim
=
triton
.
cdiv
(
reshaped_out
.
shape
[
0
],
BLOCK_OUT_DIM
)
blocks_quant_dim
=
triton
.
cdiv
(
reshaped_out
.
shape
[
1
],
BLOCK_QUANT_DIM
)
_upcast_from_mxfp
[(
blocks_out_dim
,
blocks_quant_dim
)](
reshaped_out
,
*
reshaped_out
.
stride
(),
reshaped_scale
,
*
reshaped_scale
.
stride
(),
reshaped_tensor
,
*
reshaped_tensor
.
stride
(),
*
reshaped_out
.
shape
,
BLOCK_OUT_DIM
,
BLOCK_QUANT_DIM
,
num_warps
=
8
)
out
=
out
.
transpose
(
axis
,
scale
.
ndim
-
1
).
contiguous
()
return
out
# ------------
def
right_shift_unsigned
(
x
,
shift
):
# CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift
return
(
x
>>
shift
)
&
((
1
<<
(
32
-
shift
))
-
1
)
def
get_max_quant_val
(
dtype
:
torch
.
dtype
):
d
=
{
torch
.
uint8
:
6.0
,
torch
.
float8_e5m2
:
57344.0
,
torch
.
float8_e4m3fn
:
448.0
}
assert
dtype
in
d
return
d
[
dtype
]
def
downcast_to_mxfp_torch
(
src_tensor
:
torch
.
Tensor
,
out_quant_type
:
torch
.
dtype
,
axis
:
int
,
DEQUANT_SCALE_ROUNDING_MODE
:
DequantScaleRoundingMode
=
DequantScaleRoundingMode
.
ROUND_UP
):
"""
Converts the src tensor to the output format specified by out_quant_type.
axis: The axis along which the tensors are contiguous and quantization is applied.
DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN.
Returns:
out_quant_tensor: Quantized tensor in mx format.
• For mxfp8, the output has the same shape as src_tensor.
• For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis.
Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32),
where L is the original length along that axis.
"""
# This should probably be packed into its own tiny class
ndim
=
src_tensor
.
ndim
assert
-
ndim
<=
axis
<
ndim
,
f
"Invalid axis
{
axis
=
}
"
assert
src_tensor
.
dtype
in
{
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
},
f
"Invalid input tensor dtype
{
src_tensor
.
dtype
}
"
axis
=
axis
if
axis
>=
0
else
axis
+
ndim
is_fp4
=
out_quant_type
==
torch
.
uint8
is_fp8
=
"float8"
in
str
(
out_quant_type
)
assert
is_fp4
or
is_fp8
,
f
"Invalid input tensor dtype
{
out_quant_type
}
"
device
=
src_tensor
.
device
# For mxfp4 conversion, we assume the contiguous axis length is even.
if
is_fp4
:
axis_shape
=
src_tensor
.
size
(
axis
)
assert
axis_shape
%
2
==
0
,
"For mxfp4 conversion the contiguous axis length must be even."
# Permute the tensor so that the contiguous axis becomes the last dimension.
src
=
src_tensor
.
transpose
(
axis
,
src_tensor
.
ndim
-
1
).
to
(
torch
.
float32
)
axis_shape
=
src
.
shape
[
-
1
]
# Pad the axis to be divisible by 32, in case it is not.
next_multiple
=
triton
.
cdiv
(
axis_shape
,
MXFP_BLOCK_SIZE
)
*
MXFP_BLOCK_SIZE
pad_amount
=
next_multiple
-
axis_shape
padded_src
=
F
.
pad
(
src
,
(
0
,
pad_amount
))
valid_mask
=
F
.
pad
(
torch
.
ones_like
(
src
,
dtype
=
torch
.
bool
),
(
0
,
pad_amount
))
padded_axis_shape
=
padded_src
.
size
(
-
1
)
# now divisible by 32
# --- Compute per-group maximums for scale ---
# Set padded entries to -1 so they don’t affect the max.
abs_f
=
torch
.
abs
(
padded_src
)
abs_f
=
torch
.
where
(
valid_mask
,
abs_f
,
torch
.
tensor
(
-
1.0
,
device
=
device
,
dtype
=
padded_src
.
dtype
))
# Reshape the last dimension into groups of 32.
new_shape
=
padded_src
.
shape
[:
-
1
]
+
(
padded_axis_shape
//
MXFP_BLOCK_SIZE
,
MXFP_BLOCK_SIZE
)
abs_groups
=
abs_f
.
view
(
*
new_shape
)
# Compute maximum along the group dimension (of size 32).
max_val
,
_
=
abs_groups
.
max
(
dim
=-
1
,
keepdim
=
True
)
# Choose a max quantization value depending on type.
max_quant_val
=
get_max_quant_val
(
out_quant_type
)
dequant_scale
=
max_val
/
max_quant_val
# shape: (..., padded_axis_shape//32, 1)
# Convert to int to round the FP32 scale, prior to quantization!
ds_int
=
dequant_scale
.
view
(
torch
.
int32
)
if
DEQUANT_SCALE_ROUNDING_MODE
==
DequantScaleRoundingMode
.
ROUND_UP
:
ds_int_rounded
=
(
ds_int
+
0x007FFFFF
)
&
0x7F800000
else
:
ds_int_rounded
=
ds_int
&
0x7F800000
# Reinterpret back as float32.
dequant_scale_rounded
=
ds_int_rounded
.
view
(
torch
.
float32
)
# Compute the quantization scale.
quant_scale
=
torch
.
where
(
dequant_scale_rounded
==
0
,
torch
.
tensor
(
0.0
,
device
=
device
),
1.0
/
dequant_scale_rounded
)
# Quantize the tensor
orig_padded_shape
=
padded_src
.
shape
padded_src_groups
=
padded_src
.
view
(
*
new_shape
)
quant_tensor
=
padded_src_groups
*
quant_scale
# Reshape back to the original shape and trim padding
quant_tensor
=
quant_tensor
.
view
(
orig_padded_shape
)
quant_tensor
=
quant_tensor
[...,
:
axis_shape
]
# Finally, convert the quantized tensor to the target format
if
is_fp8
:
# Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior
quant_tensor
=
torch
.
clamp
(
quant_tensor
,
-
max_quant_val
,
max_quant_val
)
out_weight
=
quant_tensor
.
to
(
out_quant_type
)
else
:
assert
is_fp4
,
f
"Invalid output quantization type
{
out_quant_type
}
"
# For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8.
# First, reinterpret the quantized tensor bits.
q_int
=
quant_tensor
.
contiguous
().
view
(
torch
.
int32
)
# Extract sign, exponent, and mantissa.
signs
=
q_int
&
0x80000000
exponents
=
right_shift_unsigned
(
q_int
,
23
)
&
0xFF
mantissas
=
q_int
&
0x7FFFFF
E8_BIAS
=
127
E2_BIAS
=
1
# Adjust mantissas for subnormals.
mantissas
=
torch
.
where
(
exponents
<
E8_BIAS
,
(
0x400000
|
right_shift_unsigned
(
mantissas
,
1
))
>>
(
E8_BIAS
-
exponents
-
1
),
mantissas
)
exponents
=
torch
.
maximum
(
exponents
,
torch
.
tensor
(
E8_BIAS
-
E2_BIAS
,
device
=
device
))
-
(
E8_BIAS
-
E2_BIAS
)
e2m1_tmp
=
right_shift_unsigned
(((
exponents
<<
2
)
|
right_shift_unsigned
(
mantissas
,
21
))
+
1
,
1
)
e2m1_tmp
=
torch
.
minimum
(
e2m1_tmp
,
torch
.
tensor
(
0x7
,
device
=
device
))
e2m1_value
=
(
right_shift_unsigned
(
signs
,
28
)
|
e2m1_tmp
).
to
(
torch
.
uint8
)
# shape: (..., even_axis_shape)
# Pack pairs of 4-bit values along the last dimension.
e2m1_value
=
e2m1_value
.
view
(
*
e2m1_value
.
shape
[:
-
1
],
axis_shape
//
2
,
2
)
evens
=
e2m1_value
[...,
0
]
odds
=
e2m1_value
[...,
1
]
out_weight
=
evens
|
(
odds
<<
4
)
# shape: (..., axis_shape//2)
# --- Process and output the scale ---
dq_scale
=
(
ds_int_rounded
.
view
(
*
dequant_scale
.
shape
)
>>
23
).
to
(
torch
.
uint8
)
# shape: (..., axis_shape//32, 1)
dq_scale
=
dq_scale
.
squeeze
(
-
1
)
out_weight
=
out_weight
.
transpose
(
axis
,
src_tensor
.
ndim
-
1
)
dq_scale
=
dq_scale
.
transpose
(
axis
,
src_tensor
.
ndim
-
1
)
return
out_weight
,
dq_scale
def
cvt_e2m1_to_fp32
(
input_tensor
):
assert
input_tensor
.
dtype
==
torch
.
uint8
input_tensor
=
input_tensor
.
to
(
torch
.
int32
)
evens
=
input_tensor
&
0xF
odds
=
(
input_tensor
>>
4
)
&
0xF
vals
=
[
0.0
,
0.5
,
1
,
1.5
,
2
,
3
,
4
,
6
]
outputs
=
torch
.
tensor
(
vals
,
dtype
=
torch
.
float32
,
device
=
input_tensor
.
device
)
outputs
=
torch
.
cat
([
outputs
,
-
outputs
])
even_floats
=
outputs
[
evens
]
odd_floats
=
outputs
[
odds
]
output_tensor
=
torch
.
stack
([
even_floats
,
odd_floats
],
dim
=-
1
)
output_tensor
=
output_tensor
.
view
(
*
input_tensor
.
shape
[:
-
1
],
-
1
)
return
output_tensor
def
upcast_from_mxfp_torch
(
tensor
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
target_dtype
:
torch
.
dtype
,
axis
:
int
):
"""
Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype.
axis: The axis along which dequantization is applied.
Returns:
out_weight: Tensor in the target format.
"""
ndim
=
tensor
.
ndim
assert
-
ndim
<=
axis
<
ndim
,
f
"Invalid axis
{
axis
=
}
"
is_fp8
=
tensor
.
dtype
==
torch
.
float8_e4m3fn
or
tensor
.
dtype
==
torch
.
float8_e5m2
assert
is_fp8
or
tensor
.
dtype
==
torch
.
uint8
,
f
"Invalid input quantization type
{
tensor
.
dtype
}
"
# Permute the tensor and scale so that the quantization axis becomes the last dimension
axis
=
axis
if
axis
>=
0
else
axis
+
ndim
scale
=
scale
.
transpose
(
axis
,
scale
.
ndim
-
1
)
tensor
=
tensor
.
transpose
(
axis
,
tensor
.
ndim
-
1
)
dq_scale
=
(
scale
.
to
(
torch
.
int32
)
<<
23
).
view
(
torch
.
float32
)
# Shift to the exponent and bitcast to fp32
if
tensor
.
dtype
==
torch
.
uint8
:
fp32_tensor
=
cvt_e2m1_to_fp32
(
tensor
)
else
:
fp32_tensor
=
tensor
.
to
(
torch
.
float32
)
logical_quant_dim
=
tensor
.
shape
[
-
1
]
*
(
2
if
tensor
.
dtype
==
torch
.
uint8
else
1
)
axis_shape
=
fp32_tensor
.
size
(
-
1
)
padded_axis_shape
=
triton
.
cdiv
(
logical_quant_dim
,
MXFP_BLOCK_SIZE
)
*
MXFP_BLOCK_SIZE
pad_size
=
padded_axis_shape
-
axis_shape
padded_tensor
=
F
.
pad
(
fp32_tensor
,
(
0
,
pad_size
))
new_axis_shape
=
padded_tensor
.
shape
[
-
1
]
new_shape
=
padded_tensor
.
shape
[:
-
1
]
+
(
new_axis_shape
//
MXFP_BLOCK_SIZE
,
MXFP_BLOCK_SIZE
)
padded_tensor
=
padded_tensor
.
view
(
*
new_shape
)
dq_scale_padded
=
dq_scale
.
unsqueeze
(
-
1
)
# shape: [..., ceil(axis_shape/32), 1]
out_padded
=
padded_tensor
*
dq_scale_padded
# Flatten back and remove the padded tail
out_padded
=
out_padded
.
view
(
*
fp32_tensor
.
shape
[:
-
1
],
new_axis_shape
)
out_tensor
=
out_padded
[...,
:
axis_shape
]
out_tensor
=
out_tensor
.
to
(
target_dtype
).
contiguous
()
out_tensor
=
out_tensor
.
transpose
(
axis
,
tensor
.
ndim
-
1
)
return
out_tensor
quantize_mxfp8_fn
=
_quantize_mxfp8_fn
vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/__init__.py
0 → 100644
View file @
d29c39ca
vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
0 → 100644
View file @
d29c39ca
import
triton
import
triton.language
as
tl
# fmt: off
MXFP_BLOCK_SIZE
=
tl
.
constexpr
(
32
)
@
triton
.
jit
def
_get_max_quant_val
(
dtype
:
tl
.
constexpr
):
if
dtype
==
tl
.
uint8
:
return
6.0
elif
dtype
==
tl
.
float8e5
:
return
57344.0
elif
dtype
==
tl
.
float8e4nv
:
return
448.0
else
:
tl
.
static_assert
(
False
,
f
"Invalid
{
dtype
=
}
"
)
@
triton
.
jit
def
_compute_quant_and_scale
(
src_tensor
,
valid_src_mask
,
mx_tensor_dtype
:
tl
.
constexpr
,
DEQUANT_SCALE_ROUNDING_MODE
:
tl
.
constexpr
=
0
):
is_fp8
:
tl
.
constexpr
=
mx_tensor_dtype
==
tl
.
float8e4nv
or
mx_tensor_dtype
==
tl
.
float8e5
BLOCK_SIZE_OUT_DIM
:
tl
.
constexpr
=
src_tensor
.
shape
[
0
]
BLOCK_SIZE_QUANT_DIM
:
tl
.
constexpr
=
src_tensor
.
shape
[
1
]
BLOCK_SIZE_QUANT_MX_SCALE
:
tl
.
constexpr
=
src_tensor
.
shape
[
1
]
//
MXFP_BLOCK_SIZE
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
f32_tensor
=
src_tensor
.
to
(
tl
.
float32
)
abs_tensor
=
tl
.
abs
(
f32_tensor
)
abs_tensor
=
tl
.
where
(
valid_src_mask
,
abs_tensor
,
-
1.0
)
# Don't consider padding tensors in scale computation
abs_tensor
=
tl
.
reshape
(
abs_tensor
,
[
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
,
MXFP_BLOCK_SIZE
])
max_val
=
tl
.
max
(
abs_tensor
,
axis
=
2
,
keep_dims
=
True
)
dequant_scale
=
max_val
/
_get_max_quant_val
(
mx_tensor_dtype
)
if
DEQUANT_SCALE_ROUNDING_MODE
==
0
:
# DequantScaleRoundingMode.ROUND_UP
# compute 2 ** ceil(log2(dequant_scale))
# Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
# A corner case: exponent is 0xFF that will overflow but that's already
# NaN so assume we don't care.
dequant_scale_exponent
=
(
dequant_scale
.
to
(
tl
.
uint32
,
bitcast
=
True
)
+
0x007FFFFF
)
&
0x7F800000
else
:
# DequantScaleRoundingMode.ROUND_DOWN
# compute 2 ** floor(log2(dequant_scale))
assert
DEQUANT_SCALE_ROUNDING_MODE
==
1
dequant_scale_exponent
=
dequant_scale
.
to
(
tl
.
uint32
,
bitcast
=
True
)
&
0x7F800000
dequant_scale_rounded
=
dequant_scale_exponent
.
to
(
tl
.
float32
,
bitcast
=
True
)
quant_scale
=
tl
.
where
(
dequant_scale_rounded
==
0
,
0
,
1.0
/
dequant_scale_rounded
)
f32_tensor
=
tl
.
reshape
(
f32_tensor
,
[
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
,
MXFP_BLOCK_SIZE
])
quant_tensor
=
f32_tensor
*
quant_scale
# Reshape the tensors after scaling
quant_tensor
=
quant_tensor
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_DIM
])
# Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
quant_tensor
=
tl
.
where
(
valid_src_mask
,
quant_tensor
,
0
)
dequant_scale_exponent
=
dequant_scale_exponent
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
])
# First, we simply extract the exponent part of the scales and store the result
dequant_scale_exponent
=
(
dequant_scale_exponent
>>
23
).
to
(
tl
.
uint8
)
# Now we must convert the tensors to the mx format.
if
is_fp8
:
out_tensor
=
quant_tensor
.
to
(
mx_tensor_dtype
)
else
:
quant_tensor
=
quant_tensor
.
to
(
tl
.
uint32
,
bitcast
=
True
)
signs
=
quant_tensor
&
0x80000000
exponents
=
(
quant_tensor
>>
23
)
&
0xFF
mantissas
=
(
quant_tensor
&
0x7FFFFF
)
# 0.25 <= x < 0.75 maps to 0.5, a denormal number
E8_BIAS
=
127
E2_BIAS
=
1
# Move implicit bit 1 at the beginning to mantissa for denormals
adjusted_exponents
=
tl
.
core
.
sub
(
E8_BIAS
,
exponents
+
1
,
sanitize_overflow
=
False
)
mantissas
=
tl
.
where
(
exponents
<
E8_BIAS
,
(
0x400000
|
(
mantissas
>>
1
))
>>
adjusted_exponents
,
mantissas
)
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
exponents
=
tl
.
maximum
(
exponents
,
E8_BIAS
-
E2_BIAS
)
-
(
E8_BIAS
-
E2_BIAS
)
# Combine sign, exponent, and mantissa, while saturating
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
e2m1_tmp
=
tl
.
minimum
((((
exponents
<<
2
)
|
(
mantissas
>>
21
))
+
1
)
>>
1
,
0x7
)
e2m1_value
=
((
signs
>>
28
)
|
e2m1_tmp
).
to
(
tl
.
uint8
)
e2m1_value
=
tl
.
reshape
(
e2m1_value
,
[
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_DIM
//
2
,
2
])
evens
,
odds
=
tl
.
split
(
e2m1_value
)
out_tensor
=
evens
|
(
odds
<<
4
)
return
out_tensor
,
dequant_scale_exponent
@
triton
.
jit
def
_downcast_to_mxfp
(
mx_tensor_ptr
,
stride_mxt_outer
,
stride_mxt_quant
:
tl
.
constexpr
,
mx_scale_ptr
,
stride_mx_scale_outer
,
stride_mx_scale_quant
,
src_ptr
,
stride_src_outer
,
stride_src_quant
,
outer_dim
,
quant_dim
,
BLOCK_SIZE_OUT_DIM
:
tl
.
constexpr
,
BLOCK_SIZE_QUANT_DIM
:
tl
.
constexpr
,
DEQUANT_SCALE_ROUNDING_MODE
:
tl
.
constexpr
):
tl
.
static_assert
(
stride_mxt_quant
==
1
,
f
"Output stride,
{
stride_mxt_quant
=
}
must be 1."
)
tl
.
static_assert
(
BLOCK_SIZE_QUANT_DIM
%
MXFP_BLOCK_SIZE
==
0
,
f
"
{
BLOCK_SIZE_QUANT_DIM
=
}
must be a multiple of 32"
)
# uint8 signifies two fp4 e2m1 values packed into a single byte
mx_tensor_dtype
:
tl
.
constexpr
=
mx_tensor_ptr
.
dtype
.
element_ty
tl
.
static_assert
(
mx_tensor_dtype
==
tl
.
uint8
or
(
mx_tensor_dtype
==
tl
.
float8e4nv
or
mx_tensor_dtype
==
tl
.
float8e5
),
f
"Invalid
{
mx_tensor_dtype
=
}
. Must be uint8 or float8."
)
src_dtype
:
tl
.
constexpr
=
src_ptr
.
dtype
.
element_ty
tl
.
static_assert
(
mx_scale_ptr
.
dtype
.
element_ty
==
tl
.
uint8
,
f
"
{
mx_scale_ptr
.
dtype
.
element_ty
=
}
must be uint8"
)
tl
.
static_assert
((
src_dtype
==
tl
.
bfloat16
)
or
(
src_dtype
==
tl
.
float16
)
or
(
src_dtype
==
tl
.
float32
),
f
"
{
src_dtype
=
}
must be bfloat16 or float16 or float32"
)
is_fp4
:
tl
.
constexpr
=
mx_tensor_dtype
==
tl
.
uint8
outer_block
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
quant_block
=
tl
.
program_id
(
1
).
to
(
tl
.
int64
)
K_DIVISOR
:
tl
.
constexpr
=
2
if
is_fp4
else
1
BLOCK_SIZE_QUANT_MX_SCALE
:
tl
.
constexpr
=
BLOCK_SIZE_QUANT_DIM
//
MXFP_BLOCK_SIZE
BLOCK_SIZE_QUANT_MX_TENSOR
:
tl
.
constexpr
=
BLOCK_SIZE_QUANT_DIM
//
K_DIVISOR
start_src_quant
=
quant_block
*
BLOCK_SIZE_QUANT_DIM
start_mx_scale_quant
=
quant_block
*
BLOCK_SIZE_QUANT_MX_SCALE
start_mx_quant
=
quant_block
*
BLOCK_SIZE_QUANT_MX_TENSOR
start_out
=
outer_block
*
BLOCK_SIZE_OUT_DIM
src_ptr
+=
start_src_quant
*
stride_src_quant
+
start_out
*
stride_src_outer
mx_scale_ptr
+=
start_mx_scale_quant
*
stride_mx_scale_quant
+
start_out
*
stride_mx_scale_outer
mx_tensor_ptr
+=
start_mx_quant
*
stride_mxt_quant
+
start_out
*
stride_mxt_outer
offs_src_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_DIM
)[
None
,
:].
to
(
tl
.
int64
)
offs_mxt_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_MX_TENSOR
)[
None
,
:].
to
(
tl
.
int64
)
offs_scale_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_MX_SCALE
)[
None
,
:].
to
(
tl
.
int64
)
offs_outer
=
tl
.
arange
(
0
,
BLOCK_SIZE_OUT_DIM
)[:,
None
].
to
(
tl
.
int64
)
mask_src_quant
=
start_src_quant
+
offs_src_quant
<
quant_dim
mask_n
=
start_out
+
offs_outer
<
outer_dim
full_mask_src
=
mask_src_quant
&
mask_n
mask_mxt_quant
=
start_mx_quant
+
offs_mxt_quant
<
tl
.
cdiv
(
quant_dim
,
K_DIVISOR
)
full_mask_mxt
=
mask_mxt_quant
&
mask_n
scale_mask_k
=
start_mx_scale_quant
+
offs_scale_quant
<
tl
.
cdiv
(
quant_dim
,
MXFP_BLOCK_SIZE
)
full_scale_mask
=
scale_mask_k
&
mask_n
src_tensor_offsets
=
offs_src_quant
*
stride_src_quant
+
offs_outer
*
stride_src_outer
mx_scale_offsets
=
offs_scale_quant
*
stride_mx_scale_quant
+
offs_outer
*
stride_mx_scale_outer
mx_tensor_offsets
=
offs_mxt_quant
*
stride_mxt_quant
+
offs_outer
*
stride_mxt_outer
src_tensor
=
tl
.
load
(
src_ptr
+
src_tensor_offsets
,
mask
=
full_mask_src
)
out_tensor
,
scale_tensor
=
_compute_quant_and_scale
(
src_tensor
,
full_mask_src
,
mx_tensor_dtype
,
DEQUANT_SCALE_ROUNDING_MODE
)
tl
.
store
(
mx_scale_ptr
+
mx_scale_offsets
,
scale_tensor
,
mask
=
full_scale_mask
)
tl
.
store
(
mx_tensor_ptr
+
mx_tensor_offsets
,
out_tensor
,
mask
=
full_mask_mxt
)
@
triton
.
jit
(
repr
=
lambda
_
:
"_dequantize_mxfp8"
)
def
_quantize_mxfp8_fn
(
input
,
mask
,
pid
=
None
):
return
_compute_quant_and_scale
(
input
,
mask
,
tl
.
float8e4nv
)
vllm/kvprune_legacy_save/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
0 → 100644
View file @
d29c39ca
import
triton
import
triton.language
as
tl
from
._downcast_to_mxfp
import
MXFP_BLOCK_SIZE
# fmt: off
@
triton
.
jit
def
_upcast_from_mxfp
(
out_ptr
,
stride_o_outer
,
stride_o_quant
:
tl
.
constexpr
,
mx_scale_ptr
,
stride_scale_outer
,
stride_scale_quant
,
mx_tensor_ptr
,
stride_tensor_outer
,
stride_tensor_quant
:
tl
.
constexpr
,
outer_dim
,
quant_dim
,
BLOCK_SIZE_OUT_DIM
:
tl
.
constexpr
,
BLOCK_SIZE_QUANT_DIM
:
tl
.
constexpr
):
tl
.
static_assert
(
stride_o_quant
==
1
,
"the weight must be contiguous in the k dimension for mx"
)
tl
.
static_assert
(
BLOCK_SIZE_QUANT_DIM
%
MXFP_BLOCK_SIZE
==
0
,
"BLOCK_SIZE_K must be a multiple of 32"
)
# uint8 signifies two fp4 e2m1 values packed into a single byte
mx_tensor_dtype
:
tl
.
constexpr
=
mx_tensor_ptr
.
dtype
.
element_ty
dst_dtype
:
tl
.
constexpr
=
out_ptr
.
dtype
.
element_ty
tl
.
static_assert
(
dst_dtype
==
tl
.
float16
or
dst_dtype
==
tl
.
bfloat16
or
dst_dtype
==
tl
.
float32
)
tl
.
static_assert
(
mx_tensor_dtype
==
tl
.
uint8
or
((
mx_tensor_dtype
==
tl
.
float8e4nv
or
mx_tensor_dtype
==
tl
.
float8e5
)
or
mx_tensor_dtype
==
dst_dtype
),
"mx_tensor_ptr must be uint8 or float8 or dst_dtype"
)
tl
.
static_assert
(
mx_scale_ptr
.
dtype
.
element_ty
==
tl
.
uint8
,
"mx_scale_ptr must be uint8"
)
# Determine if we are dealing with fp8 types.
is_fp4
:
tl
.
constexpr
=
mx_tensor_dtype
==
tl
.
uint8
is_fp8
:
tl
.
constexpr
=
mx_tensor_dtype
==
tl
.
float8e4nv
or
mx_tensor_dtype
==
tl
.
float8e5
K_DIVISOR
:
tl
.
constexpr
=
2
if
is_fp4
else
1
BLOCK_SIZE_QUANT_MX_SCALE
:
tl
.
constexpr
=
BLOCK_SIZE_QUANT_DIM
//
MXFP_BLOCK_SIZE
BLOCK_SIZE_QUANT_MX_TENSOR
:
tl
.
constexpr
=
BLOCK_SIZE_QUANT_DIM
//
K_DIVISOR
# Compute starting indices for the quantized (packed) dimension and the outer dimension.
outer_block
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
quant_block
=
tl
.
program_id
(
1
).
to
(
tl
.
int64
)
start_mxt_quant
=
quant_block
*
BLOCK_SIZE_QUANT_MX_TENSOR
start_out_quant
=
quant_block
*
BLOCK_SIZE_QUANT_DIM
start_mx_scale_quant
=
quant_block
*
BLOCK_SIZE_QUANT_MX_SCALE
start_out
=
outer_block
*
BLOCK_SIZE_OUT_DIM
mx_tensor_ptr
+=
start_mxt_quant
*
stride_tensor_quant
+
start_out
*
stride_tensor_outer
mx_scale_ptr
+=
start_mx_scale_quant
*
stride_scale_quant
+
start_out
*
stride_scale_outer
out_ptr
+=
start_out
*
stride_o_outer
+
start_out_quant
*
stride_o_quant
# Compute offsets and masks.
offs_src_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_MX_TENSOR
)[
None
,
:].
to
(
tl
.
int64
)
offs_out_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_DIM
)[
None
,
:].
to
(
tl
.
int64
)
offs_outer
=
tl
.
arange
(
0
,
BLOCK_SIZE_OUT_DIM
)[:,
None
].
to
(
tl
.
int64
)
offs_scale
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_MX_SCALE
)[
None
,
:].
to
(
tl
.
int64
)
mask_outer
=
start_out
+
offs_outer
<
outer_dim
mask_out_quant
=
start_out_quant
+
offs_out_quant
<
quant_dim
full_mask_out
=
mask_out_quant
&
mask_outer
mask_src_quant
=
start_mxt_quant
+
offs_src_quant
<
tl
.
cdiv
(
quant_dim
,
K_DIVISOR
)
full_mask_src
=
mask_src_quant
&
mask_outer
mask_scale
=
start_mx_scale_quant
+
offs_scale
<
tl
.
cdiv
(
quant_dim
,
MXFP_BLOCK_SIZE
)
full_scale_mask
=
mask_scale
&
mask_outer
tensor_offsets
=
offs_src_quant
*
stride_tensor_quant
+
offs_outer
*
stride_tensor_outer
scale_offsets
=
offs_scale
*
stride_scale_quant
+
offs_outer
*
stride_scale_outer
out_offsets
=
offs_out_quant
*
stride_o_quant
+
offs_outer
*
stride_o_outer
# Load the packed tensor and scale.
tensor
=
tl
.
load
(
mx_tensor_ptr
+
tensor_offsets
,
mask
=
full_mask_src
)
scale
=
tl
.
load
(
mx_scale_ptr
+
scale_offsets
,
mask
=
full_scale_mask
)
# Upcast the scale to the destination type.
if
dst_dtype
==
tl
.
bfloat16
:
dst_scale
=
(
scale
.
to
(
tl
.
uint16
)
<<
7
).
to
(
dst_dtype
,
bitcast
=
True
)
else
:
dst_scale
=
(
scale
.
to
(
tl
.
uint32
)
<<
23
).
to
(
tl
.
float32
,
bitcast
=
True
)
if
dst_dtype
==
tl
.
float16
:
dst_scale
=
dst_scale
.
to
(
tl
.
float16
)
# Now upcast the tensor.
intermediate_dtype
:
tl
.
constexpr
=
tl
.
bfloat16
if
dst_dtype
==
tl
.
float32
else
dst_dtype
if
is_fp8
:
dst_tensor
=
tensor
.
to
(
intermediate_dtype
)
if
tensor
.
dtype
==
tl
.
float8e5
:
from_e_bits
:
tl
.
constexpr
=
5
from_m_bits
:
tl
.
constexpr
=
2
to_e_bits
:
tl
.
constexpr
=
8
if
intermediate_dtype
==
tl
.
bfloat16
else
5
to_m_bits
:
tl
.
constexpr
=
7
if
intermediate_dtype
==
tl
.
bfloat16
else
10
# Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
non_finite_mask_src
:
tl
.
constexpr
=
((
1
<<
from_e_bits
)
-
1
)
<<
from_m_bits
non_finite_mask_dst
:
tl
.
constexpr
=
((
1
<<
to_e_bits
)
-
1
)
<<
to_m_bits
dst_tensor
=
tl
.
where
(
(
tensor
.
to
(
tl
.
uint8
,
bitcast
=
True
)
&
non_finite_mask_src
)
==
non_finite_mask_src
,
(
dst_tensor
.
to
(
tl
.
uint16
,
bitcast
=
True
)
|
non_finite_mask_dst
).
to
(
intermediate_dtype
,
bitcast
=
True
),
dst_tensor
,
)
else
:
assert
is_fp4
dst_bias
:
tl
.
constexpr
=
127
if
intermediate_dtype
==
tl
.
bfloat16
else
15
dst_0p5
:
tl
.
constexpr
=
16128
if
intermediate_dtype
==
tl
.
bfloat16
else
0x3800
dst_m_bits
:
tl
.
constexpr
=
7
if
intermediate_dtype
==
tl
.
bfloat16
else
10
# e2m1
em0
=
tensor
&
0x07
em1
=
tensor
&
0x70
x0
=
(
em0
.
to
(
tl
.
uint16
)
<<
(
dst_m_bits
-
1
))
|
((
tensor
&
0x08
).
to
(
tl
.
uint16
)
<<
12
)
x1
=
(
em1
.
to
(
tl
.
uint16
)
<<
(
dst_m_bits
-
5
))
|
((
tensor
&
0x80
).
to
(
tl
.
uint16
)
<<
8
)
# Three cases:
# 1) x is normal and non-zero: Correct bias
x0
=
tl
.
where
((
em0
&
0x06
)
!=
0
,
x0
+
((
dst_bias
-
1
)
<<
dst_m_bits
),
x0
)
x1
=
tl
.
where
((
em1
&
0x60
)
!=
0
,
x1
+
((
dst_bias
-
1
)
<<
dst_m_bits
),
x1
)
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
x0
=
tl
.
where
(
em0
==
0x01
,
dst_0p5
|
(
x0
&
0x8000
),
x0
)
x1
=
tl
.
where
(
em1
==
0x10
,
dst_0p5
|
(
x1
&
0x8000
),
x1
)
# 3) x is zero, do nothing
dst_tensor
=
tl
.
interleave
(
x0
,
x1
).
to
(
intermediate_dtype
,
bitcast
=
True
)
dst_tensor
=
dst_tensor
.
to
(
dst_dtype
)
# Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping.
dst_tensor
=
dst_tensor
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
,
MXFP_BLOCK_SIZE
])
dst_scale
=
dst_scale
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
,
1
])
scale
=
scale
.
reshape
(
dst_scale
.
shape
)
out_tensor
=
dst_tensor
*
dst_scale
# Correct any NaNs encoded via the scale.
out_tensor
=
tl
.
where
(
scale
==
0xFF
,
float
(
"nan"
),
out_tensor
)
out_tensor
=
out_tensor
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_DIM
])
tl
.
store
(
out_ptr
+
out_offsets
,
out_tensor
,
mask
=
full_mask_out
)
vllm/kvprune_legacy_save/triton_kernels/proton_opts.py
0 → 100644
View file @
d29c39ca
# proton options
import
os
_launch_metadata_allow_sync
=
None
def
launch_metadata_allow_sync
():
global
_launch_metadata_allow_sync
if
_launch_metadata_allow_sync
is
None
:
_launch_metadata_allow_sync
=
not
(
os
.
getenv
(
"PROTON_LAUNCH_METADATA_NOSYNC"
)
==
"1"
)
return
_launch_metadata_allow_sync
def
set_launch_metadata_allow_sync
(
allow_sync
:
bool
):
global
_launch_metadata_allow_sync
_launch_metadata_allow_sync
=
allow_sync
vllm/kvprune_legacy_save/triton_kernels/reduction_details/__init__.py
0 → 100644
View file @
d29c39ca
vllm/kvprune_legacy_save/triton_kernels/reduction_details/reduce_bitmatrix.py
0 → 100644
View file @
d29c39ca
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
vpopc
(
x
):
"""
Vertical popcount
Input x : uint32[..., N]
Output y : uint32[..., 32]
semantics : y[..., i] = sum_j((x[..., j] >> i) & 1)
credits: @apgoucher
"""
tl
.
static_assert
(
x
.
dtype
==
tl
.
uint32
,
"x should consist of 32-bit unsigned integers"
)
BLOCK_N
:
tl
.
constexpr
=
x
.
shape
[
-
1
]
# summation axis
BATCHES
:
tl
.
constexpr
=
x
.
numel
//
BLOCK_N
# number of batches
if
BLOCK_N
>=
8
:
sa1
:
tl
.
constexpr
=
8
else
:
sa1
:
tl
.
constexpr
=
BLOCK_N
# create 8-way sums in 4-bit fields:
y
=
tl
.
reshape
(
x
,
[
BATCHES
,
BLOCK_N
//
sa1
,
sa1
,
1
])
y
=
(
y
>>
tl
.
arange
(
0
,
4
)[
None
,
None
,
None
,
:])
&
0x11111111
y
=
tl
.
sum
(
y
,
2
)
# [BATCHES, BLOCK_N // sa1, 4]
if
BLOCK_N
>=
128
:
sa2
:
tl
.
constexpr
=
16
else
:
sa2
:
tl
.
constexpr
=
BLOCK_N
//
sa1
# create 128-way sums in 8-bit fields:
y
=
tl
.
reshape
(
y
,
[
BATCHES
,
BLOCK_N
//
(
sa1
*
sa2
),
sa2
,
1
,
4
])
y
=
(
y
>>
(
4
*
tl
.
arange
(
0
,
2
))[
None
,
None
,
None
,
:,
None
])
&
0x0F0F0F0F
y
=
tl
.
sum
(
y
,
2
)
# [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4]
sa3
:
tl
.
constexpr
=
BLOCK_N
//
(
sa1
*
sa2
)
# create N-way sums in 32-bit fields:
y
=
tl
.
reshape
(
y
,
[
BATCHES
,
1
,
sa3
,
8
])
y
=
(
y
>>
(
8
*
tl
.
arange
(
0
,
4
))[
None
,
:,
None
,
None
])
&
0x000000FF
y
=
tl
.
sum
(
y
,
2
)
# [BATCHES, 4, 8]
y
=
tl
.
reshape
(
y
,
x
.
shape
[:
-
1
]
+
[
32
])
return
y
@
triton
.
jit
def
_sum_bitmatrix_memset
(
Ret
,
BLOCK
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
0
)
offs
=
pid
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
tl
.
store
(
Ret
+
offs
,
0
)
@
triton
.
jit
def
_sum_bitmatrix_rows
(
B
,
shape_bm
,
stride_bm
:
tl
.
constexpr
,
stride_bn
:
tl
.
constexpr
,
# input bitmatrix
Ret
,
Partials
,
stride_pm
:
tl
.
constexpr
,
stride_pn
,
shape_pn
,
# outputs
BLOCK_MM
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
tl
.
static_assert
(
BLOCK_MM
%
BLOCK_M
==
0
)
TILE_SIZE
:
tl
.
constexpr
=
BLOCK_MM
//
BLOCK_M
if
isinstance
(
shape_bm
,
tl
.
tensor
)
and
shape_bm
.
dtype
.
is_ptr
():
shape_bm
=
tl
.
load
(
shape_bm
)
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_MM
+
tl
.
arange
(
0
,
BLOCK_MM
)
offs_n
=
pid_n
*
32
+
tl
.
arange
(
0
,
32
)
n_rows
=
shape_bm
bits
=
tl
.
load
(
B
+
pid_n
*
stride_bn
+
offs_m
*
stride_bm
,
mask
=
offs_m
<
n_rows
,
other
=
0
)
bits
=
tl
.
reshape
(
bits
,
[
TILE_SIZE
,
BLOCK_M
])
ret
=
vpopc
(
bits
)
# [TILE_SIZE, 32]
offs_t
=
pid_m
*
TILE_SIZE
+
tl
.
arange
(
0
,
TILE_SIZE
)
tl
.
atomic_add
(
Ret
+
offs_n
,
tl
.
sum
(
ret
,
0
),
sem
=
"relaxed"
)
tl
.
store
(
Partials
+
offs_t
[:,
None
]
*
stride_pm
+
offs_n
[
None
,
:]
*
stride_pn
,
ret
)
def
clear_sums
(
n_cols
,
device
,
MEMSET_BLOCK
=
512
):
cdiv
=
triton
.
cdiv
blocks
=
cdiv
(
n_cols
,
MEMSET_BLOCK
)
out_ret
=
torch
.
empty
((
blocks
*
MEMSET_BLOCK
,),
device
=
device
,
dtype
=
torch
.
int32
)
_sum_bitmatrix_memset
[(
blocks
,)](
out_ret
,
MEMSET_BLOCK
)
return
out_ret
def
sum_bitmatrix_rows
(
x
,
out_ret
,
partials_block_size
=
None
):
assert
partials_block_size
is
not
None
cdiv
=
triton
.
cdiv
PARTIALS_BLOCK_M
=
partials_block_size
n_rows
,
n_cols
=
x
.
shape
n_rows_max
=
x
.
shape_max
[
0
]
assert
out_ret
.
shape
==
(
n_cols
,)
TILE_SIZE
=
max
(
1
,
128
//
PARTIALS_BLOCK_M
)
BLOCK_MM
=
PARTIALS_BLOCK_M
*
TILE_SIZE
pids_x
=
cdiv
(
n_rows_max
,
BLOCK_MM
)
pids_y
=
cdiv
(
n_cols
,
32
)
out_partials
=
torch
.
empty
(
(
pids_y
*
32
,
pids_x
*
TILE_SIZE
),
device
=
out_ret
.
device
,
dtype
=
torch
.
int32
)
out_partials
=
torch
.
transpose
(
out_partials
,
0
,
1
)
# output tensors
_sum_bitmatrix_rows
[(
pids_x
,
pids_y
)](
x
.
storage
.
data
,
n_rows
,
x
.
stride
(
0
),
x
.
stride
(
1
),
# input
out_ret
,
# output [final reduction]
out_partials
,
out_partials
.
stride
(
0
),
out_partials
.
stride
(
1
),
out_partials
.
shape
[
1
],
# output [partial reductions]
BLOCK_M
=
PARTIALS_BLOCK_M
,
BLOCK_MM
=
BLOCK_MM
,
# constants
num_warps
=
8
,
)
out_partials
=
out_partials
[:
cdiv
(
n_rows_max
,
PARTIALS_BLOCK_M
),
:]
return
out_ret
,
out_partials
vllm/kvprune_legacy_save/triton_kernels/routing.py
0 → 100644
View file @
d29c39ca
import
torch
import
triton
from
dataclasses
import
dataclass
,
field
from
.routing_details._routing_compute
import
_combined_routing_compute
from
.routing_details._routing_compute
import
_combined_routing_memset
from
.routing_details._routing_compute
import
_routing_clear_bitmatrix
from
.routing_details._expt_data
import
_expt_data_memset
from
.routing_details._expt_data
import
_expt_data_compute
from
.target_info
import
is_hip
@
dataclass
class
GatherIndx
:
"""
Indices for an operation that performs:
Y = X[src_idx, :]
"""
# array such that `dst_idx[src_idx] = arange(0, N)`
src_indx
:
torch
.
Tensor
dst_indx
:
torch
.
Tensor
@
dataclass
class
ScatterIndx
:
"""
Indices for an operation that performs:
Y[dst_idx, :] = X
"""
# array such that `dst_idx[src_idx] = arange(0, N)`
src_indx
:
torch
.
Tensor
dst_indx
:
torch
.
Tensor
@
dataclass
class
ExptData
:
# hist[i] is the number of tokens routed to expert i
hist
:
torch
.
Tensor
# token_offs_raw[i] is the offset of the first token routed
# to expert i in an expert-sorted array
token_offs_raw
:
torch
.
Tensor
# token_offs_pad[block][i] is the offset of the first token routed
# to expert i in an expert-sorted array, assuming histogram
# rounded to the next multiple of `block`
token_offs_pad
:
dict
[
int
,
torch
.
Tensor
]
# block_id_map[block] contain one value for each `pid`` launched by
# the matrix multiplication kernel launched with BLOCK_M=block:
# - the value is -1 if the `pid` has no work to do
# - otherwise, the value is two int16 (packed as an int32) that
# correspond respectively to (1) the expert assigned to
# the tokens processed by this pid; (2) the block assigned to the
# tokens processed by this pid (think `pid_m` in a regular matmul)
# see `test_routing.py` for a reference implementation and more details
block_pid_map
:
dict
[
int
,
torch
.
Tensor
]
def
__post_init__
(
self
):
if
self
.
hist
is
not
None
:
assert
self
.
hist
.
dtype
==
torch
.
int32
if
self
.
token_offs_raw
is
not
None
:
assert
self
.
token_offs_raw
.
dtype
==
torch
.
int32
if
self
.
token_offs_pad
is
not
None
:
for
v
in
self
.
token_offs_pad
.
values
():
assert
v
.
dtype
==
torch
.
int32
if
self
.
block_pid_map
is
not
None
:
for
v
in
self
.
block_pid_map
.
values
():
assert
v
.
dtype
==
torch
.
int32
@
dataclass
class
RoutingData
:
gate_scal
:
torch
.
Tensor
=
field
()
expt_hist
:
torch
.
Tensor
=
field
()
n_expts_tot
:
int
=
field
()
n_expts_act
:
int
=
field
()
expt_data
:
ExptData
=
None
# Used to make perf annotation cleaner: when we use expert sharding, we can
# use this to tell the "expected" number of local tokens per expert, because
# the actual number can vary per each input.
expected_tokens_per_expt
:
int
=
field
(
default
=
None
)
def
n_blocks
(
self
,
n_rows
,
block_m
):
if
n_rows
<=
self
.
n_expts_tot
:
return
n_rows
else
:
return
(
triton
.
cdiv
(
max
(
n_rows
-
self
.
n_expts_tot
+
1
,
0
),
block_m
)
+
self
.
n_expts_tot
-
1
)
# --------------------------
# sort tokens by expert
# --------------------------
class
SortTokens
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
expt_scal
,
expt_indx
,
n_expts_tot
,
bitmatrix
):
HIST_BLOCK_M
=
32
INDX_OFFS_BLOCK_M
=
512
MEMSET_BLOCK
=
1024
cdiv
=
triton
.
cdiv
device
=
expt_scal
.
device
dtype
=
expt_scal
.
dtype
n_tokens_raw
,
_
=
bitmatrix
.
shape
n_tokens_pad
,
n_expts_act
=
expt_scal
.
shape
n_gates_pad
=
n_tokens_pad
*
n_expts_act
hist
,
partial_hist
=
bitmatrix
.
sum
(
partials_block_size
=
HIST_BLOCK_M
)
hist
=
hist
[:
n_expts_tot
]
assert
hist
.
dtype
==
torch
.
int32
# scratchpad
expt_offs
=
torch
.
empty
(
n_expts_tot
,
dtype
=
torch
.
int32
,
device
=
device
)
combined_indx
=
torch
.
empty
(
n_gates_pad
*
2
,
dtype
=
torch
.
int32
,
device
=
device
)
# output
topk_indx
=
combined_indx
[:
n_gates_pad
]
gate_indx
=
combined_indx
[
n_gates_pad
:]
gate_scal
=
torch
.
empty
(
n_gates_pad
,
dtype
=
dtype
,
device
=
device
)
(
token_offs_combined
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
blocks1a
,
blocks2a
,
MEMSET_BLOCK_A
,
HIST2_BLOCK_M
,
block_m_log2_start
,
block_m_num
,
)
=
_compute_expt_data_internal
(
hist
,
n_expts_tot
,
n_gates_pad
)
blocks1b
=
cdiv
(
n_gates_pad
*
2
,
MEMSET_BLOCK
)
+
n_expts_tot
+
1
blocks2b
=
cdiv
(
n_tokens_pad
,
HIST_BLOCK_M
)
_combined_routing_memset
[(
blocks1a
+
blocks1b
,)](
combined_indx
,
n_gates_pad
*
2
,
-
1
,
MEMSET_BLOCK
,
hist
,
#
expt_offs
,
hist
.
shape
[
0
],
n_expts_tot
,
partial_hist
,
# inputs
partial_hist
.
shape
[
0
],
partial_hist
.
stride
(
0
),
partial_hist
.
stride
(
1
),
# outputs
token_offs_combined
,
token_offs_combined
.
stride
(
0
),
#
blocks1a
,
block_pid_map
,
#
block_m_log2_start
,
SIZES
=
block_m_num
,
BLOCK_A
=
MEMSET_BLOCK_A
,
# optimization parameters
BLOCK_N
=
512
,
BLOCK_M
=
INDX_OFFS_BLOCK_M
,
# tunable parameters
)
indx_offs
=
partial_hist
_combined_routing_compute
[(
blocks2a
+
blocks2b
,)](
topk_indx
,
gate_indx
,
gate_scal
,
# outputs
expt_scal
,
expt_indx
,
indx_offs
,
indx_offs
.
stride
(
0
),
indx_offs
.
stride
(
1
),
# inputs
expt_offs
,
n_tokens_raw
,
# input shape
HIST_BLOCK_M
,
n_expts_act
,
# constants
hist
,
token_offs_pad
,
token_offs_pad
.
stride
(
0
),
block_pid_map
,
block_pid_map
.
stride
(
0
),
# outputs
block_m_log2_start
,
block_m_num
,
HIST2_BLOCK_M
,
blocks2a
,
# etc.
)
ctx
.
n_tokens_raw
=
n_tokens_raw
ctx
.
n_tokens_pad
=
n_tokens_pad
ctx
.
n_expts_act
=
n_expts_act
ctx
.
save_for_backward
(
gate_indx
)
return
(
hist
,
topk_indx
,
gate_indx
,
gate_scal
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
)
@
staticmethod
def
backward
(
ctx
,
_0
,
_1
,
_2
,
dgate_scal
,
_3
,
_4
,
_5
):
(
gate_indx
,)
=
ctx
.
saved_tensors
dgate_scal
=
dgate_scal
[
gate_indx
]
dgate_scal
=
dgate_scal
.
reshape
(
ctx
.
n_tokens_pad
,
ctx
.
n_expts_act
)
return
dgate_scal
,
None
,
None
,
None
def
sort_tokens
(
expt_scal
,
expt_indx
,
n_expts_tot
,
bitmatrix
):
return
SortTokens
.
apply
(
expt_scal
,
expt_indx
,
n_expts_tot
,
bitmatrix
)
# --------------------------
# prune routing
# --------------------------
class
PruneRouting
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
expt_scal
,
expt_indx
,
bitmatrix
,
n_expts_tot
,
simulated_ep
):
from
.compaction
import
compaction
n_tokens_pad
=
expt_scal
.
shape
[
0
]
assert
n_expts_tot
%
simulated_ep
==
0
_routing_clear_bitmatrix
[(
n_tokens_pad
,)](
bitmatrix
.
storage
.
data
,
bitmatrix
.
storage
.
data
.
stride
(
0
),
bitmatrix
.
storage
.
data
.
stride
(
1
),
bitmatrix
.
storage
.
data
.
shape
[
1
],
n_expts_tot
//
simulated_ep
,
BLOCK_N
=
512
,
)
# perform compaction to update expt_scal / expt_indx
expt_scal
,
expt_indx
=
compaction
(
expt_scal
,
expt_indx
,
bitmatrix
)
n_expts_tot
=
n_expts_tot
//
simulated_ep
bitmatrix
.
shape
[
-
1
]
=
n_expts_tot
return
expt_scal
,
expt_indx
,
bitmatrix
def
prune_routing
(
expt_scal
,
expt_indx
,
bitmatrix
,
n_expts_tot
,
simulated_ep
):
return
PruneRouting
.
apply
(
expt_scal
,
expt_indx
,
bitmatrix
,
n_expts_tot
,
simulated_ep
)
# --------------------------
# expt_data
# --------------------------
def
log2_power_of_two
(
x
):
assert
x
>
0
and
(
x
&
(
x
-
1
))
==
0
,
"x must be a power of two"
return
x
.
bit_length
()
-
1
block_m_log2_start
=
4
def
_compute_expt_data_internal
(
expt_hist
,
n_expts_tot
,
n_gates
):
MEMSET_BLOCK
=
512
HIST2_BLOCK_M
=
512
device
=
expt_hist
.
device
n_expts_tot
=
n_expts_tot
cdiv
=
triton
.
cdiv
# block_ms are all powers-of-two between 16 and 128 (inclusive)
block_m_log2_end
=
9
if
is_hip
()
else
8
block_m_num
=
block_m_log2_end
-
block_m_log2_start
if
n_gates
<=
n_expts_tot
:
max_n_tiles
=
n_gates
else
:
max_n_tiles
=
(
n_expts_tot
-
1
-
((
n_expts_tot
-
n_gates
-
1
)
//
2
**
block_m_log2_start
)
)
# allocate memory
pad
=
lambda
x
:
cdiv
(
x
,
MEMSET_BLOCK
)
*
MEMSET_BLOCK
dtype
=
torch
.
int32
token_offs_combined
=
torch
.
empty
(
(
block_m_num
+
1
,
pad
(
n_expts_tot
+
1
)),
dtype
=
dtype
,
device
=
device
)
token_offs_raw
=
token_offs_combined
[
0
][:
n_expts_tot
+
1
]
token_offs_pad
=
token_offs_combined
[
1
:]
block_pid_map
=
torch
.
empty
(
(
block_m_num
,
pad
(
max_n_tiles
)),
dtype
=
dtype
,
device
=
device
)
memset_grid
=
torch
.
numel
(
block_pid_map
)
//
MEMSET_BLOCK
# exact division
# compute outputs
token_offs_pad
=
token_offs_pad
[:,
:
n_expts_tot
+
1
]
block_pid_map
=
block_pid_map
[:,
:
max_n_tiles
]
blocks1
=
memset_grid
+
block_m_num
+
1
blocks2
=
n_expts_tot
*
block_m_num
return
(
token_offs_combined
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
blocks1
,
blocks2
,
MEMSET_BLOCK
,
HIST2_BLOCK_M
,
block_m_log2_start
,
block_m_num
,
)
def
_unpack_into_dict
(
x
):
block_m_log2_end
=
block_m_log2_start
+
x
.
shape
[
0
]
x
=
{
2
**
j
:
x
[
i
,
:]
for
i
,
j
in
enumerate
(
range
(
block_m_log2_start
,
block_m_log2_end
))
}
return
x
def
compute_expt_data
(
expt_hist
,
n_expts_tot
,
n_gates
):
if
expt_hist
is
None
:
return
ExptData
(
None
,
None
,
None
,
None
)
# this just computes the kernel arguments:
(
token_offs_combined
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
blocks1
,
blocks2
,
MEMSET_BLOCK
,
HIST2_BLOCK_M
,
block_m_log2_start
,
block_m_num
,
)
=
_compute_expt_data_internal
(
expt_hist
,
n_expts_tot
,
n_gates
)
_expt_data_memset
[(
blocks1
,)](
expt_hist
,
n_expts_tot
,
#
token_offs_combined
,
token_offs_combined
.
stride
(
0
),
#
block_pid_map
,
#
block_m_log2_start
,
SIZES
=
block_m_num
,
BLOCK
=
MEMSET_BLOCK
,
# optimization parameters
num_warps
=
4
,
)
_expt_data_compute
[(
blocks2
,)](
expt_hist
,
token_offs_pad
,
token_offs_pad
.
stride
(
0
),
block_pid_map
,
block_pid_map
.
stride
(
0
),
# outputs
block_m_log2_start
,
SIZES
=
block_m_num
,
BLOCK
=
HIST2_BLOCK_M
,
# optimization parameters
num_warps
=
4
,
)
token_offs_pad
=
_unpack_into_dict
(
token_offs_pad
)
block_pid_map
=
_unpack_into_dict
(
block_pid_map
)
return
ExptData
(
expt_hist
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
)
# --------------------------
# routing
# --------------------------
def
routing_from_bitmatrix
(
bitmatrix
,
expt_scal
,
expt_indx
,
n_expts_tot
,
n_expts_act
):
(
hist
,
topk_indx
,
gate_indx
,
gate_scal
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
)
=
sort_tokens
(
expt_scal
,
expt_indx
,
n_expts_tot
,
bitmatrix
)
token_offs_pad
=
_unpack_into_dict
(
token_offs_pad
)
block_pid_map
=
_unpack_into_dict
(
block_pid_map
)
expt_data
=
ExptData
(
hist
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
)
# pack the matmul data structure
gather_indx
=
GatherIndx
(
src_indx
=
topk_indx
,
dst_indx
=
gate_indx
)
scatter_indx
=
ScatterIndx
(
src_indx
=
gate_indx
,
dst_indx
=
topk_indx
)
return
(
RoutingData
(
gate_scal
,
hist
,
n_expts_tot
,
n_expts_act
,
expt_data
),
gather_indx
,
scatter_indx
,
)
def
routing
(
logits
,
n_expts_act
,
sm_first
=
False
,
expt_indx
=
None
,
simulated_ep
=
1
,
n_rows
=
None
):
from
.topk
import
topk
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
expt_scal
,
expt_indx
,
bitmatrix
=
topk
(
logits
,
n_expts_act
,
#
apply_softmax
=
not
sm_first
,
y_indx
=
expt_indx
,
n_rows
=
n_rows
,
)
n_expts_tot
=
logits
.
shape
[
-
1
]
//
simulated_ep
# mutate bitmatrix
if
simulated_ep
>
1
:
expt_scal
,
expt_indx
,
bitmatrix
=
prune_routing
(
expt_scal
,
expt_indx
,
bitmatrix
,
logits
.
shape
[
-
1
],
simulated_ep
)
return
routing_from_bitmatrix
(
bitmatrix
,
expt_scal
,
expt_indx
,
n_expts_tot
,
n_expts_act
)
# --------------------------
# torch reference
# --------------------------
def
compute_expt_data_torch
(
hist
,
n_expts_tot
,
n_gates
):
# offset for each experts
device
=
hist
.
device
token_offs_raw
=
torch
.
cumsum
(
hist
,
dim
=
0
)
token_offs_raw
=
torch
.
cat
((
torch
.
zeros
(
1
,
device
=
device
),
token_offs_raw
))
token_offs_raw
=
token_offs_raw
.
int
()
# maximum number of tiles for all values of `block_m` considered
block_ms
=
[
16
,
32
,
64
,
128
]
if
is_hip
():
block_ms
.
append
(
256
)
if
n_gates
<=
n_expts_tot
:
max_n_tiles
=
n_gates
else
:
# ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
# ceil_div(x, y): -(-x // y)
max_n_tiles
=
n_expts_tot
-
1
-
((
n_expts_tot
-
n_gates
-
1
)
//
min
(
block_ms
))
# fill up tile offset/infos for each block
token_offs_pad
=
dict
()
block_pid_map
=
dict
()
for
block_m
in
block_ms
:
n_tiles
=
(
hist
+
block_m
-
1
)
//
block_m
# matmul blocks needed
token_offs_pad
[
block_m
]
=
torch
.
cumsum
(
n_tiles
,
dim
=
0
)
token_offs_pad
[
block_m
]
=
torch
.
cat
(
(
torch
.
zeros
(
1
,
device
=
device
),
token_offs_pad
[
block_m
])
)
token_offs_pad
[
block_m
]
=
token_offs_pad
[
block_m
].
int
()
# compute data required to drive ragged batch matmul
block_pid_map
[
block_m
]
=
-
torch
.
ones
(
max_n_tiles
,
dtype
=
torch
.
int32
,
device
=
device
)
# for e in range(n_expts_tot):
# offset = token_offs_pad[block_m][e]
# for b in range(n_tiles[e]):
# block_pid_map[block_m][offset + b] = (b << 16) + e
col
=
torch
.
arange
(
max_n_tiles
,
device
=
device
)
map_vals
=
(
torch
.
arange
(
n_expts_tot
,
device
=
device
)[:,
None
]
+
(
col
<<
16
)[
None
,
:]
)
map_idxs
=
token_offs_pad
[
block_m
][:
-
1
,
None
]
+
col
[
None
,
:]
mask
=
col
[
None
,
:]
<
n_tiles
[:,
None
]
block_pid_map
[
block_m
].
index_put_
((
map_idxs
[
mask
],),
map_vals
.
int
()[
mask
])
return
ExptData
(
hist
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
)
def
topk_torch
(
vals
,
k
,
expt_indx
,
has_user_provided_indx
=
False
):
# topk of experts
if
has_user_provided_indx
:
tk_indx
=
expt_indx
else
:
tk_indx
=
torch
.
argsort
(
-
vals
,
dim
=
1
,
stable
=
True
)[:,
:
k
]
tk_indx
=
tk_indx
.
long
()
tk_val
=
torch
.
take_along_dim
(
vals
,
tk_indx
,
dim
=
1
)
tk_indx
=
tk_indx
.
int
()
return
tk_val
,
tk_indx
def
routing_torch
(
logits
,
n_expts_act
,
sm_first
=
False
,
expt_indx
=
None
,
n_rows
=
None
):
has_user_provided_indx
=
expt_indx
is
not
None
n_gates_pad
=
logits
.
shape
[
0
]
*
n_expts_act
if
n_rows
is
not
None
:
logits
=
logits
[:
n_rows
,
:]
_
,
n_expts_tot
=
logits
.
shape
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
expt_scal
,
expt_indx
=
topk_torch
(
logits
,
n_expts_act
,
expt_indx
,
has_user_provided_indx
=
has_user_provided_indx
)
if
not
sm_first
:
expt_scal
=
torch
.
softmax
(
expt_scal
,
dim
=-
1
)
# sort each token's selections by expert
if
not
has_user_provided_indx
:
expt_indx
,
sort_indices
=
torch
.
sort
(
expt_indx
,
dim
=
1
)
expt_scal
=
torch
.
gather
(
expt_scal
,
1
,
sort_indices
)
# flatten topk data
expt_scal
=
expt_scal
.
reshape
(
-
1
)
expt_indx
=
expt_indx
.
reshape
(
-
1
).
to
(
torch
.
int32
)
# sort by expert_id so experts are contiguous for the matmul
topk_indx
=
torch
.
argsort
(
expt_indx
,
stable
=
True
)
gate_indx
=
torch
.
argsort
(
topk_indx
,
stable
=
True
)
gate_scal
=
expt_scal
[
topk_indx
]
hist
=
torch
.
histc
(
expt_indx
,
bins
=
n_expts_tot
,
max
=
n_expts_tot
-
1
).
int
()
# histogram of tokens over experts
# pack the matmul data structure
gather_indx
=
GatherIndx
(
src_indx
=
topk_indx
.
int
(),
dst_indx
=
gate_indx
.
int
())
scatter_indx
=
ScatterIndx
(
src_indx
=
gate_indx
.
int
(),
dst_indx
=
topk_indx
.
int
())
# compute expt_data
expt_data
=
compute_expt_data_torch
(
hist
,
n_expts_tot
,
n_gates_pad
)
return
(
RoutingData
(
gate_scal
,
hist
,
n_expts_tot
,
n_expts_act
,
expt_data
),
gather_indx
,
scatter_indx
,
)
vllm/kvprune_legacy_save/triton_kernels/routing_details/__init__.py
0 → 100644
View file @
d29c39ca
vllm/kvprune_legacy_save/triton_kernels/routing_details/_expt_data.py
0 → 100644
View file @
d29c39ca
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_cdiv_pow2
(
n
,
log2_k
):
return
(
n
+
((
1
<<
log2_k
)
-
1
))
>>
log2_k
@
triton
.
jit
def
_expt_data_memset
(
Hist
,
n_expts_tot
,
MDStarts
,
tile_starts_stridem
,
MDTileInfo
,
first_tile_dim_log2
,
SIZES
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
if
pid
<=
SIZES
:
MDStarts
+=
pid
*
tile_starts_stridem
x_tile
=
tl
.
zeros
([
BLOCK
],
dtype
=
MDStarts
.
dtype
.
element_ty
)
Tile_ptrs
=
MDStarts
+
tl
.
arange
(
0
,
BLOCK
)
tile_dim_log2
=
tl
.
where
(
pid
==
0
,
0
,
pid
+
first_tile_dim_log2
-
1
)
for
i
in
range
(
0
,
n_expts_tot
+
1
,
BLOCK
):
offs_n
=
tl
.
arange
(
0
,
BLOCK
)
+
i
mask_n0
=
offs_n
<
n_expts_tot
hist_tok
=
tl
.
load
(
Hist
+
offs_n
,
mask
=
mask_n0
,
other
=
0
)
hist_tile
=
_cdiv_pow2
(
hist_tok
,
tile_dim_log2
)
tile_starts
=
tl
.
cumsum
(
hist_tile
,
0
)
+
x_tile
x_tile
+=
tl
.
sum
(
hist_tile
,
0
).
to
(
MDStarts
.
dtype
.
element_ty
)
tl
.
store
(
Tile_ptrs
,
tile_starts
-
hist_tile
)
Tile_ptrs
+=
BLOCK
else
:
pid
-=
SIZES
+
1
TileInfoOut
=
MDTileInfo
+
pid
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
tl
.
store
(
TileInfoOut
,
0xFFFFFFFF
)
@
triton
.
jit
def
_expt_data_compute
(
Hist
,
MDTileStarts
,
tile_starts_stridem
,
MDTileInfo
,
tile_info_stridem
,
first_tile_dim_log2
,
SIZES
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
expt_id
=
pid
//
SIZES
buff_id
=
pid
%
SIZES
MDTileStarts
+=
buff_id
*
tile_starts_stridem
MDTileInfo
+=
buff_id
*
tile_info_stridem
n_tokens
=
tl
.
load
(
Hist
+
expt_id
)
tile_dim_log2
=
first_tile_dim_log2
+
buff_id
n_blocks
=
_cdiv_pow2
(
n_tokens
,
tile_dim_log2
)
tile_off
=
tl
.
load
(
MDTileStarts
+
expt_id
)
MDTileInfo
+=
tile_off
for
block_off
in
range
(
0
,
n_blocks
,
BLOCK
):
block_offs
=
block_off
+
tl
.
arange
(
0
,
BLOCK
)
data
=
(
block_offs
<<
16
)
+
expt_id
tl
.
store
(
MDTileInfo
+
block_offs
,
data
,
mask
=
block_offs
<
n_blocks
)
vllm/kvprune_legacy_save/triton_kernels/routing_details/_routing_compute.py
0 → 100644
View file @
d29c39ca
import
triton
import
triton.language
as
tl
from
._expt_data
import
_expt_data_compute
,
_expt_data_memset
@
triton
.
jit
def
_routing_compute_expt_offs
(
ExpertHist
,
FinalExpertOffs
,
hist_size
,
# histogram
BLOCK_N
:
tl
.
constexpr
,
):
loop_iterations
=
(
hist_size
+
BLOCK_N
-
1
)
//
BLOCK_N
x
=
tl
.
zeros
([
BLOCK_N
],
ExpertHist
.
dtype
.
element_ty
)
for
i
in
range
(
loop_iterations
):
offs_n
=
i
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_n
=
offs_n
<
hist_size
hist2
=
tl
.
load
(
ExpertHist
+
offs_n
,
mask
=
mask_n
)
tok_starts
=
tl
.
cumsum
(
hist2
,
0
)
-
hist2
+
x
x
+=
tl
.
sum
(
hist2
,
0
)
tl
.
store
(
FinalExpertOffs
+
offs_n
,
tok_starts
,
mask
=
mask_n
)
offs_n
+=
BLOCK_N
@
triton
.
jit
def
_routing_compute_indx_offs
(
PartialHist
,
shape_pm
,
stride_pm
,
stride_pn
,
BLOCK_M
:
tl
.
constexpr
,
expt_id
):
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
# iterate over input data
curr_sum
=
0
for
_
in
range
(
0
,
shape_pm
,
BLOCK_M
):
offs
=
offs_m
*
stride_pm
+
expt_id
*
stride_pn
curr
=
tl
.
load
(
PartialHist
+
offs
,
mask
=
offs_m
<
shape_pm
)
out
=
tl
.
cumsum
(
curr
,
0
)
+
curr_sum
curr_sum
+=
tl
.
sum
(
curr
,
0
)
tl
.
store
(
PartialHist
+
offs
,
out
-
curr
,
mask
=
offs_m
<
shape_pm
)
offs_m
+=
BLOCK_M
@
triton
.
jit
def
_keyed_add
(
x
,
y
):
# we keep the key in the upper 16 bits of a uint32:
key_mask
:
tl
.
constexpr
=
0xFFFF0000
kx
=
x
&
key_mask
ky
=
y
&
key_mask
z
=
tl
.
where
(
kx
==
ky
,
x
+
y
-
kx
,
y
)
return
z
@
triton
.
jit
def
_routing_compute_indx
(
pid_m
,
GatherIndx
,
ScatterIndx
,
GateScal
,
ExptScal
,
ExptIndx
,
PartialOffs
,
stride_pm
,
stride_pn
,
TokensStart
,
n_tokens
,
BLOCK_M
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
):
if
isinstance
(
n_tokens
,
tl
.
tensor
)
and
n_tokens
.
dtype
.
is_ptr
():
n_tokens
=
tl
.
load
(
n_tokens
)
n_gates
=
n_tokens
*
N_EXPTS_ACT
tl
.
static_assert
(
N_EXPTS_ACT
*
BLOCK_M
<=
32768
)
local_offs
=
tl
.
arange
(
0
,
N_EXPTS_ACT
*
BLOCK_M
)
offs
=
pid_m
*
BLOCK_M
*
N_EXPTS_ACT
+
local_offs
expert
=
tl
.
load
(
ExptIndx
+
offs
,
mask
=
(
offs
<
n_gates
),
other
=-
1
).
to
(
tl
.
uint32
)
# stable-sort by expert ID:
kv_pairs
=
((
expert
<<
16
)
|
local_offs
).
to
(
tl
.
uint32
)
kv_pairs
=
tl
.
sort
(
kv_pairs
,
0
)
expert
=
kv_pairs
>>
16
offs
=
pid_m
*
BLOCK_M
*
N_EXPTS_ACT
+
(
kv_pairs
&
0xFFFF
)
mask
=
expert
!=
0xFFFF
gate_scal
=
tl
.
load
(
ExptScal
+
offs
,
mask
=
mask
)
# compute run lengths in expert-sorted order:
x
=
kv_pairs
&
0xFFFF0000
|
0x00000001
expts_and_inclusive_run_lengths
=
tl
.
associative_scan
(
x
,
0
,
_keyed_add
)
exclusive_run_lengths
=
(
expts_and_inclusive_run_lengths
-
1
)
&
0xFFFF
gates
=
tl
.
load
(
PartialOffs
+
pid_m
*
stride_pm
+
expert
*
stride_pn
,
mask
=
mask
)
gates
+=
tl
.
load
(
TokensStart
+
expert
,
mask
=
mask
)
gates
+=
exclusive_run_lengths
tl
.
store
(
ScatterIndx
+
offs
,
gates
,
mask
=
mask
)
tl
.
store
(
GatherIndx
+
gates
,
offs
,
mask
=
mask
)
tl
.
store
(
GateScal
+
gates
,
gate_scal
,
mask
=
mask
)
@
triton
.
jit
def
_combined_routing_compute
(
GatherIndx
,
ScatterIndx
,
GateScal
,
ExptScal
,
ExptIndx
,
PartialOffs
,
stride_pm
,
stride_pn
,
TokensStart
,
n_tokens
,
BLOCK_M
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
Hist
,
MDTileStarts
,
tile_starts_stridem
,
MDTileInfo
,
tile_info_stridem
,
first_tile_dim_log2
,
SIZES
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
blocks2a
,
):
pid
=
tl
.
program_id
(
0
)
if
pid
<
blocks2a
:
_expt_data_compute
(
Hist
,
MDTileStarts
,
tile_starts_stridem
,
MDTileInfo
,
tile_info_stridem
,
first_tile_dim_log2
,
SIZES
,
BLOCK
,
)
else
:
pid
-=
blocks2a
_routing_compute_indx
(
pid
,
GatherIndx
,
ScatterIndx
,
GateScal
,
ExptScal
,
ExptIndx
,
PartialOffs
,
stride_pm
,
stride_pn
,
TokensStart
,
n_tokens
,
BLOCK_M
,
N_EXPTS_ACT
,
)
@
triton
.
jit
def
_routing_clear_bitmatrix
(
Bitmatrix
,
stride_bm
,
stride_bn
,
shape_bn
,
cutoff
,
BLOCK_N
:
tl
.
constexpr
):
pid_m
=
tl
.
program_id
(
0
)
cutoff_word
=
cutoff
//
32
cutoff_bit
=
cutoff
%
32
cutoff_mask
=
(
1
<<
(
cutoff_bit
))
-
1
for
start_n
in
range
(
0
,
shape_bn
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
values
=
tl
.
load
(
Bitmatrix
+
pid_m
*
stride_bm
+
offs_n
*
stride_bn
,
mask
=
offs_n
<
shape_bn
)
values
=
tl
.
where
(
offs_n
==
cutoff_word
,
values
&
cutoff_mask
,
values
)
values
=
tl
.
where
(
offs_n
>
cutoff_word
,
0
,
values
)
tl
.
store
(
Bitmatrix
+
pid_m
*
stride_bm
+
offs_n
*
stride_bn
,
values
,
mask
=
offs_n
<
shape_bn
,
)
@
triton
.
jit
def
_combined_routing_memset
(
Indx
,
size
,
sentinel
,
BLOCK
:
tl
.
constexpr
,
ExpertHist
,
FinalExpertOffs
,
hist_size
,
n_expts_tot
,
PartialHist
,
shape_pm
,
stride_pm
,
stride_pn
,
MDStarts
,
tile_starts_stridem
,
blocks1a
,
MDTileInfo
,
first_tile_dim_log2
,
SIZES
:
tl
.
constexpr
,
BLOCK_A
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
"""
This kernel essentially combines 6 different pieces of functionality,
statically branching on the value of tl.program_id(0) to decide which
codepath to take.
pid == 0: create the token cumsum
1 <= pid <= SIZES: create a tile cumsum
SIZES < pid < blocks1a: initialise MDTileInfo to 0xffffffff
blocks1a <= pid < blocks1a + n_expts_tot: compute_indx_offs
pid == blocks1a + n_expts_tot: compute_expt_offs
pid > blocks1a + n_expts_tot: initialise Indx to sentinel
As each of these is a relatively trivial workload, launching them from
this single trampoline is beneficial as they can execute on different
streaming multiprocesses in parallel.
"""
pid
=
tl
.
program_id
(
0
)
if
pid
<
blocks1a
:
_expt_data_memset
(
ExpertHist
,
n_expts_tot
,
MDStarts
,
tile_starts_stridem
,
MDTileInfo
,
first_tile_dim_log2
,
SIZES
,
BLOCK_A
,
)
elif
pid
==
n_expts_tot
+
blocks1a
:
_routing_compute_expt_offs
(
ExpertHist
,
FinalExpertOffs
,
hist_size
,
BLOCK_N
)
elif
pid
<
n_expts_tot
+
blocks1a
:
_routing_compute_indx_offs
(
PartialHist
,
shape_pm
,
stride_pm
,
stride_pn
,
BLOCK_M
,
pid
-
blocks1a
)
else
:
offs
=
(
pid
-
n_expts_tot
-
blocks1a
-
1
)
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
mask
=
offs
<
size
tl
.
store
(
Indx
+
offs
,
sentinel
,
mask
=
mask
)
vllm/kvprune_legacy_save/triton_kernels/specialize.py
0 → 100644
View file @
d29c39ca
import
inspect
import
re
import
textwrap
import
types
import
triton
def
cacheable
(
f
):
"""
A decorator that allow you to write something of the form:
@cacheable
def my_kernel(): return (expression dynamically defining a kernel)
such that it interacts gracefully with triton cache and preload.
"""
g
=
f
()
g
.
fn
.
__name__
=
f
.
__name__
g
.
fn
.
__module__
=
f
.
__module__
g
.
fn
.
__qualname__
=
f
.
__qualname__
g
.
__name__
=
f
.
__name__
g
.
__module__
=
f
.
__module__
g
.
__qualname__
=
f
.
__qualname__
g
.
_fn_name
=
f
"
{
f
.
__module__
}
.
{
f
.
__qualname__
}
"
return
g
def
define_kernel
(
src
,
module
,
attrs
=
None
,
**
extra_globals
):
"""
Dynamically create a Triton function or kernel from a src string,
linking any symbols in the kernel to objects specified by extra_globals.
"""
# create templace function
def
_empty_fn
():
pass
gdict
=
dict
(
**
(
_empty_fn
.
__globals__
))
gdict
.
update
(
extra_globals
)
f
=
types
.
FunctionType
(
_empty_fn
.
__code__
,
gdict
)
f
.
__module__
=
module
.
__name__
src
=
textwrap
.
dedent
(
src
)
src
=
src
[
src
.
find
(
"def "
)
:]
stored_functions
=
[]
function_name
=
src
[
4
:].
split
(
"("
)[
0
].
strip
()
exec_globals
=
gdict
exec_globals
.
update
({
"stored_functions"
:
stored_functions
})
exec
(
src
+
"
\n\n
stored_functions.append("
+
function_name
+
")
\n
"
,
exec_globals
)
f
.
__signature__
=
inspect
.
signature
(
stored_functions
[
0
])
f
.
__name__
=
function_name
f
.
__doc__
=
stored_functions
[
0
].
__doc__
if
attrs
is
None
:
attrs
=
dict
()
f
=
triton
.
JITFunction
(
f
,
**
attrs
)
f
.
_unsafe_update_src
(
src
)
return
f
def
specialize
(
fn
,
module
,
constants
,
tuples
,
name
=
None
,
do_not_specialize
=
tuple
()):
assert
isinstance
(
fn
,
triton
.
runtime
.
jit
.
JITFunction
)
if
name
is
None
:
name
=
f
"
{
fn
.
__name__
}
"
# Get original source code
src
=
inspect
.
getsource
(
fn
.
fn
)
src
=
textwrap
.
dedent
(
src
)
lines
=
src
.
split
(
"
\n
"
)
# Skip decorator and def line
def_idx
=
next
(
i
for
i
,
line
in
enumerate
(
lines
)
if
line
.
strip
().
startswith
(
"def"
))
# separate header vs body LOC
header_end
=
def_idx
while
not
lines
[
header_end
].
rstrip
().
endswith
(
":"
):
header_end
+=
1
body_lines
=
lines
[
header_end
+
1
:]
header_lines
=
lines
[
def_idx
:
header_end
+
1
]
# clean-up header
header_clean
=
[
l
.
split
(
"#"
,
1
)[
0
].
strip
()
# keep code, discard comment
for
l
in
header_lines
if
l
.
split
(
"#"
,
1
)[
0
].
strip
()
# skip blank‑after‑comment lines
]
# decompose arguments
header_src
=
" "
.
join
(
header_clean
)
# turn it into a single line
m
=
re
.
search
(
r
"\((.*)\)\s*:"
,
header_src
)
if
not
m
:
raise
ValueError
(
"Could not parse function header"
)
args_str
=
m
.
group
(
1
)
args
=
[
arg
.
strip
()
for
arg
in
args_str
.
split
(
","
)
if
arg
.
strip
()]
non_specialized_args
=
[]
for
arg
in
args
:
arg_key
=
arg
.
split
(
":"
)[
0
].
split
(
"="
)[
0
].
strip
()
new_args
=
tuples
.
get
(
arg_key
,
[
arg
])
if
arg_key
not
in
constants
:
non_specialized_args
+=
new_args
# add global symbols
spec_fns
=
{
v
.
__name__
:
v
for
k
,
v
in
constants
.
items
()
if
isinstance
(
v
,
triton
.
runtime
.
jit
.
JITFunction
)
}
globals
=
spec_fns
|
fn
.
get_capture_scope
()
# build new source code and define kernel dynamically
new_signature
=
f
"def
{
name
}
(
{
', '
.
join
(
non_specialized_args
)
}
):"
constexpr_lines
=
[
f
"
{
key
}
: tl.constexpr =
{
value
.
__name__
if
callable
(
value
)
else
value
}
"
for
key
,
value
in
constants
.
items
()
]
tuple_lines
=
[
f
"
{
key
}
=
{
'('
+
','
.
join
(
value
)
+
(
','
if
len
(
value
)
>=
1
else
''
)
+
')'
}
"
for
key
,
value
in
tuples
.
items
()
]
new_src
=
"
\n
"
.
join
(
[
"@triton.jit"
,
new_signature
]
+
constexpr_lines
+
tuple_lines
+
body_lines
)
# find function parameters
sig
=
inspect
.
signature
(
triton
.
runtime
.
jit
.
JITFunction
.
__init__
)
params
=
list
(
sig
.
parameters
.
values
())[
2
:]
attrs
=
{
param
.
name
:
getattr
(
fn
,
param
.
name
,
param
.
default
)
for
param
in
params
}
# make a new repr which appends the repr of the specialized functions.
base_repr
=
attrs
[
"repr"
]
def
new_repr
(
specialization
):
ret
=
base_repr
(
specialization
)
for
spec_fn
in
spec_fns
.
values
():
spec_repr
=
spec_fn
.
repr
(
None
)
if
spec_repr
:
spec_repr
=
spec_repr
.
strip
(
"_"
)
if
spec_repr
:
ret
+=
f
"_
{
spec_repr
}
"
return
ret
attrs
[
"repr"
]
=
new_repr
if
do_not_specialize
:
attrs
[
"do_not_specialize"
]
=
do_not_specialize
ret
=
define_kernel
(
new_src
,
module
,
attrs
,
**
globals
)
return
ret
vllm/kvprune_legacy_save/triton_kernels/swiglu.py
0 → 100644
View file @
d29c39ca
from
dataclasses
import
dataclass
from
vllm.kvprune.triton_kernels.numerics
import
InFlexData
,
OutFlexData
import
torch
import
triton
from
.swiglu_details._swiglu
import
_swiglu
,
_swiglu_fn
from
vllm.kvprune.triton_kernels
import
target_info
@
dataclass
(
frozen
=
True
)
class
FlexCtx
:
out_data
:
OutFlexData
=
OutFlexData
()
inp_data
:
InFlexData
=
InFlexData
()
saturate_inf
:
bool
=
False
@
dataclass
(
frozen
=
True
)
class
PrecisionConfig
:
limit
:
float
flex_ctx
:
FlexCtx
=
FlexCtx
()
swiglu_fn
=
_swiglu_fn
class
SwiGLU
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
a
,
alpha
,
precision_config
,
routing_data
):
N
=
a
.
shape
[
-
1
]
M
=
a
.
numel
()
//
N
assert
a
.
stride
()[
-
1
]
==
1
assert
a
.
shape
[
-
1
]
%
2
==
0
out
=
torch
.
empty
(
size
=
(
M
,
N
//
2
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
flex_ctx
=
precision_config
.
flex_ctx
# optimization hyperparameters
BLOCK_M
,
BLOCK_N
=
32
//
a
.
itemsize
,
128
num_warps
=
4
kwargs
=
{
"maxnreg"
:
64
}
if
not
target_info
.
is_hip
()
else
{}
# launch semi-persistent kernel
N_BLOCKS
=
triton
.
cdiv
(
N
//
2
,
BLOCK_N
)
num_sms
=
target_info
.
num_sms
()
if
routing_data
is
not
None
:
waves_per_sm
=
32
if
target_info
.
is_hip
()
else
128
num_pid
=
num_sms
*
(
waves_per_sm
//
num_warps
)
M_BLOCKS
=
max
(
1
,
triton
.
cdiv
(
num_pid
,
N_BLOCKS
))
grid
=
(
min
(
M_BLOCKS
*
N_BLOCKS
,
4
*
num_sms
),)
else
:
M_BLOCKS
=
triton
.
cdiv
(
M
,
BLOCK_M
)
if
M_BLOCKS
*
N_BLOCKS
>=
8
*
num_sms
:
grid
=
(
8
*
num_sms
,)
else
:
grid
=
(
min
(
M_BLOCKS
*
N_BLOCKS
,
4
*
num_sms
),)
n_tokens
=
None
if
routing_data
is
not
None
:
n_tokens
=
routing_data
.
expt_data
.
token_offs_raw
[
routing_data
.
n_expts_tot
]
_swiglu
[
grid
](
flex_ctx
.
out_data
.
reinterpret
(
out
),
flex_ctx
.
out_data
.
expected_scale
,
flex_ctx
.
out_data
.
actual_scale
,
flex_ctx
.
out_data
.
checksum_scale
,
flex_ctx
.
inp_data
.
reinterpret
(
a
),
flex_ctx
.
inp_data
.
scale
,
alpha
,
M
,
N
//
2
,
a
.
shape
[
-
1
],
1
,
out
.
shape
[
-
1
],
1
,
precision_config
.
limit
,
n_tokens
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
EVEN_N
=
(
N
//
2
)
%
BLOCK_N
==
0
,
M_BLOCKS
=
M_BLOCKS
,
N_BLOCKS
=
N_BLOCKS
,
flexpoint_saturate_inf
=
flex_ctx
.
saturate_inf
,
num_warps
=
num_warps
,
**
kwargs
,
)
out
=
out
.
view
(
a
.
shape
[:
-
1
]
+
out
.
shape
[
-
1
:])
return
out
def
swiglu
(
a
,
alpha
,
precision_config
,
routing_data
=
None
):
return
SwiGLU
.
apply
(
a
,
alpha
,
precision_config
,
routing_data
)
def
swiglu_torch
(
a
,
alpha
,
precision_config
):
limit
=
precision_config
.
limit
a_gelu
=
a
[...,
::
2
]
if
limit
is
not
None
:
a_gelu
=
a_gelu
.
clamp
(
max
=
limit
)
a_linear
=
a
[...,
1
::
2
]
if
limit
is
not
None
:
a_linear
=
a_linear
.
clamp
(
min
=-
limit
,
max
=
limit
)
out_gelu
=
a_gelu
*
torch
.
sigmoid
(
alpha
*
a_gelu
)
out
=
out_gelu
*
(
a_linear
+
1
)
return
out
vllm/kvprune_legacy_save/triton_kernels/swiglu_details/__init__.py
0 → 100644
View file @
d29c39ca
vllm/kvprune_legacy_save/triton_kernels/swiglu_details/_swiglu.py
0 → 100644
View file @
d29c39ca
from
vllm.kvprune.triton_kernels.numerics_details.flexpoint
import
(
load_scale
,
float_to_flex
,
update_scale
,
)
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
clip
(
x
,
limit
,
clip_lower
:
tl
.
constexpr
):
res
=
tl
.
minimum
(
x
,
limit
)
if
clip_lower
:
res
=
tl
.
maximum
(
-
limit
,
res
)
return
res
@
triton
.
jit
def
thread_local_absmax
(
x
,
BLOCK_SIZE
:
tl
.
constexpr
,
NUM_THREADS
:
tl
.
constexpr
):
return
tl
.
max
(
tl
.
reshape
(
tl
.
abs
(
x
),
[
NUM_THREADS
,
BLOCK_SIZE
//
NUM_THREADS
],
can_reorder
=
True
),
axis
=
1
,
)
def
swiglu_repr
(
specialization
):
signature
=
specialization
.
signature
constants
=
specialization
.
constants
convert_dtype
=
lambda
dtype
:
"mxfp4"
if
"u8"
in
dtype
else
dtype
dtypes
=
"x"
.
join
([
convert_dtype
(
f
"
{
signature
[
i
][
1
:]
}
"
)
for
i
in
[
"Out"
,
"A"
]])
blocks
=
"x"
.
join
([
f
"
{
constants
[
i
]
}
"
for
i
in
[
"BLOCK_M"
,
"BLOCK_N"
]])
return
f
"_swiglu_
{
dtypes
}
_
{
blocks
}
"
def
swiglu_launch_metadata
(
grid
,
kernel
,
args
):
M
,
N
=
args
[
"M"
],
args
[
"N"
]
ret
=
dict
()
ret
[
"name"
]
=
f
"
{
kernel
.
name
}
[M =
{
M
}
, N =
{
N
}
]"
A
,
Out
=
args
[
"A"
],
args
[
"Out"
]
ret
[
"bytes"
]
=
Out
.
numel
()
*
Out
.
element_size
()
+
A
.
numel
()
*
A
.
element_size
()
return
ret
@
triton
.
jit
def
compute_swiglu
(
gelu
,
linear
,
scale
,
alpha
,
limit
):
gelu
=
gelu
.
to
(
tl
.
float32
)
*
scale
if
limit
is
not
None
:
gelu
=
clip
(
gelu
,
limit
,
clip_lower
=
False
)
linear
=
linear
.
to
(
tl
.
float32
)
*
scale
if
limit
is
not
None
:
linear
=
clip
(
linear
,
limit
,
clip_lower
=
True
)
s
=
gelu
/
(
1
+
tl
.
exp
(
-
alpha
*
gelu
))
return
tl
.
fma
(
s
,
linear
,
s
)
# (s * (linear + 1))
@
triton
.
jit
(
repr
=
lambda
_
:
"_swiglu"
)
def
_swiglu_fn
(
input
,
alpha
,
limit
):
gelu
,
linear
=
tl
.
split
(
tl
.
reshape
(
input
,
(
input
.
shape
[
0
],
input
.
shape
[
1
]
//
2
,
2
)))
return
compute_swiglu
(
gelu
,
linear
,
1.0
,
alpha
,
limit
)
@
triton
.
jit
(
repr
=
swiglu_repr
,
launch_metadata
=
swiglu_launch_metadata
)
def
_swiglu
(
Out
,
OutExpectedScale
,
OutActualScale
,
OutChecksumScale
,
A
,
AScale
,
alpha
,
M
,
N
,
stride_am
,
stride_an
,
stride_outm
,
stride_outn
,
limit
:
tl
.
constexpr
,
NTokens
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
M_BLOCKS
,
N_BLOCKS
,
flexpoint_saturate_inf
:
tl
.
constexpr
,
):
if
NTokens
is
not
None
:
M
=
tl
.
load
(
NTokens
)
M_BLOCKS
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
local_max
=
tl
.
full
([
tl
.
extra
.
cuda
.
num_threads
()],
0.0
,
tl
.
float32
)
a_scale
=
load_scale
(
AScale
)
out_expected_scale
=
load_scale
(
OutExpectedScale
)
for
pid
in
tl
.
range
(
tl
.
program_id
(
0
),
M_BLOCKS
*
N_BLOCKS
,
tl
.
num_programs
(
0
),
num_stages
=
2
):
pid_m
=
pid
//
N_BLOCKS
pid_n
=
pid
%
N_BLOCKS
off_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_m
=
off_m
<
M
mask_n
=
off_n
<
N
packed_off_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
2
*
BLOCK_N
)
//
2
packed_mask_n
=
packed_off_n
<
N
packed_mask_n
=
tl
.
max_constancy
(
packed_mask_n
,
[
16
])
# load a
packed_off_n
=
pid_n
*
2
*
BLOCK_N
+
tl
.
arange
(
0
,
2
*
BLOCK_N
)
packed_offs
=
off_m
[:,
None
]
*
stride_am
+
packed_off_n
[
None
,
:]
*
stride_an
if
EVEN_N
:
a_packed
=
tl
.
load
(
A
+
packed_offs
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
else
:
if
pid_n
*
BLOCK_N
+
BLOCK_N
<=
N
:
a_packed
=
tl
.
load
(
A
+
packed_offs
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
else
:
packed_mask
=
mask_m
[:,
None
]
&
packed_mask_n
[
None
,
:]
a_packed
=
tl
.
load
(
A
+
packed_offs
,
mask
=
packed_mask
,
other
=
0.0
)
a_gelu
,
a_linear
=
tl
.
split
(
tl
.
reshape
(
a_packed
,
(
BLOCK_M
,
BLOCK_N
,
2
)))
out
=
compute_swiglu
(
a_gelu
,
a_linear
,
a_scale
,
alpha
,
limit
)
# update flexpoint stats and divide by scale
# we don't need masking because of the `other` when loading `A`
if
OutActualScale
is
not
None
:
absmax
=
thread_local_absmax
(
out
,
out
.
numel
,
tl
.
extra
.
cuda
.
num_threads
())
local_max
=
tl
.
maximum
(
local_max
,
absmax
)
out
=
float_to_flex
(
out
,
out_expected_scale
,
None
,
# ActualScale: local absmax is tracked and updated after the loop
OutChecksumScale
,
None
,
Out
,
flexpoint_saturate_inf
,
)
mask
=
mask_m
[:,
None
]
if
EVEN_N
else
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
tl
.
store
(
Out
+
off_m
[:,
None
]
*
stride_outm
+
off_n
[
None
,
:]
*
stride_outn
,
out
,
mask
)
update_scale
(
local_max
,
OutActualScale
,
Out
)
vllm/kvprune_legacy_save/triton_kernels/target_info.py
0 → 100644
View file @
d29c39ca
import
torch
import
triton
# ``constexpr_function`` moved across Triton versions; ROCm/vendor wheels often
# only expose ``triton.constexpr_function`` (not ``triton.runtime.jit``).
def
_resolve_constexpr_function
():
fn
=
getattr
(
triton
,
"constexpr_function"
,
None
)
if
fn
is
not
None
:
return
fn
try
:
from
triton.runtime.jit
import
constexpr_function
as
_fn
return
_fn
except
ImportError
:
pass
_jit
=
getattr
(
triton
,
"jit"
,
None
)
if
_jit
is
not
None
:
fn
=
getattr
(
_jit
,
"constexpr_function"
,
None
)
if
fn
is
not
None
:
return
fn
raise
ImportError
(
"Cannot resolve Triton constexpr_function (try: pip install -U triton)"
)
constexpr_function
=
_resolve_constexpr_function
()
__all__
=
[
"cuda_capability_geq"
,
"get_cdna_version"
,
"has_tma_gather"
,
"has_native_mxfp"
,
"is_cuda"
,
"is_hip"
,
"is_hip_cdna3"
,
"is_hip_cdna4"
,
"num_sms"
,
]
try
:
from
triton.language.target_info
import
(
cuda_capability_geq
,
current_target
,
is_cuda
,
is_hip
,
is_hip_cdna3
,
is_hip_cdna4
,
)
except
ImportError
:
# Some ROCm / vendor Triton wheels omit ``triton.language.target_info``.
# Mirror upstream Triton (see triton/language/target_info.py) via runtime.
from
triton.runtime
import
driver
def
current_target
():
try
:
active_driver
=
driver
.
active
except
RuntimeError
:
return
None
return
active_driver
.
get_current_target
()
@
constexpr_function
def
is_cuda
():
target
=
current_target
()
return
target
is
not
None
and
target
.
backend
==
"cuda"
@
constexpr_function
def
is_hip
():
target
=
current_target
()
return
target
is
not
None
and
target
.
backend
==
"hip"
@
constexpr_function
def
cuda_capability_geq
(
major
,
minor
=
0
):
target
=
current_target
()
if
target
is
None
or
target
.
backend
!=
"cuda"
:
return
False
assert
isinstance
(
target
.
arch
,
int
)
return
target
.
arch
>=
major
*
10
+
minor
@
constexpr_function
def
is_hip_cdna3
():
target
=
current_target
()
return
target
is
not
None
and
target
.
arch
==
"gfx942"
@
constexpr_function
def
is_hip_cdna4
():
target
=
current_target
()
return
target
is
not
None
and
target
.
arch
==
"gfx950"
@
constexpr_function
def
get_cdna_version
():
"""
AMD CDNA generation: 3 (gfx942) or 4 (gfx950); -1 if unknown / non-HIP.
"""
target
=
current_target
()
if
target
is
None
or
target
.
backend
!=
"hip"
:
return
-
1
if
target
.
arch
==
"gfx942"
:
return
3
if
target
.
arch
==
"gfx950"
:
return
4
return
-
1
@
constexpr_function
def
has_tma_gather
():
return
cuda_capability_geq
(
10
,
0
)
@
constexpr_function
def
has_native_mxfp
():
return
cuda_capability_geq
(
10
,
0
)
def
num_sms
():
return
torch
.
cuda
.
get_device_properties
(
0
).
multi_processor_count
Prev
1
…
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment