Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d29c39ca
"vscode:/vscode.git/clone" did not exist on "5867819eaffe3c939c0920c15d5048cb7f9129f8"
Commit
d29c39ca
authored
Apr 30, 2026
by
chenzk
Browse files
vllm kvprune wo:v1.1.0
parent
f81ce56b
Changes
246
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2075 additions
and
0 deletions
+2075
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
...ernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
+158
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
...ernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
+125
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/proton_opts.py
...tor-vllm/src/compactor_vllm/triton_kernels/proton_opts.py
+19
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/__init__.py
...mpactor_vllm/triton_kernels/reduction_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/reduce_bitmatrix.py
...vllm/triton_kernels/reduction_details/reduce_bitmatrix.py
+133
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing.py
...mpactor-vllm/src/compactor_vllm/triton_kernels/routing.py
+521
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/__init__.py
...compactor_vllm/triton_kernels/routing_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_expt_data.py
...mpactor_vllm/triton_kernels/routing_details/_expt_data.py
+75
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_routing_compute.py
...r_vllm/triton_kernels/routing_details/_routing_compute.py
+241
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/specialize.py
...ctor-vllm/src/compactor_vllm/triton_kernels/specialize.py
+143
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu.py
...ompactor-vllm/src/compactor_vllm/triton_kernels/swiglu.py
+99
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/__init__.py
.../compactor_vllm/triton_kernels/swiglu_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/_swiglu.py
...c/compactor_vllm/triton_kernels/swiglu_details/_swiglu.py
+141
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/target_info.py
...tor-vllm/src/compactor_vllm/triton_kernels/target_info.py
+54
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor.py
...ompactor-vllm/src/compactor_vllm/triton_kernels/tensor.py
+227
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/__init__.py
.../compactor_vllm/triton_kernels/tensor_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout.py
...rc/compactor_vllm/triton_kernels/tensor_details/layout.py
+40
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/__init__.py
.../triton_kernels/tensor_details/layout_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/base.py
...vllm/triton_kernels/tensor_details/layout_details/base.py
+18
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_scale.py
..._kernels/tensor_details/layout_details/blackwell_scale.py
+81
-0
No files found.
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
0 → 100644
View file @
d29c39ca
import
triton
import
triton.language
as
tl
# fmt: off
MXFP_BLOCK_SIZE
=
tl
.
constexpr
(
32
)
@
triton
.
jit
def
_get_max_quant_val
(
dtype
:
tl
.
constexpr
):
if
dtype
==
tl
.
uint8
:
return
6.0
elif
dtype
==
tl
.
float8e5
:
return
57344.0
elif
dtype
==
tl
.
float8e4nv
:
return
448.0
else
:
tl
.
static_assert
(
False
,
f
"Invalid
{
dtype
=
}
"
)
@
triton
.
jit
def
_compute_quant_and_scale
(
src_tensor
,
valid_src_mask
,
mx_tensor_dtype
:
tl
.
constexpr
,
DEQUANT_SCALE_ROUNDING_MODE
:
tl
.
constexpr
=
0
):
is_fp8
:
tl
.
constexpr
=
mx_tensor_dtype
==
tl
.
float8e4nv
or
mx_tensor_dtype
==
tl
.
float8e5
BLOCK_SIZE_OUT_DIM
:
tl
.
constexpr
=
src_tensor
.
shape
[
0
]
BLOCK_SIZE_QUANT_DIM
:
tl
.
constexpr
=
src_tensor
.
shape
[
1
]
BLOCK_SIZE_QUANT_MX_SCALE
:
tl
.
constexpr
=
src_tensor
.
shape
[
1
]
//
MXFP_BLOCK_SIZE
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
f32_tensor
=
src_tensor
.
to
(
tl
.
float32
)
abs_tensor
=
tl
.
abs
(
f32_tensor
)
abs_tensor
=
tl
.
where
(
valid_src_mask
,
abs_tensor
,
-
1.0
)
# Don't consider padding tensors in scale computation
abs_tensor
=
tl
.
reshape
(
abs_tensor
,
[
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
,
MXFP_BLOCK_SIZE
])
max_val
=
tl
.
max
(
abs_tensor
,
axis
=
2
,
keep_dims
=
True
)
dequant_scale
=
max_val
/
_get_max_quant_val
(
mx_tensor_dtype
)
if
DEQUANT_SCALE_ROUNDING_MODE
==
0
:
# DequantScaleRoundingMode.ROUND_UP
# compute 2 ** ceil(log2(dequant_scale))
# Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
# A corner case: exponent is 0xFF that will overflow but that's already
# NaN so assume we don't care.
dequant_scale_exponent
=
(
dequant_scale
.
to
(
tl
.
uint32
,
bitcast
=
True
)
+
0x007FFFFF
)
&
0x7F800000
else
:
# DequantScaleRoundingMode.ROUND_DOWN
# compute 2 ** floor(log2(dequant_scale))
assert
DEQUANT_SCALE_ROUNDING_MODE
==
1
dequant_scale_exponent
=
dequant_scale
.
to
(
tl
.
uint32
,
bitcast
=
True
)
&
0x7F800000
dequant_scale_rounded
=
dequant_scale_exponent
.
to
(
tl
.
float32
,
bitcast
=
True
)
quant_scale
=
tl
.
where
(
dequant_scale_rounded
==
0
,
0
,
1.0
/
dequant_scale_rounded
)
f32_tensor
=
tl
.
reshape
(
f32_tensor
,
[
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
,
MXFP_BLOCK_SIZE
])
quant_tensor
=
f32_tensor
*
quant_scale
# Reshape the tensors after scaling
quant_tensor
=
quant_tensor
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_DIM
])
# Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
quant_tensor
=
tl
.
where
(
valid_src_mask
,
quant_tensor
,
0
)
dequant_scale_exponent
=
dequant_scale_exponent
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
])
# First, we simply extract the exponent part of the scales and store the result
dequant_scale_exponent
=
(
dequant_scale_exponent
>>
23
).
to
(
tl
.
uint8
)
# Now we must convert the tensors to the mx format.
if
is_fp8
:
out_tensor
=
quant_tensor
.
to
(
mx_tensor_dtype
)
else
:
quant_tensor
=
quant_tensor
.
to
(
tl
.
uint32
,
bitcast
=
True
)
signs
=
quant_tensor
&
0x80000000
exponents
=
(
quant_tensor
>>
23
)
&
0xFF
mantissas
=
(
quant_tensor
&
0x7FFFFF
)
# 0.25 <= x < 0.75 maps to 0.5, a denormal number
E8_BIAS
=
127
E2_BIAS
=
1
# Move implicit bit 1 at the beginning to mantissa for denormals
adjusted_exponents
=
tl
.
core
.
sub
(
E8_BIAS
,
exponents
+
1
,
sanitize_overflow
=
False
)
mantissas
=
tl
.
where
(
exponents
<
E8_BIAS
,
(
0x400000
|
(
mantissas
>>
1
))
>>
adjusted_exponents
,
mantissas
)
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
exponents
=
tl
.
maximum
(
exponents
,
E8_BIAS
-
E2_BIAS
)
-
(
E8_BIAS
-
E2_BIAS
)
# Combine sign, exponent, and mantissa, while saturating
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
e2m1_tmp
=
tl
.
minimum
((((
exponents
<<
2
)
|
(
mantissas
>>
21
))
+
1
)
>>
1
,
0x7
)
e2m1_value
=
((
signs
>>
28
)
|
e2m1_tmp
).
to
(
tl
.
uint8
)
e2m1_value
=
tl
.
reshape
(
e2m1_value
,
[
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_DIM
//
2
,
2
])
evens
,
odds
=
tl
.
split
(
e2m1_value
)
out_tensor
=
evens
|
(
odds
<<
4
)
return
out_tensor
,
dequant_scale_exponent
@
triton
.
jit
def
_downcast_to_mxfp
(
mx_tensor_ptr
,
stride_mxt_outer
,
stride_mxt_quant
:
tl
.
constexpr
,
mx_scale_ptr
,
stride_mx_scale_outer
,
stride_mx_scale_quant
,
src_ptr
,
stride_src_outer
,
stride_src_quant
,
outer_dim
,
quant_dim
,
BLOCK_SIZE_OUT_DIM
:
tl
.
constexpr
,
BLOCK_SIZE_QUANT_DIM
:
tl
.
constexpr
,
DEQUANT_SCALE_ROUNDING_MODE
:
tl
.
constexpr
):
tl
.
static_assert
(
stride_mxt_quant
==
1
,
f
"Output stride,
{
stride_mxt_quant
=
}
must be 1."
)
tl
.
static_assert
(
BLOCK_SIZE_QUANT_DIM
%
MXFP_BLOCK_SIZE
==
0
,
f
"
{
BLOCK_SIZE_QUANT_DIM
=
}
must be a multiple of 32"
)
# uint8 signifies two fp4 e2m1 values packed into a single byte
mx_tensor_dtype
:
tl
.
constexpr
=
mx_tensor_ptr
.
dtype
.
element_ty
tl
.
static_assert
(
mx_tensor_dtype
==
tl
.
uint8
or
(
mx_tensor_dtype
==
tl
.
float8e4nv
or
mx_tensor_dtype
==
tl
.
float8e5
),
f
"Invalid
{
mx_tensor_dtype
=
}
. Must be uint8 or float8."
)
src_dtype
:
tl
.
constexpr
=
src_ptr
.
dtype
.
element_ty
tl
.
static_assert
(
mx_scale_ptr
.
dtype
.
element_ty
==
tl
.
uint8
,
f
"
{
mx_scale_ptr
.
dtype
.
element_ty
=
}
must be uint8"
)
tl
.
static_assert
((
src_dtype
==
tl
.
bfloat16
)
or
(
src_dtype
==
tl
.
float16
)
or
(
src_dtype
==
tl
.
float32
),
f
"
{
src_dtype
=
}
must be bfloat16 or float16 or float32"
)
is_fp4
:
tl
.
constexpr
=
mx_tensor_dtype
==
tl
.
uint8
outer_block
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
quant_block
=
tl
.
program_id
(
1
).
to
(
tl
.
int64
)
K_DIVISOR
:
tl
.
constexpr
=
2
if
is_fp4
else
1
BLOCK_SIZE_QUANT_MX_SCALE
:
tl
.
constexpr
=
BLOCK_SIZE_QUANT_DIM
//
MXFP_BLOCK_SIZE
BLOCK_SIZE_QUANT_MX_TENSOR
:
tl
.
constexpr
=
BLOCK_SIZE_QUANT_DIM
//
K_DIVISOR
start_src_quant
=
quant_block
*
BLOCK_SIZE_QUANT_DIM
start_mx_scale_quant
=
quant_block
*
BLOCK_SIZE_QUANT_MX_SCALE
start_mx_quant
=
quant_block
*
BLOCK_SIZE_QUANT_MX_TENSOR
start_out
=
outer_block
*
BLOCK_SIZE_OUT_DIM
src_ptr
+=
start_src_quant
*
stride_src_quant
+
start_out
*
stride_src_outer
mx_scale_ptr
+=
start_mx_scale_quant
*
stride_mx_scale_quant
+
start_out
*
stride_mx_scale_outer
mx_tensor_ptr
+=
start_mx_quant
*
stride_mxt_quant
+
start_out
*
stride_mxt_outer
offs_src_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_DIM
)[
None
,
:].
to
(
tl
.
int64
)
offs_mxt_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_MX_TENSOR
)[
None
,
:].
to
(
tl
.
int64
)
offs_scale_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_MX_SCALE
)[
None
,
:].
to
(
tl
.
int64
)
offs_outer
=
tl
.
arange
(
0
,
BLOCK_SIZE_OUT_DIM
)[:,
None
].
to
(
tl
.
int64
)
mask_src_quant
=
start_src_quant
+
offs_src_quant
<
quant_dim
mask_n
=
start_out
+
offs_outer
<
outer_dim
full_mask_src
=
mask_src_quant
&
mask_n
mask_mxt_quant
=
start_mx_quant
+
offs_mxt_quant
<
tl
.
cdiv
(
quant_dim
,
K_DIVISOR
)
full_mask_mxt
=
mask_mxt_quant
&
mask_n
scale_mask_k
=
start_mx_scale_quant
+
offs_scale_quant
<
tl
.
cdiv
(
quant_dim
,
MXFP_BLOCK_SIZE
)
full_scale_mask
=
scale_mask_k
&
mask_n
src_tensor_offsets
=
offs_src_quant
*
stride_src_quant
+
offs_outer
*
stride_src_outer
mx_scale_offsets
=
offs_scale_quant
*
stride_mx_scale_quant
+
offs_outer
*
stride_mx_scale_outer
mx_tensor_offsets
=
offs_mxt_quant
*
stride_mxt_quant
+
offs_outer
*
stride_mxt_outer
src_tensor
=
tl
.
load
(
src_ptr
+
src_tensor_offsets
,
mask
=
full_mask_src
)
out_tensor
,
scale_tensor
=
_compute_quant_and_scale
(
src_tensor
,
full_mask_src
,
mx_tensor_dtype
,
DEQUANT_SCALE_ROUNDING_MODE
)
tl
.
store
(
mx_scale_ptr
+
mx_scale_offsets
,
scale_tensor
,
mask
=
full_scale_mask
)
tl
.
store
(
mx_tensor_ptr
+
mx_tensor_offsets
,
out_tensor
,
mask
=
full_mask_mxt
)
@
triton
.
jit
(
repr
=
lambda
_
:
"_dequantize_mxfp8"
)
def
_quantize_mxfp8_fn
(
input
,
mask
,
pid
=
None
):
return
_compute_quant_and_scale
(
input
,
mask
,
tl
.
float8e4nv
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
0 → 100644
View file @
d29c39ca
import
triton
import
triton.language
as
tl
from
._downcast_to_mxfp
import
MXFP_BLOCK_SIZE
# fmt: off
@
triton
.
jit
def
_upcast_from_mxfp
(
out_ptr
,
stride_o_outer
,
stride_o_quant
:
tl
.
constexpr
,
mx_scale_ptr
,
stride_scale_outer
,
stride_scale_quant
,
mx_tensor_ptr
,
stride_tensor_outer
,
stride_tensor_quant
:
tl
.
constexpr
,
outer_dim
,
quant_dim
,
BLOCK_SIZE_OUT_DIM
:
tl
.
constexpr
,
BLOCK_SIZE_QUANT_DIM
:
tl
.
constexpr
):
tl
.
static_assert
(
stride_o_quant
==
1
,
"the weight must be contiguous in the k dimension for mx"
)
tl
.
static_assert
(
BLOCK_SIZE_QUANT_DIM
%
MXFP_BLOCK_SIZE
==
0
,
"BLOCK_SIZE_K must be a multiple of 32"
)
# uint8 signifies two fp4 e2m1 values packed into a single byte
mx_tensor_dtype
:
tl
.
constexpr
=
mx_tensor_ptr
.
dtype
.
element_ty
dst_dtype
:
tl
.
constexpr
=
out_ptr
.
dtype
.
element_ty
tl
.
static_assert
(
dst_dtype
==
tl
.
float16
or
dst_dtype
==
tl
.
bfloat16
or
dst_dtype
==
tl
.
float32
)
tl
.
static_assert
(
mx_tensor_dtype
==
tl
.
uint8
or
((
mx_tensor_dtype
==
tl
.
float8e4nv
or
mx_tensor_dtype
==
tl
.
float8e5
)
or
mx_tensor_dtype
==
dst_dtype
),
"mx_tensor_ptr must be uint8 or float8 or dst_dtype"
)
tl
.
static_assert
(
mx_scale_ptr
.
dtype
.
element_ty
==
tl
.
uint8
,
"mx_scale_ptr must be uint8"
)
# Determine if we are dealing with fp8 types.
is_fp4
:
tl
.
constexpr
=
mx_tensor_dtype
==
tl
.
uint8
is_fp8
:
tl
.
constexpr
=
mx_tensor_dtype
==
tl
.
float8e4nv
or
mx_tensor_dtype
==
tl
.
float8e5
K_DIVISOR
:
tl
.
constexpr
=
2
if
is_fp4
else
1
BLOCK_SIZE_QUANT_MX_SCALE
:
tl
.
constexpr
=
BLOCK_SIZE_QUANT_DIM
//
MXFP_BLOCK_SIZE
BLOCK_SIZE_QUANT_MX_TENSOR
:
tl
.
constexpr
=
BLOCK_SIZE_QUANT_DIM
//
K_DIVISOR
# Compute starting indices for the quantized (packed) dimension and the outer dimension.
outer_block
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
quant_block
=
tl
.
program_id
(
1
).
to
(
tl
.
int64
)
start_mxt_quant
=
quant_block
*
BLOCK_SIZE_QUANT_MX_TENSOR
start_out_quant
=
quant_block
*
BLOCK_SIZE_QUANT_DIM
start_mx_scale_quant
=
quant_block
*
BLOCK_SIZE_QUANT_MX_SCALE
start_out
=
outer_block
*
BLOCK_SIZE_OUT_DIM
mx_tensor_ptr
+=
start_mxt_quant
*
stride_tensor_quant
+
start_out
*
stride_tensor_outer
mx_scale_ptr
+=
start_mx_scale_quant
*
stride_scale_quant
+
start_out
*
stride_scale_outer
out_ptr
+=
start_out
*
stride_o_outer
+
start_out_quant
*
stride_o_quant
# Compute offsets and masks.
offs_src_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_MX_TENSOR
)[
None
,
:].
to
(
tl
.
int64
)
offs_out_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_DIM
)[
None
,
:].
to
(
tl
.
int64
)
offs_outer
=
tl
.
arange
(
0
,
BLOCK_SIZE_OUT_DIM
)[:,
None
].
to
(
tl
.
int64
)
offs_scale
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_MX_SCALE
)[
None
,
:].
to
(
tl
.
int64
)
mask_outer
=
start_out
+
offs_outer
<
outer_dim
mask_out_quant
=
start_out_quant
+
offs_out_quant
<
quant_dim
full_mask_out
=
mask_out_quant
&
mask_outer
mask_src_quant
=
start_mxt_quant
+
offs_src_quant
<
tl
.
cdiv
(
quant_dim
,
K_DIVISOR
)
full_mask_src
=
mask_src_quant
&
mask_outer
mask_scale
=
start_mx_scale_quant
+
offs_scale
<
tl
.
cdiv
(
quant_dim
,
MXFP_BLOCK_SIZE
)
full_scale_mask
=
mask_scale
&
mask_outer
tensor_offsets
=
offs_src_quant
*
stride_tensor_quant
+
offs_outer
*
stride_tensor_outer
scale_offsets
=
offs_scale
*
stride_scale_quant
+
offs_outer
*
stride_scale_outer
out_offsets
=
offs_out_quant
*
stride_o_quant
+
offs_outer
*
stride_o_outer
# Load the packed tensor and scale.
tensor
=
tl
.
load
(
mx_tensor_ptr
+
tensor_offsets
,
mask
=
full_mask_src
)
scale
=
tl
.
load
(
mx_scale_ptr
+
scale_offsets
,
mask
=
full_scale_mask
)
# Upcast the scale to the destination type.
if
dst_dtype
==
tl
.
bfloat16
:
dst_scale
=
(
scale
.
to
(
tl
.
uint16
)
<<
7
).
to
(
dst_dtype
,
bitcast
=
True
)
else
:
dst_scale
=
(
scale
.
to
(
tl
.
uint32
)
<<
23
).
to
(
tl
.
float32
,
bitcast
=
True
)
if
dst_dtype
==
tl
.
float16
:
dst_scale
=
dst_scale
.
to
(
tl
.
float16
)
# Now upcast the tensor.
intermediate_dtype
:
tl
.
constexpr
=
tl
.
bfloat16
if
dst_dtype
==
tl
.
float32
else
dst_dtype
if
is_fp8
:
dst_tensor
=
tensor
.
to
(
intermediate_dtype
)
if
tensor
.
dtype
==
tl
.
float8e5
:
from_e_bits
:
tl
.
constexpr
=
5
from_m_bits
:
tl
.
constexpr
=
2
to_e_bits
:
tl
.
constexpr
=
8
if
intermediate_dtype
==
tl
.
bfloat16
else
5
to_m_bits
:
tl
.
constexpr
=
7
if
intermediate_dtype
==
tl
.
bfloat16
else
10
# Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
non_finite_mask_src
:
tl
.
constexpr
=
((
1
<<
from_e_bits
)
-
1
)
<<
from_m_bits
non_finite_mask_dst
:
tl
.
constexpr
=
((
1
<<
to_e_bits
)
-
1
)
<<
to_m_bits
dst_tensor
=
tl
.
where
(
(
tensor
.
to
(
tl
.
uint8
,
bitcast
=
True
)
&
non_finite_mask_src
)
==
non_finite_mask_src
,
(
dst_tensor
.
to
(
tl
.
uint16
,
bitcast
=
True
)
|
non_finite_mask_dst
).
to
(
intermediate_dtype
,
bitcast
=
True
),
dst_tensor
,
)
else
:
assert
is_fp4
dst_bias
:
tl
.
constexpr
=
127
if
intermediate_dtype
==
tl
.
bfloat16
else
15
dst_0p5
:
tl
.
constexpr
=
16128
if
intermediate_dtype
==
tl
.
bfloat16
else
0x3800
dst_m_bits
:
tl
.
constexpr
=
7
if
intermediate_dtype
==
tl
.
bfloat16
else
10
# e2m1
em0
=
tensor
&
0x07
em1
=
tensor
&
0x70
x0
=
(
em0
.
to
(
tl
.
uint16
)
<<
(
dst_m_bits
-
1
))
|
((
tensor
&
0x08
).
to
(
tl
.
uint16
)
<<
12
)
x1
=
(
em1
.
to
(
tl
.
uint16
)
<<
(
dst_m_bits
-
5
))
|
((
tensor
&
0x80
).
to
(
tl
.
uint16
)
<<
8
)
# Three cases:
# 1) x is normal and non-zero: Correct bias
x0
=
tl
.
where
((
em0
&
0x06
)
!=
0
,
x0
+
((
dst_bias
-
1
)
<<
dst_m_bits
),
x0
)
x1
=
tl
.
where
((
em1
&
0x60
)
!=
0
,
x1
+
((
dst_bias
-
1
)
<<
dst_m_bits
),
x1
)
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
x0
=
tl
.
where
(
em0
==
0x01
,
dst_0p5
|
(
x0
&
0x8000
),
x0
)
x1
=
tl
.
where
(
em1
==
0x10
,
dst_0p5
|
(
x1
&
0x8000
),
x1
)
# 3) x is zero, do nothing
dst_tensor
=
tl
.
interleave
(
x0
,
x1
).
to
(
intermediate_dtype
,
bitcast
=
True
)
dst_tensor
=
dst_tensor
.
to
(
dst_dtype
)
# Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping.
dst_tensor
=
dst_tensor
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
,
MXFP_BLOCK_SIZE
])
dst_scale
=
dst_scale
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
,
1
])
scale
=
scale
.
reshape
(
dst_scale
.
shape
)
out_tensor
=
dst_tensor
*
dst_scale
# Correct any NaNs encoded via the scale.
out_tensor
=
tl
.
where
(
scale
==
0xFF
,
float
(
"nan"
),
out_tensor
)
out_tensor
=
out_tensor
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_DIM
])
tl
.
store
(
out_ptr
+
out_offsets
,
out_tensor
,
mask
=
full_mask_out
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/proton_opts.py
0 → 100644
View file @
d29c39ca
# proton options
import
os
_launch_metadata_allow_sync
=
None
def
launch_metadata_allow_sync
():
global
_launch_metadata_allow_sync
if
_launch_metadata_allow_sync
is
None
:
_launch_metadata_allow_sync
=
not
(
os
.
getenv
(
"PROTON_LAUNCH_METADATA_NOSYNC"
)
==
"1"
)
return
_launch_metadata_allow_sync
def
set_launch_metadata_allow_sync
(
allow_sync
:
bool
):
global
_launch_metadata_allow_sync
_launch_metadata_allow_sync
=
allow_sync
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/__init__.py
0 → 100644
View file @
d29c39ca
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/reduce_bitmatrix.py
0 → 100644
View file @
d29c39ca
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
vpopc
(
x
):
"""
Vertical popcount
Input x : uint32[..., N]
Output y : uint32[..., 32]
semantics : y[..., i] = sum_j((x[..., j] >> i) & 1)
credits: @apgoucher
"""
tl
.
static_assert
(
x
.
dtype
==
tl
.
uint32
,
"x should consist of 32-bit unsigned integers"
)
BLOCK_N
:
tl
.
constexpr
=
x
.
shape
[
-
1
]
# summation axis
BATCHES
:
tl
.
constexpr
=
x
.
numel
//
BLOCK_N
# number of batches
if
BLOCK_N
>=
8
:
sa1
:
tl
.
constexpr
=
8
else
:
sa1
:
tl
.
constexpr
=
BLOCK_N
# create 8-way sums in 4-bit fields:
y
=
tl
.
reshape
(
x
,
[
BATCHES
,
BLOCK_N
//
sa1
,
sa1
,
1
])
y
=
(
y
>>
tl
.
arange
(
0
,
4
)[
None
,
None
,
None
,
:])
&
0x11111111
y
=
tl
.
sum
(
y
,
2
)
# [BATCHES, BLOCK_N // sa1, 4]
if
BLOCK_N
>=
128
:
sa2
:
tl
.
constexpr
=
16
else
:
sa2
:
tl
.
constexpr
=
BLOCK_N
//
sa1
# create 128-way sums in 8-bit fields:
y
=
tl
.
reshape
(
y
,
[
BATCHES
,
BLOCK_N
//
(
sa1
*
sa2
),
sa2
,
1
,
4
])
y
=
(
y
>>
(
4
*
tl
.
arange
(
0
,
2
))[
None
,
None
,
None
,
:,
None
])
&
0x0F0F0F0F
y
=
tl
.
sum
(
y
,
2
)
# [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4]
sa3
:
tl
.
constexpr
=
BLOCK_N
//
(
sa1
*
sa2
)
# create N-way sums in 32-bit fields:
y
=
tl
.
reshape
(
y
,
[
BATCHES
,
1
,
sa3
,
8
])
y
=
(
y
>>
(
8
*
tl
.
arange
(
0
,
4
))[
None
,
:,
None
,
None
])
&
0x000000FF
y
=
tl
.
sum
(
y
,
2
)
# [BATCHES, 4, 8]
y
=
tl
.
reshape
(
y
,
x
.
shape
[:
-
1
]
+
[
32
])
return
y
@
triton
.
jit
def
_sum_bitmatrix_memset
(
Ret
,
BLOCK
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
0
)
offs
=
pid
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
tl
.
store
(
Ret
+
offs
,
0
)
@
triton
.
jit
def
_sum_bitmatrix_rows
(
B
,
shape_bm
,
stride_bm
:
tl
.
constexpr
,
stride_bn
:
tl
.
constexpr
,
# input bitmatrix
Ret
,
Partials
,
stride_pm
:
tl
.
constexpr
,
stride_pn
,
shape_pn
,
# outputs
BLOCK_MM
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
tl
.
static_assert
(
BLOCK_MM
%
BLOCK_M
==
0
)
TILE_SIZE
:
tl
.
constexpr
=
BLOCK_MM
//
BLOCK_M
if
isinstance
(
shape_bm
,
tl
.
tensor
)
and
shape_bm
.
dtype
.
is_ptr
():
shape_bm
=
tl
.
load
(
shape_bm
)
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_MM
+
tl
.
arange
(
0
,
BLOCK_MM
)
offs_n
=
pid_n
*
32
+
tl
.
arange
(
0
,
32
)
n_rows
=
shape_bm
bits
=
tl
.
load
(
B
+
pid_n
*
stride_bn
+
offs_m
*
stride_bm
,
mask
=
offs_m
<
n_rows
,
other
=
0
)
bits
=
tl
.
reshape
(
bits
,
[
TILE_SIZE
,
BLOCK_M
])
ret
=
vpopc
(
bits
)
# [TILE_SIZE, 32]
offs_t
=
pid_m
*
TILE_SIZE
+
tl
.
arange
(
0
,
TILE_SIZE
)
tl
.
atomic_add
(
Ret
+
offs_n
,
tl
.
sum
(
ret
,
0
),
sem
=
"relaxed"
)
tl
.
store
(
Partials
+
offs_t
[:,
None
]
*
stride_pm
+
offs_n
[
None
,
:]
*
stride_pn
,
ret
)
def
clear_sums
(
n_cols
,
device
,
MEMSET_BLOCK
=
512
):
cdiv
=
triton
.
cdiv
blocks
=
cdiv
(
n_cols
,
MEMSET_BLOCK
)
out_ret
=
torch
.
empty
((
blocks
*
MEMSET_BLOCK
,),
device
=
device
,
dtype
=
torch
.
int32
)
_sum_bitmatrix_memset
[(
blocks
,)](
out_ret
,
MEMSET_BLOCK
)
return
out_ret
def
sum_bitmatrix_rows
(
x
,
out_ret
,
partials_block_size
=
None
):
assert
partials_block_size
is
not
None
cdiv
=
triton
.
cdiv
PARTIALS_BLOCK_M
=
partials_block_size
n_rows
,
n_cols
=
x
.
shape
n_rows_max
=
x
.
shape_max
[
0
]
assert
out_ret
.
shape
==
(
n_cols
,)
TILE_SIZE
=
max
(
1
,
128
//
PARTIALS_BLOCK_M
)
BLOCK_MM
=
PARTIALS_BLOCK_M
*
TILE_SIZE
pids_x
=
cdiv
(
n_rows_max
,
BLOCK_MM
)
pids_y
=
cdiv
(
n_cols
,
32
)
out_partials
=
torch
.
empty
(
(
pids_y
*
32
,
pids_x
*
TILE_SIZE
),
device
=
out_ret
.
device
,
dtype
=
torch
.
int32
)
out_partials
=
torch
.
transpose
(
out_partials
,
0
,
1
)
# output tensors
_sum_bitmatrix_rows
[(
pids_x
,
pids_y
)](
x
.
storage
.
data
,
n_rows
,
x
.
stride
(
0
),
x
.
stride
(
1
),
# input
out_ret
,
# output [final reduction]
out_partials
,
out_partials
.
stride
(
0
),
out_partials
.
stride
(
1
),
out_partials
.
shape
[
1
],
# output [partial reductions]
BLOCK_M
=
PARTIALS_BLOCK_M
,
BLOCK_MM
=
BLOCK_MM
,
# constants
num_warps
=
8
,
)
out_partials
=
out_partials
[:
cdiv
(
n_rows_max
,
PARTIALS_BLOCK_M
),
:]
return
out_ret
,
out_partials
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing.py
0 → 100644
View file @
d29c39ca
import
torch
import
triton
from
dataclasses
import
dataclass
,
field
from
.routing_details._routing_compute
import
_combined_routing_compute
from
.routing_details._routing_compute
import
_combined_routing_memset
from
.routing_details._routing_compute
import
_routing_clear_bitmatrix
from
.routing_details._expt_data
import
_expt_data_memset
from
.routing_details._expt_data
import
_expt_data_compute
from
.target_info
import
is_hip
@
dataclass
class
GatherIndx
:
"""
Indices for an operation that performs:
Y = X[src_idx, :]
"""
# array such that `dst_idx[src_idx] = arange(0, N)`
src_indx
:
torch
.
Tensor
dst_indx
:
torch
.
Tensor
@
dataclass
class
ScatterIndx
:
"""
Indices for an operation that performs:
Y[dst_idx, :] = X
"""
# array such that `dst_idx[src_idx] = arange(0, N)`
src_indx
:
torch
.
Tensor
dst_indx
:
torch
.
Tensor
@
dataclass
class
ExptData
:
# hist[i] is the number of tokens routed to expert i
hist
:
torch
.
Tensor
# token_offs_raw[i] is the offset of the first token routed
# to expert i in an expert-sorted array
token_offs_raw
:
torch
.
Tensor
# token_offs_pad[block][i] is the offset of the first token routed
# to expert i in an expert-sorted array, assuming histogram
# rounded to the next multiple of `block`
token_offs_pad
:
dict
[
int
,
torch
.
Tensor
]
# block_id_map[block] contain one value for each `pid`` launched by
# the matrix multiplication kernel launched with BLOCK_M=block:
# - the value is -1 if the `pid` has no work to do
# - otherwise, the value is two int16 (packed as an int32) that
# correspond respectively to (1) the expert assigned to
# the tokens processed by this pid; (2) the block assigned to the
# tokens processed by this pid (think `pid_m` in a regular matmul)
# see `test_routing.py` for a reference implementation and more details
block_pid_map
:
dict
[
int
,
torch
.
Tensor
]
def
__post_init__
(
self
):
if
self
.
hist
is
not
None
:
assert
self
.
hist
.
dtype
==
torch
.
int32
if
self
.
token_offs_raw
is
not
None
:
assert
self
.
token_offs_raw
.
dtype
==
torch
.
int32
if
self
.
token_offs_pad
is
not
None
:
for
v
in
self
.
token_offs_pad
.
values
():
assert
v
.
dtype
==
torch
.
int32
if
self
.
block_pid_map
is
not
None
:
for
v
in
self
.
block_pid_map
.
values
():
assert
v
.
dtype
==
torch
.
int32
@
dataclass
class
RoutingData
:
gate_scal
:
torch
.
Tensor
=
field
()
expt_hist
:
torch
.
Tensor
=
field
()
n_expts_tot
:
int
=
field
()
n_expts_act
:
int
=
field
()
expt_data
:
ExptData
=
None
# Used to make perf annotation cleaner: when we use expert sharding, we can
# use this to tell the "expected" number of local tokens per expert, because
# the actual number can vary per each input.
expected_tokens_per_expt
:
int
=
field
(
default
=
None
)
def
n_blocks
(
self
,
n_rows
,
block_m
):
if
n_rows
<=
self
.
n_expts_tot
:
return
n_rows
else
:
return
(
triton
.
cdiv
(
max
(
n_rows
-
self
.
n_expts_tot
+
1
,
0
),
block_m
)
+
self
.
n_expts_tot
-
1
)
# --------------------------
# sort tokens by expert
# --------------------------
class
SortTokens
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
expt_scal
,
expt_indx
,
n_expts_tot
,
bitmatrix
):
HIST_BLOCK_M
=
32
INDX_OFFS_BLOCK_M
=
512
MEMSET_BLOCK
=
1024
cdiv
=
triton
.
cdiv
device
=
expt_scal
.
device
dtype
=
expt_scal
.
dtype
n_tokens_raw
,
_
=
bitmatrix
.
shape
n_tokens_pad
,
n_expts_act
=
expt_scal
.
shape
n_gates_pad
=
n_tokens_pad
*
n_expts_act
hist
,
partial_hist
=
bitmatrix
.
sum
(
partials_block_size
=
HIST_BLOCK_M
)
hist
=
hist
[:
n_expts_tot
]
assert
hist
.
dtype
==
torch
.
int32
# scratchpad
expt_offs
=
torch
.
empty
(
n_expts_tot
,
dtype
=
torch
.
int32
,
device
=
device
)
combined_indx
=
torch
.
empty
(
n_gates_pad
*
2
,
dtype
=
torch
.
int32
,
device
=
device
)
# output
topk_indx
=
combined_indx
[:
n_gates_pad
]
gate_indx
=
combined_indx
[
n_gates_pad
:]
gate_scal
=
torch
.
empty
(
n_gates_pad
,
dtype
=
dtype
,
device
=
device
)
(
token_offs_combined
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
blocks1a
,
blocks2a
,
MEMSET_BLOCK_A
,
HIST2_BLOCK_M
,
block_m_log2_start
,
block_m_num
,
)
=
_compute_expt_data_internal
(
hist
,
n_expts_tot
,
n_gates_pad
)
blocks1b
=
cdiv
(
n_gates_pad
*
2
,
MEMSET_BLOCK
)
+
n_expts_tot
+
1
blocks2b
=
cdiv
(
n_tokens_pad
,
HIST_BLOCK_M
)
_combined_routing_memset
[(
blocks1a
+
blocks1b
,)](
combined_indx
,
n_gates_pad
*
2
,
-
1
,
MEMSET_BLOCK
,
hist
,
#
expt_offs
,
hist
.
shape
[
0
],
n_expts_tot
,
partial_hist
,
# inputs
partial_hist
.
shape
[
0
],
partial_hist
.
stride
(
0
),
partial_hist
.
stride
(
1
),
# outputs
token_offs_combined
,
token_offs_combined
.
stride
(
0
),
#
blocks1a
,
block_pid_map
,
#
block_m_log2_start
,
SIZES
=
block_m_num
,
BLOCK_A
=
MEMSET_BLOCK_A
,
# optimization parameters
BLOCK_N
=
512
,
BLOCK_M
=
INDX_OFFS_BLOCK_M
,
# tunable parameters
)
indx_offs
=
partial_hist
_combined_routing_compute
[(
blocks2a
+
blocks2b
,)](
topk_indx
,
gate_indx
,
gate_scal
,
# outputs
expt_scal
,
expt_indx
,
indx_offs
,
indx_offs
.
stride
(
0
),
indx_offs
.
stride
(
1
),
# inputs
expt_offs
,
n_tokens_raw
,
# input shape
HIST_BLOCK_M
,
n_expts_act
,
# constants
hist
,
token_offs_pad
,
token_offs_pad
.
stride
(
0
),
block_pid_map
,
block_pid_map
.
stride
(
0
),
# outputs
block_m_log2_start
,
block_m_num
,
HIST2_BLOCK_M
,
blocks2a
,
# etc.
)
ctx
.
n_tokens_raw
=
n_tokens_raw
ctx
.
n_tokens_pad
=
n_tokens_pad
ctx
.
n_expts_act
=
n_expts_act
ctx
.
save_for_backward
(
gate_indx
)
return
(
hist
,
topk_indx
,
gate_indx
,
gate_scal
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
)
@
staticmethod
def
backward
(
ctx
,
_0
,
_1
,
_2
,
dgate_scal
,
_3
,
_4
,
_5
):
(
gate_indx
,)
=
ctx
.
saved_tensors
dgate_scal
=
dgate_scal
[
gate_indx
]
dgate_scal
=
dgate_scal
.
reshape
(
ctx
.
n_tokens_pad
,
ctx
.
n_expts_act
)
return
dgate_scal
,
None
,
None
,
None
def
sort_tokens
(
expt_scal
,
expt_indx
,
n_expts_tot
,
bitmatrix
):
return
SortTokens
.
apply
(
expt_scal
,
expt_indx
,
n_expts_tot
,
bitmatrix
)
# --------------------------
# prune routing
# --------------------------
class
PruneRouting
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
expt_scal
,
expt_indx
,
bitmatrix
,
n_expts_tot
,
simulated_ep
):
from
.compaction
import
compaction
n_tokens_pad
=
expt_scal
.
shape
[
0
]
assert
n_expts_tot
%
simulated_ep
==
0
_routing_clear_bitmatrix
[(
n_tokens_pad
,)](
bitmatrix
.
storage
.
data
,
bitmatrix
.
storage
.
data
.
stride
(
0
),
bitmatrix
.
storage
.
data
.
stride
(
1
),
bitmatrix
.
storage
.
data
.
shape
[
1
],
n_expts_tot
//
simulated_ep
,
BLOCK_N
=
512
,
)
# perform compaction to update expt_scal / expt_indx
expt_scal
,
expt_indx
=
compaction
(
expt_scal
,
expt_indx
,
bitmatrix
)
n_expts_tot
=
n_expts_tot
//
simulated_ep
bitmatrix
.
shape
[
-
1
]
=
n_expts_tot
return
expt_scal
,
expt_indx
,
bitmatrix
def
prune_routing
(
expt_scal
,
expt_indx
,
bitmatrix
,
n_expts_tot
,
simulated_ep
):
return
PruneRouting
.
apply
(
expt_scal
,
expt_indx
,
bitmatrix
,
n_expts_tot
,
simulated_ep
)
# --------------------------
# expt_data
# --------------------------
def
log2_power_of_two
(
x
):
assert
x
>
0
and
(
x
&
(
x
-
1
))
==
0
,
"x must be a power of two"
return
x
.
bit_length
()
-
1
block_m_log2_start
=
4
def
_compute_expt_data_internal
(
expt_hist
,
n_expts_tot
,
n_gates
):
MEMSET_BLOCK
=
512
HIST2_BLOCK_M
=
512
device
=
expt_hist
.
device
n_expts_tot
=
n_expts_tot
cdiv
=
triton
.
cdiv
# block_ms are all powers-of-two between 16 and 128 (inclusive)
block_m_log2_end
=
9
if
is_hip
()
else
8
block_m_num
=
block_m_log2_end
-
block_m_log2_start
if
n_gates
<=
n_expts_tot
:
max_n_tiles
=
n_gates
else
:
max_n_tiles
=
(
n_expts_tot
-
1
-
((
n_expts_tot
-
n_gates
-
1
)
//
2
**
block_m_log2_start
)
)
# allocate memory
pad
=
lambda
x
:
cdiv
(
x
,
MEMSET_BLOCK
)
*
MEMSET_BLOCK
dtype
=
torch
.
int32
token_offs_combined
=
torch
.
empty
(
(
block_m_num
+
1
,
pad
(
n_expts_tot
+
1
)),
dtype
=
dtype
,
device
=
device
)
token_offs_raw
=
token_offs_combined
[
0
][:
n_expts_tot
+
1
]
token_offs_pad
=
token_offs_combined
[
1
:]
block_pid_map
=
torch
.
empty
(
(
block_m_num
,
pad
(
max_n_tiles
)),
dtype
=
dtype
,
device
=
device
)
memset_grid
=
torch
.
numel
(
block_pid_map
)
//
MEMSET_BLOCK
# exact division
# compute outputs
token_offs_pad
=
token_offs_pad
[:,
:
n_expts_tot
+
1
]
block_pid_map
=
block_pid_map
[:,
:
max_n_tiles
]
blocks1
=
memset_grid
+
block_m_num
+
1
blocks2
=
n_expts_tot
*
block_m_num
return
(
token_offs_combined
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
blocks1
,
blocks2
,
MEMSET_BLOCK
,
HIST2_BLOCK_M
,
block_m_log2_start
,
block_m_num
,
)
def
_unpack_into_dict
(
x
):
block_m_log2_end
=
block_m_log2_start
+
x
.
shape
[
0
]
x
=
{
2
**
j
:
x
[
i
,
:]
for
i
,
j
in
enumerate
(
range
(
block_m_log2_start
,
block_m_log2_end
))
}
return
x
def
compute_expt_data
(
expt_hist
,
n_expts_tot
,
n_gates
):
if
expt_hist
is
None
:
return
ExptData
(
None
,
None
,
None
,
None
)
# this just computes the kernel arguments:
(
token_offs_combined
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
blocks1
,
blocks2
,
MEMSET_BLOCK
,
HIST2_BLOCK_M
,
block_m_log2_start
,
block_m_num
,
)
=
_compute_expt_data_internal
(
expt_hist
,
n_expts_tot
,
n_gates
)
_expt_data_memset
[(
blocks1
,)](
expt_hist
,
n_expts_tot
,
#
token_offs_combined
,
token_offs_combined
.
stride
(
0
),
#
block_pid_map
,
#
block_m_log2_start
,
SIZES
=
block_m_num
,
BLOCK
=
MEMSET_BLOCK
,
# optimization parameters
num_warps
=
4
,
)
_expt_data_compute
[(
blocks2
,)](
expt_hist
,
token_offs_pad
,
token_offs_pad
.
stride
(
0
),
block_pid_map
,
block_pid_map
.
stride
(
0
),
# outputs
block_m_log2_start
,
SIZES
=
block_m_num
,
BLOCK
=
HIST2_BLOCK_M
,
# optimization parameters
num_warps
=
4
,
)
token_offs_pad
=
_unpack_into_dict
(
token_offs_pad
)
block_pid_map
=
_unpack_into_dict
(
block_pid_map
)
return
ExptData
(
expt_hist
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
)
# --------------------------
# routing
# --------------------------
def
routing_from_bitmatrix
(
bitmatrix
,
expt_scal
,
expt_indx
,
n_expts_tot
,
n_expts_act
):
(
hist
,
topk_indx
,
gate_indx
,
gate_scal
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
)
=
sort_tokens
(
expt_scal
,
expt_indx
,
n_expts_tot
,
bitmatrix
)
token_offs_pad
=
_unpack_into_dict
(
token_offs_pad
)
block_pid_map
=
_unpack_into_dict
(
block_pid_map
)
expt_data
=
ExptData
(
hist
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
)
# pack the matmul data structure
gather_indx
=
GatherIndx
(
src_indx
=
topk_indx
,
dst_indx
=
gate_indx
)
scatter_indx
=
ScatterIndx
(
src_indx
=
gate_indx
,
dst_indx
=
topk_indx
)
return
(
RoutingData
(
gate_scal
,
hist
,
n_expts_tot
,
n_expts_act
,
expt_data
),
gather_indx
,
scatter_indx
,
)
def
routing
(
logits
,
n_expts_act
,
sm_first
=
False
,
expt_indx
=
None
,
simulated_ep
=
1
,
n_rows
=
None
):
from
.topk
import
topk
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
expt_scal
,
expt_indx
,
bitmatrix
=
topk
(
logits
,
n_expts_act
,
#
apply_softmax
=
not
sm_first
,
y_indx
=
expt_indx
,
n_rows
=
n_rows
,
)
n_expts_tot
=
logits
.
shape
[
-
1
]
//
simulated_ep
# mutate bitmatrix
if
simulated_ep
>
1
:
expt_scal
,
expt_indx
,
bitmatrix
=
prune_routing
(
expt_scal
,
expt_indx
,
bitmatrix
,
logits
.
shape
[
-
1
],
simulated_ep
)
return
routing_from_bitmatrix
(
bitmatrix
,
expt_scal
,
expt_indx
,
n_expts_tot
,
n_expts_act
)
# --------------------------
# torch reference
# --------------------------
def
compute_expt_data_torch
(
hist
,
n_expts_tot
,
n_gates
):
# offset for each experts
device
=
hist
.
device
token_offs_raw
=
torch
.
cumsum
(
hist
,
dim
=
0
)
token_offs_raw
=
torch
.
cat
((
torch
.
zeros
(
1
,
device
=
device
),
token_offs_raw
))
token_offs_raw
=
token_offs_raw
.
int
()
# maximum number of tiles for all values of `block_m` considered
block_ms
=
[
16
,
32
,
64
,
128
]
if
is_hip
():
block_ms
.
append
(
256
)
if
n_gates
<=
n_expts_tot
:
max_n_tiles
=
n_gates
else
:
# ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
# ceil_div(x, y): -(-x // y)
max_n_tiles
=
n_expts_tot
-
1
-
((
n_expts_tot
-
n_gates
-
1
)
//
min
(
block_ms
))
# fill up tile offset/infos for each block
token_offs_pad
=
dict
()
block_pid_map
=
dict
()
for
block_m
in
block_ms
:
n_tiles
=
(
hist
+
block_m
-
1
)
//
block_m
# matmul blocks needed
token_offs_pad
[
block_m
]
=
torch
.
cumsum
(
n_tiles
,
dim
=
0
)
token_offs_pad
[
block_m
]
=
torch
.
cat
(
(
torch
.
zeros
(
1
,
device
=
device
),
token_offs_pad
[
block_m
])
)
token_offs_pad
[
block_m
]
=
token_offs_pad
[
block_m
].
int
()
# compute data required to drive ragged batch matmul
block_pid_map
[
block_m
]
=
-
torch
.
ones
(
max_n_tiles
,
dtype
=
torch
.
int32
,
device
=
device
)
# for e in range(n_expts_tot):
# offset = token_offs_pad[block_m][e]
# for b in range(n_tiles[e]):
# block_pid_map[block_m][offset + b] = (b << 16) + e
col
=
torch
.
arange
(
max_n_tiles
,
device
=
device
)
map_vals
=
(
torch
.
arange
(
n_expts_tot
,
device
=
device
)[:,
None
]
+
(
col
<<
16
)[
None
,
:]
)
map_idxs
=
token_offs_pad
[
block_m
][:
-
1
,
None
]
+
col
[
None
,
:]
mask
=
col
[
None
,
:]
<
n_tiles
[:,
None
]
block_pid_map
[
block_m
].
index_put_
((
map_idxs
[
mask
],),
map_vals
.
int
()[
mask
])
return
ExptData
(
hist
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
)
def
topk_torch
(
vals
,
k
,
expt_indx
,
has_user_provided_indx
=
False
):
# topk of experts
if
has_user_provided_indx
:
tk_indx
=
expt_indx
else
:
tk_indx
=
torch
.
argsort
(
-
vals
,
dim
=
1
,
stable
=
True
)[:,
:
k
]
tk_indx
=
tk_indx
.
long
()
tk_val
=
torch
.
take_along_dim
(
vals
,
tk_indx
,
dim
=
1
)
tk_indx
=
tk_indx
.
int
()
return
tk_val
,
tk_indx
def
routing_torch
(
logits
,
n_expts_act
,
sm_first
=
False
,
expt_indx
=
None
,
n_rows
=
None
):
has_user_provided_indx
=
expt_indx
is
not
None
n_gates_pad
=
logits
.
shape
[
0
]
*
n_expts_act
if
n_rows
is
not
None
:
logits
=
logits
[:
n_rows
,
:]
_
,
n_expts_tot
=
logits
.
shape
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
expt_scal
,
expt_indx
=
topk_torch
(
logits
,
n_expts_act
,
expt_indx
,
has_user_provided_indx
=
has_user_provided_indx
)
if
not
sm_first
:
expt_scal
=
torch
.
softmax
(
expt_scal
,
dim
=-
1
)
# sort each token's selections by expert
if
not
has_user_provided_indx
:
expt_indx
,
sort_indices
=
torch
.
sort
(
expt_indx
,
dim
=
1
)
expt_scal
=
torch
.
gather
(
expt_scal
,
1
,
sort_indices
)
# flatten topk data
expt_scal
=
expt_scal
.
reshape
(
-
1
)
expt_indx
=
expt_indx
.
reshape
(
-
1
).
to
(
torch
.
int32
)
# sort by expert_id so experts are contiguous for the matmul
topk_indx
=
torch
.
argsort
(
expt_indx
,
stable
=
True
)
gate_indx
=
torch
.
argsort
(
topk_indx
,
stable
=
True
)
gate_scal
=
expt_scal
[
topk_indx
]
hist
=
torch
.
histc
(
expt_indx
,
bins
=
n_expts_tot
,
max
=
n_expts_tot
-
1
).
int
()
# histogram of tokens over experts
# pack the matmul data structure
gather_indx
=
GatherIndx
(
src_indx
=
topk_indx
.
int
(),
dst_indx
=
gate_indx
.
int
())
scatter_indx
=
ScatterIndx
(
src_indx
=
gate_indx
.
int
(),
dst_indx
=
topk_indx
.
int
())
# compute expt_data
expt_data
=
compute_expt_data_torch
(
hist
,
n_expts_tot
,
n_gates_pad
)
return
(
RoutingData
(
gate_scal
,
hist
,
n_expts_tot
,
n_expts_act
,
expt_data
),
gather_indx
,
scatter_indx
,
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/__init__.py
0 → 100644
View file @
d29c39ca
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_expt_data.py
0 → 100644
View file @
d29c39ca
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_cdiv_pow2
(
n
,
log2_k
):
return
(
n
+
((
1
<<
log2_k
)
-
1
))
>>
log2_k
@
triton
.
jit
def
_expt_data_memset
(
Hist
,
n_expts_tot
,
MDStarts
,
tile_starts_stridem
,
MDTileInfo
,
first_tile_dim_log2
,
SIZES
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
if
pid
<=
SIZES
:
MDStarts
+=
pid
*
tile_starts_stridem
x_tile
=
tl
.
zeros
([
BLOCK
],
dtype
=
MDStarts
.
dtype
.
element_ty
)
Tile_ptrs
=
MDStarts
+
tl
.
arange
(
0
,
BLOCK
)
tile_dim_log2
=
tl
.
where
(
pid
==
0
,
0
,
pid
+
first_tile_dim_log2
-
1
)
for
i
in
range
(
0
,
n_expts_tot
+
1
,
BLOCK
):
offs_n
=
tl
.
arange
(
0
,
BLOCK
)
+
i
mask_n0
=
offs_n
<
n_expts_tot
hist_tok
=
tl
.
load
(
Hist
+
offs_n
,
mask
=
mask_n0
,
other
=
0
)
hist_tile
=
_cdiv_pow2
(
hist_tok
,
tile_dim_log2
)
tile_starts
=
tl
.
cumsum
(
hist_tile
,
0
)
+
x_tile
x_tile
+=
tl
.
sum
(
hist_tile
,
0
).
to
(
MDStarts
.
dtype
.
element_ty
)
tl
.
store
(
Tile_ptrs
,
tile_starts
-
hist_tile
)
Tile_ptrs
+=
BLOCK
else
:
pid
-=
SIZES
+
1
TileInfoOut
=
MDTileInfo
+
pid
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
tl
.
store
(
TileInfoOut
,
0xFFFFFFFF
)
@
triton
.
jit
def
_expt_data_compute
(
Hist
,
MDTileStarts
,
tile_starts_stridem
,
MDTileInfo
,
tile_info_stridem
,
first_tile_dim_log2
,
SIZES
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
expt_id
=
pid
//
SIZES
buff_id
=
pid
%
SIZES
MDTileStarts
+=
buff_id
*
tile_starts_stridem
MDTileInfo
+=
buff_id
*
tile_info_stridem
n_tokens
=
tl
.
load
(
Hist
+
expt_id
)
tile_dim_log2
=
first_tile_dim_log2
+
buff_id
n_blocks
=
_cdiv_pow2
(
n_tokens
,
tile_dim_log2
)
tile_off
=
tl
.
load
(
MDTileStarts
+
expt_id
)
MDTileInfo
+=
tile_off
for
block_off
in
range
(
0
,
n_blocks
,
BLOCK
):
block_offs
=
block_off
+
tl
.
arange
(
0
,
BLOCK
)
data
=
(
block_offs
<<
16
)
+
expt_id
tl
.
store
(
MDTileInfo
+
block_offs
,
data
,
mask
=
block_offs
<
n_blocks
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_routing_compute.py
0 → 100644
View file @
d29c39ca
import
triton
import
triton.language
as
tl
from
._expt_data
import
_expt_data_compute
,
_expt_data_memset
@
triton
.
jit
def
_routing_compute_expt_offs
(
ExpertHist
,
FinalExpertOffs
,
hist_size
,
# histogram
BLOCK_N
:
tl
.
constexpr
,
):
loop_iterations
=
(
hist_size
+
BLOCK_N
-
1
)
//
BLOCK_N
x
=
tl
.
zeros
([
BLOCK_N
],
ExpertHist
.
dtype
.
element_ty
)
for
i
in
range
(
loop_iterations
):
offs_n
=
i
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_n
=
offs_n
<
hist_size
hist2
=
tl
.
load
(
ExpertHist
+
offs_n
,
mask
=
mask_n
)
tok_starts
=
tl
.
cumsum
(
hist2
,
0
)
-
hist2
+
x
x
+=
tl
.
sum
(
hist2
,
0
)
tl
.
store
(
FinalExpertOffs
+
offs_n
,
tok_starts
,
mask
=
mask_n
)
offs_n
+=
BLOCK_N
@
triton
.
jit
def
_routing_compute_indx_offs
(
PartialHist
,
shape_pm
,
stride_pm
,
stride_pn
,
BLOCK_M
:
tl
.
constexpr
,
expt_id
):
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
# iterate over input data
curr_sum
=
0
for
_
in
range
(
0
,
shape_pm
,
BLOCK_M
):
offs
=
offs_m
*
stride_pm
+
expt_id
*
stride_pn
curr
=
tl
.
load
(
PartialHist
+
offs
,
mask
=
offs_m
<
shape_pm
)
out
=
tl
.
cumsum
(
curr
,
0
)
+
curr_sum
curr_sum
+=
tl
.
sum
(
curr
,
0
)
tl
.
store
(
PartialHist
+
offs
,
out
-
curr
,
mask
=
offs_m
<
shape_pm
)
offs_m
+=
BLOCK_M
@
triton
.
jit
def
_keyed_add
(
x
,
y
):
# we keep the key in the upper 16 bits of a uint32:
key_mask
:
tl
.
constexpr
=
0xFFFF0000
kx
=
x
&
key_mask
ky
=
y
&
key_mask
z
=
tl
.
where
(
kx
==
ky
,
x
+
y
-
kx
,
y
)
return
z
@
triton
.
jit
def
_routing_compute_indx
(
pid_m
,
GatherIndx
,
ScatterIndx
,
GateScal
,
ExptScal
,
ExptIndx
,
PartialOffs
,
stride_pm
,
stride_pn
,
TokensStart
,
n_tokens
,
BLOCK_M
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
):
if
isinstance
(
n_tokens
,
tl
.
tensor
)
and
n_tokens
.
dtype
.
is_ptr
():
n_tokens
=
tl
.
load
(
n_tokens
)
n_gates
=
n_tokens
*
N_EXPTS_ACT
tl
.
static_assert
(
N_EXPTS_ACT
*
BLOCK_M
<=
32768
)
local_offs
=
tl
.
arange
(
0
,
N_EXPTS_ACT
*
BLOCK_M
)
offs
=
pid_m
*
BLOCK_M
*
N_EXPTS_ACT
+
local_offs
expert
=
tl
.
load
(
ExptIndx
+
offs
,
mask
=
(
offs
<
n_gates
),
other
=-
1
).
to
(
tl
.
uint32
)
# stable-sort by expert ID:
kv_pairs
=
((
expert
<<
16
)
|
local_offs
).
to
(
tl
.
uint32
)
kv_pairs
=
tl
.
sort
(
kv_pairs
,
0
)
expert
=
kv_pairs
>>
16
offs
=
pid_m
*
BLOCK_M
*
N_EXPTS_ACT
+
(
kv_pairs
&
0xFFFF
)
mask
=
expert
!=
0xFFFF
gate_scal
=
tl
.
load
(
ExptScal
+
offs
,
mask
=
mask
)
# compute run lengths in expert-sorted order:
x
=
kv_pairs
&
0xFFFF0000
|
0x00000001
expts_and_inclusive_run_lengths
=
tl
.
associative_scan
(
x
,
0
,
_keyed_add
)
exclusive_run_lengths
=
(
expts_and_inclusive_run_lengths
-
1
)
&
0xFFFF
gates
=
tl
.
load
(
PartialOffs
+
pid_m
*
stride_pm
+
expert
*
stride_pn
,
mask
=
mask
)
gates
+=
tl
.
load
(
TokensStart
+
expert
,
mask
=
mask
)
gates
+=
exclusive_run_lengths
tl
.
store
(
ScatterIndx
+
offs
,
gates
,
mask
=
mask
)
tl
.
store
(
GatherIndx
+
gates
,
offs
,
mask
=
mask
)
tl
.
store
(
GateScal
+
gates
,
gate_scal
,
mask
=
mask
)
@
triton
.
jit
def
_combined_routing_compute
(
GatherIndx
,
ScatterIndx
,
GateScal
,
ExptScal
,
ExptIndx
,
PartialOffs
,
stride_pm
,
stride_pn
,
TokensStart
,
n_tokens
,
BLOCK_M
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
Hist
,
MDTileStarts
,
tile_starts_stridem
,
MDTileInfo
,
tile_info_stridem
,
first_tile_dim_log2
,
SIZES
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
blocks2a
,
):
pid
=
tl
.
program_id
(
0
)
if
pid
<
blocks2a
:
_expt_data_compute
(
Hist
,
MDTileStarts
,
tile_starts_stridem
,
MDTileInfo
,
tile_info_stridem
,
first_tile_dim_log2
,
SIZES
,
BLOCK
,
)
else
:
pid
-=
blocks2a
_routing_compute_indx
(
pid
,
GatherIndx
,
ScatterIndx
,
GateScal
,
ExptScal
,
ExptIndx
,
PartialOffs
,
stride_pm
,
stride_pn
,
TokensStart
,
n_tokens
,
BLOCK_M
,
N_EXPTS_ACT
,
)
@
triton
.
jit
def
_routing_clear_bitmatrix
(
Bitmatrix
,
stride_bm
,
stride_bn
,
shape_bn
,
cutoff
,
BLOCK_N
:
tl
.
constexpr
):
pid_m
=
tl
.
program_id
(
0
)
cutoff_word
=
cutoff
//
32
cutoff_bit
=
cutoff
%
32
cutoff_mask
=
(
1
<<
(
cutoff_bit
))
-
1
for
start_n
in
range
(
0
,
shape_bn
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
values
=
tl
.
load
(
Bitmatrix
+
pid_m
*
stride_bm
+
offs_n
*
stride_bn
,
mask
=
offs_n
<
shape_bn
)
values
=
tl
.
where
(
offs_n
==
cutoff_word
,
values
&
cutoff_mask
,
values
)
values
=
tl
.
where
(
offs_n
>
cutoff_word
,
0
,
values
)
tl
.
store
(
Bitmatrix
+
pid_m
*
stride_bm
+
offs_n
*
stride_bn
,
values
,
mask
=
offs_n
<
shape_bn
,
)
@
triton
.
jit
def
_combined_routing_memset
(
Indx
,
size
,
sentinel
,
BLOCK
:
tl
.
constexpr
,
ExpertHist
,
FinalExpertOffs
,
hist_size
,
n_expts_tot
,
PartialHist
,
shape_pm
,
stride_pm
,
stride_pn
,
MDStarts
,
tile_starts_stridem
,
blocks1a
,
MDTileInfo
,
first_tile_dim_log2
,
SIZES
:
tl
.
constexpr
,
BLOCK_A
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
"""
This kernel essentially combines 6 different pieces of functionality,
statically branching on the value of tl.program_id(0) to decide which
codepath to take.
pid == 0: create the token cumsum
1 <= pid <= SIZES: create a tile cumsum
SIZES < pid < blocks1a: initialise MDTileInfo to 0xffffffff
blocks1a <= pid < blocks1a + n_expts_tot: compute_indx_offs
pid == blocks1a + n_expts_tot: compute_expt_offs
pid > blocks1a + n_expts_tot: initialise Indx to sentinel
As each of these is a relatively trivial workload, launching them from
this single trampoline is beneficial as they can execute on different
streaming multiprocesses in parallel.
"""
pid
=
tl
.
program_id
(
0
)
if
pid
<
blocks1a
:
_expt_data_memset
(
ExpertHist
,
n_expts_tot
,
MDStarts
,
tile_starts_stridem
,
MDTileInfo
,
first_tile_dim_log2
,
SIZES
,
BLOCK_A
,
)
elif
pid
==
n_expts_tot
+
blocks1a
:
_routing_compute_expt_offs
(
ExpertHist
,
FinalExpertOffs
,
hist_size
,
BLOCK_N
)
elif
pid
<
n_expts_tot
+
blocks1a
:
_routing_compute_indx_offs
(
PartialHist
,
shape_pm
,
stride_pm
,
stride_pn
,
BLOCK_M
,
pid
-
blocks1a
)
else
:
offs
=
(
pid
-
n_expts_tot
-
blocks1a
-
1
)
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
mask
=
offs
<
size
tl
.
store
(
Indx
+
offs
,
sentinel
,
mask
=
mask
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/specialize.py
0 → 100644
View file @
d29c39ca
import
inspect
import
re
import
textwrap
import
types
import
triton
def
cacheable
(
f
):
"""
A decorator that allow you to write something of the form:
@cacheable
def my_kernel(): return (expression dynamically defining a kernel)
such that it interacts gracefully with triton cache and preload.
"""
g
=
f
()
g
.
fn
.
__name__
=
f
.
__name__
g
.
fn
.
__module__
=
f
.
__module__
g
.
fn
.
__qualname__
=
f
.
__qualname__
g
.
__name__
=
f
.
__name__
g
.
__module__
=
f
.
__module__
g
.
__qualname__
=
f
.
__qualname__
g
.
_fn_name
=
f
"
{
f
.
__module__
}
.
{
f
.
__qualname__
}
"
return
g
def
define_kernel
(
src
,
module
,
attrs
=
None
,
**
extra_globals
):
"""
Dynamically create a Triton function or kernel from a src string,
linking any symbols in the kernel to objects specified by extra_globals.
"""
# create templace function
def
_empty_fn
():
pass
gdict
=
dict
(
**
(
_empty_fn
.
__globals__
))
gdict
.
update
(
extra_globals
)
f
=
types
.
FunctionType
(
_empty_fn
.
__code__
,
gdict
)
f
.
__module__
=
module
.
__name__
src
=
textwrap
.
dedent
(
src
)
src
=
src
[
src
.
find
(
"def "
)
:]
stored_functions
=
[]
function_name
=
src
[
4
:].
split
(
"("
)[
0
].
strip
()
exec_globals
=
gdict
exec_globals
.
update
({
"stored_functions"
:
stored_functions
})
exec
(
src
+
"
\n\n
stored_functions.append("
+
function_name
+
")
\n
"
,
exec_globals
)
f
.
__signature__
=
inspect
.
signature
(
stored_functions
[
0
])
f
.
__name__
=
function_name
f
.
__doc__
=
stored_functions
[
0
].
__doc__
if
attrs
is
None
:
attrs
=
dict
()
f
=
triton
.
JITFunction
(
f
,
**
attrs
)
f
.
_unsafe_update_src
(
src
)
return
f
def
specialize
(
fn
,
module
,
constants
,
tuples
,
name
=
None
,
do_not_specialize
=
tuple
()):
assert
isinstance
(
fn
,
triton
.
runtime
.
jit
.
JITFunction
)
if
name
is
None
:
name
=
f
"
{
fn
.
__name__
}
"
# Get original source code
src
=
inspect
.
getsource
(
fn
.
fn
)
src
=
textwrap
.
dedent
(
src
)
lines
=
src
.
split
(
"
\n
"
)
# Skip decorator and def line
def_idx
=
next
(
i
for
i
,
line
in
enumerate
(
lines
)
if
line
.
strip
().
startswith
(
"def"
))
# separate header vs body LOC
header_end
=
def_idx
while
not
lines
[
header_end
].
rstrip
().
endswith
(
":"
):
header_end
+=
1
body_lines
=
lines
[
header_end
+
1
:]
header_lines
=
lines
[
def_idx
:
header_end
+
1
]
# clean-up header
header_clean
=
[
l
.
split
(
"#"
,
1
)[
0
].
strip
()
# keep code, discard comment
for
l
in
header_lines
if
l
.
split
(
"#"
,
1
)[
0
].
strip
()
# skip blank‑after‑comment lines
]
# decompose arguments
header_src
=
" "
.
join
(
header_clean
)
# turn it into a single line
m
=
re
.
search
(
r
"\((.*)\)\s*:"
,
header_src
)
if
not
m
:
raise
ValueError
(
"Could not parse function header"
)
args_str
=
m
.
group
(
1
)
args
=
[
arg
.
strip
()
for
arg
in
args_str
.
split
(
","
)
if
arg
.
strip
()]
non_specialized_args
=
[]
for
arg
in
args
:
arg_key
=
arg
.
split
(
":"
)[
0
].
split
(
"="
)[
0
].
strip
()
new_args
=
tuples
.
get
(
arg_key
,
[
arg
])
if
arg_key
not
in
constants
:
non_specialized_args
+=
new_args
# add global symbols
spec_fns
=
{
v
.
__name__
:
v
for
k
,
v
in
constants
.
items
()
if
isinstance
(
v
,
triton
.
runtime
.
jit
.
JITFunction
)
}
globals
=
spec_fns
|
fn
.
get_capture_scope
()
# build new source code and define kernel dynamically
new_signature
=
f
"def
{
name
}
(
{
', '
.
join
(
non_specialized_args
)
}
):"
constexpr_lines
=
[
f
"
{
key
}
: tl.constexpr =
{
value
.
__name__
if
callable
(
value
)
else
value
}
"
for
key
,
value
in
constants
.
items
()
]
tuple_lines
=
[
f
"
{
key
}
=
{
'('
+
','
.
join
(
value
)
+
(
','
if
len
(
value
)
>=
1
else
''
)
+
')'
}
"
for
key
,
value
in
tuples
.
items
()
]
new_src
=
"
\n
"
.
join
(
[
"@triton.jit"
,
new_signature
]
+
constexpr_lines
+
tuple_lines
+
body_lines
)
# find function parameters
sig
=
inspect
.
signature
(
triton
.
runtime
.
jit
.
JITFunction
.
__init__
)
params
=
list
(
sig
.
parameters
.
values
())[
2
:]
attrs
=
{
param
.
name
:
getattr
(
fn
,
param
.
name
,
param
.
default
)
for
param
in
params
}
# make a new repr which appends the repr of the specialized functions.
base_repr
=
attrs
[
"repr"
]
def
new_repr
(
specialization
):
ret
=
base_repr
(
specialization
)
for
spec_fn
in
spec_fns
.
values
():
spec_repr
=
spec_fn
.
repr
(
None
)
if
spec_repr
:
spec_repr
=
spec_repr
.
strip
(
"_"
)
if
spec_repr
:
ret
+=
f
"_
{
spec_repr
}
"
return
ret
attrs
[
"repr"
]
=
new_repr
if
do_not_specialize
:
attrs
[
"do_not_specialize"
]
=
do_not_specialize
ret
=
define_kernel
(
new_src
,
module
,
attrs
,
**
globals
)
return
ret
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu.py
0 → 100644
View file @
d29c39ca
from
dataclasses
import
dataclass
from
compactor_vllm.triton_kernels.numerics
import
InFlexData
,
OutFlexData
import
torch
import
triton
from
.swiglu_details._swiglu
import
_swiglu
,
_swiglu_fn
from
compactor_vllm.triton_kernels
import
target_info
@
dataclass
(
frozen
=
True
)
class
FlexCtx
:
out_data
:
OutFlexData
=
OutFlexData
()
inp_data
:
InFlexData
=
InFlexData
()
saturate_inf
:
bool
=
False
@
dataclass
(
frozen
=
True
)
class
PrecisionConfig
:
limit
:
float
flex_ctx
:
FlexCtx
=
FlexCtx
()
swiglu_fn
=
_swiglu_fn
class
SwiGLU
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
a
,
alpha
,
precision_config
,
routing_data
):
N
=
a
.
shape
[
-
1
]
M
=
a
.
numel
()
//
N
assert
a
.
stride
()[
-
1
]
==
1
assert
a
.
shape
[
-
1
]
%
2
==
0
out
=
torch
.
empty
(
size
=
(
M
,
N
//
2
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
flex_ctx
=
precision_config
.
flex_ctx
# optimization hyperparameters
BLOCK_M
,
BLOCK_N
=
32
//
a
.
itemsize
,
128
num_warps
=
4
kwargs
=
{
"maxnreg"
:
64
}
if
not
target_info
.
is_hip
()
else
{}
# launch semi-persistent kernel
N_BLOCKS
=
triton
.
cdiv
(
N
//
2
,
BLOCK_N
)
num_sms
=
target_info
.
num_sms
()
if
routing_data
is
not
None
:
waves_per_sm
=
32
if
target_info
.
is_hip
()
else
128
num_pid
=
num_sms
*
(
waves_per_sm
//
num_warps
)
M_BLOCKS
=
max
(
1
,
triton
.
cdiv
(
num_pid
,
N_BLOCKS
))
grid
=
(
min
(
M_BLOCKS
*
N_BLOCKS
,
4
*
num_sms
),)
else
:
M_BLOCKS
=
triton
.
cdiv
(
M
,
BLOCK_M
)
if
M_BLOCKS
*
N_BLOCKS
>=
8
*
num_sms
:
grid
=
(
8
*
num_sms
,)
else
:
grid
=
(
min
(
M_BLOCKS
*
N_BLOCKS
,
4
*
num_sms
),)
n_tokens
=
None
if
routing_data
is
not
None
:
n_tokens
=
routing_data
.
expt_data
.
token_offs_raw
[
routing_data
.
n_expts_tot
]
_swiglu
[
grid
](
flex_ctx
.
out_data
.
reinterpret
(
out
),
flex_ctx
.
out_data
.
expected_scale
,
flex_ctx
.
out_data
.
actual_scale
,
flex_ctx
.
out_data
.
checksum_scale
,
flex_ctx
.
inp_data
.
reinterpret
(
a
),
flex_ctx
.
inp_data
.
scale
,
alpha
,
M
,
N
//
2
,
a
.
shape
[
-
1
],
1
,
out
.
shape
[
-
1
],
1
,
precision_config
.
limit
,
n_tokens
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
EVEN_N
=
(
N
//
2
)
%
BLOCK_N
==
0
,
M_BLOCKS
=
M_BLOCKS
,
N_BLOCKS
=
N_BLOCKS
,
flexpoint_saturate_inf
=
flex_ctx
.
saturate_inf
,
num_warps
=
num_warps
,
**
kwargs
,
)
out
=
out
.
view
(
a
.
shape
[:
-
1
]
+
out
.
shape
[
-
1
:])
return
out
def
swiglu
(
a
,
alpha
,
precision_config
,
routing_data
=
None
):
return
SwiGLU
.
apply
(
a
,
alpha
,
precision_config
,
routing_data
)
def
swiglu_torch
(
a
,
alpha
,
precision_config
):
limit
=
precision_config
.
limit
a_gelu
=
a
[...,
::
2
]
if
limit
is
not
None
:
a_gelu
=
a_gelu
.
clamp
(
max
=
limit
)
a_linear
=
a
[...,
1
::
2
]
if
limit
is
not
None
:
a_linear
=
a_linear
.
clamp
(
min
=-
limit
,
max
=
limit
)
out_gelu
=
a_gelu
*
torch
.
sigmoid
(
alpha
*
a_gelu
)
out
=
out_gelu
*
(
a_linear
+
1
)
return
out
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/__init__.py
0 → 100644
View file @
d29c39ca
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/_swiglu.py
0 → 100644
View file @
d29c39ca
from
compactor_vllm.triton_kernels.numerics_details.flexpoint
import
(
load_scale
,
float_to_flex
,
update_scale
,
)
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
clip
(
x
,
limit
,
clip_lower
:
tl
.
constexpr
):
res
=
tl
.
minimum
(
x
,
limit
)
if
clip_lower
:
res
=
tl
.
maximum
(
-
limit
,
res
)
return
res
@
triton
.
jit
def
thread_local_absmax
(
x
,
BLOCK_SIZE
:
tl
.
constexpr
,
NUM_THREADS
:
tl
.
constexpr
):
return
tl
.
max
(
tl
.
reshape
(
tl
.
abs
(
x
),
[
NUM_THREADS
,
BLOCK_SIZE
//
NUM_THREADS
],
can_reorder
=
True
),
axis
=
1
,
)
def
swiglu_repr
(
specialization
):
signature
=
specialization
.
signature
constants
=
specialization
.
constants
convert_dtype
=
lambda
dtype
:
"mxfp4"
if
"u8"
in
dtype
else
dtype
dtypes
=
"x"
.
join
([
convert_dtype
(
f
"
{
signature
[
i
][
1
:]
}
"
)
for
i
in
[
"Out"
,
"A"
]])
blocks
=
"x"
.
join
([
f
"
{
constants
[
i
]
}
"
for
i
in
[
"BLOCK_M"
,
"BLOCK_N"
]])
return
f
"_swiglu_
{
dtypes
}
_
{
blocks
}
"
def
swiglu_launch_metadata
(
grid
,
kernel
,
args
):
M
,
N
=
args
[
"M"
],
args
[
"N"
]
ret
=
dict
()
ret
[
"name"
]
=
f
"
{
kernel
.
name
}
[M =
{
M
}
, N =
{
N
}
]"
A
,
Out
=
args
[
"A"
],
args
[
"Out"
]
ret
[
"bytes"
]
=
Out
.
numel
()
*
Out
.
element_size
()
+
A
.
numel
()
*
A
.
element_size
()
return
ret
@
triton
.
jit
def
compute_swiglu
(
gelu
,
linear
,
scale
,
alpha
,
limit
):
gelu
=
gelu
.
to
(
tl
.
float32
)
*
scale
if
limit
is
not
None
:
gelu
=
clip
(
gelu
,
limit
,
clip_lower
=
False
)
linear
=
linear
.
to
(
tl
.
float32
)
*
scale
if
limit
is
not
None
:
linear
=
clip
(
linear
,
limit
,
clip_lower
=
True
)
s
=
gelu
/
(
1
+
tl
.
exp
(
-
alpha
*
gelu
))
return
tl
.
fma
(
s
,
linear
,
s
)
# (s * (linear + 1))
@
triton
.
jit
(
repr
=
lambda
_
:
"_swiglu"
)
def
_swiglu_fn
(
input
,
alpha
,
limit
):
gelu
,
linear
=
tl
.
split
(
tl
.
reshape
(
input
,
(
input
.
shape
[
0
],
input
.
shape
[
1
]
//
2
,
2
)))
return
compute_swiglu
(
gelu
,
linear
,
1.0
,
alpha
,
limit
)
@
triton
.
jit
(
repr
=
swiglu_repr
,
launch_metadata
=
swiglu_launch_metadata
)
def
_swiglu
(
Out
,
OutExpectedScale
,
OutActualScale
,
OutChecksumScale
,
A
,
AScale
,
alpha
,
M
,
N
,
stride_am
,
stride_an
,
stride_outm
,
stride_outn
,
limit
:
tl
.
constexpr
,
NTokens
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
M_BLOCKS
,
N_BLOCKS
,
flexpoint_saturate_inf
:
tl
.
constexpr
,
):
if
NTokens
is
not
None
:
M
=
tl
.
load
(
NTokens
)
M_BLOCKS
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
local_max
=
tl
.
full
([
tl
.
extra
.
cuda
.
num_threads
()],
0.0
,
tl
.
float32
)
a_scale
=
load_scale
(
AScale
)
out_expected_scale
=
load_scale
(
OutExpectedScale
)
for
pid
in
tl
.
range
(
tl
.
program_id
(
0
),
M_BLOCKS
*
N_BLOCKS
,
tl
.
num_programs
(
0
),
num_stages
=
2
):
pid_m
=
pid
//
N_BLOCKS
pid_n
=
pid
%
N_BLOCKS
off_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_m
=
off_m
<
M
mask_n
=
off_n
<
N
packed_off_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
2
*
BLOCK_N
)
//
2
packed_mask_n
=
packed_off_n
<
N
packed_mask_n
=
tl
.
max_constancy
(
packed_mask_n
,
[
16
])
# load a
packed_off_n
=
pid_n
*
2
*
BLOCK_N
+
tl
.
arange
(
0
,
2
*
BLOCK_N
)
packed_offs
=
off_m
[:,
None
]
*
stride_am
+
packed_off_n
[
None
,
:]
*
stride_an
if
EVEN_N
:
a_packed
=
tl
.
load
(
A
+
packed_offs
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
else
:
if
pid_n
*
BLOCK_N
+
BLOCK_N
<=
N
:
a_packed
=
tl
.
load
(
A
+
packed_offs
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
else
:
packed_mask
=
mask_m
[:,
None
]
&
packed_mask_n
[
None
,
:]
a_packed
=
tl
.
load
(
A
+
packed_offs
,
mask
=
packed_mask
,
other
=
0.0
)
a_gelu
,
a_linear
=
tl
.
split
(
tl
.
reshape
(
a_packed
,
(
BLOCK_M
,
BLOCK_N
,
2
)))
out
=
compute_swiglu
(
a_gelu
,
a_linear
,
a_scale
,
alpha
,
limit
)
# update flexpoint stats and divide by scale
# we don't need masking because of the `other` when loading `A`
if
OutActualScale
is
not
None
:
absmax
=
thread_local_absmax
(
out
,
out
.
numel
,
tl
.
extra
.
cuda
.
num_threads
())
local_max
=
tl
.
maximum
(
local_max
,
absmax
)
out
=
float_to_flex
(
out
,
out_expected_scale
,
None
,
# ActualScale: local absmax is tracked and updated after the loop
OutChecksumScale
,
None
,
Out
,
flexpoint_saturate_inf
,
)
mask
=
mask_m
[:,
None
]
if
EVEN_N
else
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
tl
.
store
(
Out
+
off_m
[:,
None
]
*
stride_outm
+
off_n
[
None
,
:]
*
stride_outn
,
out
,
mask
)
update_scale
(
local_max
,
OutActualScale
,
Out
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/target_info.py
0 → 100644
View file @
d29c39ca
import
torch
import
triton
import
triton.language
as
tl
from
triton.language.target_info
import
(
cuda_capability_geq
,
is_cuda
,
is_hip
,
is_hip_cdna3
,
is_hip_cdna4
,
)
__all__
=
[
"cuda_capability_geq"
,
"get_cdna_version"
,
"has_tma_gather"
,
"has_native_mxfp"
,
"is_cuda"
,
"is_hip"
,
"is_hip_cdna3"
,
"is_hip_cdna4"
,
"num_sms"
,
]
@
triton
.
constexpr_function
def
get_cdna_version
():
"""
Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
only supports 3 (gfx942) or 4 (gfx950). Returns -1 if it is not AMD
hardware or unsupported architecture
"""
target
=
tl
.
target_info
.
current_target
()
if
target
.
backend
!=
"hip"
:
return
-
1
if
target
.
arch
==
"gfx942"
:
return
3
if
target
.
arch
==
"gfx950"
:
return
4
return
-
1
@
triton
.
constexpr_function
def
has_tma_gather
():
return
cuda_capability_geq
(
10
,
0
)
@
triton
.
constexpr_function
def
has_native_mxfp
():
return
cuda_capability_geq
(
10
,
0
)
def
num_sms
():
return
torch
.
cuda
.
get_device_properties
(
0
).
multi_processor_count
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor.py
0 → 100644
View file @
d29c39ca
from
dataclasses
import
dataclass
,
fields
from
typing
import
Type
import
torch
from
triton.tools.tensor_descriptor
import
TensorDescriptor
from
triton.tools.ragged_tma
import
create_ragged_descriptor
from
.reduction_details.reduce_bitmatrix
import
clear_sums
,
sum_bitmatrix_rows
from
.target_info
import
cuda_capability_geq
from
.tensor_details.layout
import
Layout
,
StridedLayout
@
dataclass
class
Storage
:
data
:
torch
.
Tensor
layout
:
Layout
=
None
def
__post_init__
(
self
):
assert
isinstance
(
self
.
data
,
torch
.
Tensor
)
if
self
.
layout
is
None
:
self
.
layout
=
StridedLayout
(
self
.
data
.
shape
)
@
property
def
device
(
self
):
return
self
.
data
.
device
def
is_tma_compliant
(
self
):
# TMAs didn't exist until Hopper
if
not
cuda_capability_geq
(
9
,
0
):
return
False
# TMAs only exist for 2D, 3D, 5D inputs
if
len
(
self
.
data
.
shape
)
not
in
[
2
,
3
,
5
]:
return
False
# TMAs need at most one stride equal to 1
# and all other strides divisble by 16
strides
=
list
(
self
.
data
.
stride
())
try
:
major_dim
=
strides
.
index
(
1
)
except
ValueError
:
major_dim
=
-
1
ndim
=
self
.
data
.
ndim
bitwidth
=
4
if
self
.
data
.
dtype
==
torch
.
uint8
else
self
.
data
.
element_size
()
*
8
compliant
=
[
strides
[
i
]
*
bitwidth
%
128
==
0
for
i
in
range
(
ndim
)
if
i
!=
major_dim
]
return
all
(
compliant
)
def
make_dense_tma
(
self
,
block_shape
,
transpose
=
False
):
strides
=
list
(
self
.
data
.
stride
())
shape
=
list
(
self
.
data
.
shape
)
transpose
=
self
.
data
.
stride
()[
-
1
]
!=
1
if
transpose
:
block_shape
=
block_shape
[:
-
2
]
+
[
block_shape
[
-
1
],
block_shape
[
-
2
]]
shape
=
shape
[:
-
2
]
+
[
shape
[
-
1
],
shape
[
-
2
]]
strides
=
strides
[:
-
2
]
+
[
strides
[
-
1
],
strides
[
-
2
]]
if
self
.
data
.
dtype
==
torch
.
uint8
and
self
.
layout
.
name
==
"BLACKWELL_VALUE"
:
indx
=
strides
.
index
(
1
)
block_shape
[
indx
]
=
block_shape
[
indx
]
//
2
if
shape
[
-
1
]
%
128
!=
0
:
raise
ValueError
(
"inner shape need to be multiple of 128 for "
"mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs."
)
block_shape
=
self
.
layout
.
swizzle_block_shape
(
block_shape
)
return
TensorDescriptor
(
self
.
data
,
shape
,
strides
,
block_shape
)
def
make_tma
(
self
,
block_shape
,
mode
,
transpose
=
False
):
if
mode
in
[
"dense"
,
"gather"
,
"scatter"
]:
return
self
.
make_dense_tma
(
block_shape
,
transpose
)
assert
mode
==
"ragged"
ragged_dim
=
len
(
self
.
data
.
shape
)
-
2
return
create_ragged_descriptor
(
self
.
data
,
block_shape
,
ragged_dim
=
ragged_dim
)
@
dataclass
class
IntegerType
:
bitwidth
:
int
@
dataclass
class
FloatType
:
bitwidth_exponent
:
int
bitwidth_mantissa
:
int
is_signed
:
bool
def
__post_init__
(
self
):
self
.
bitwidth
=
(
int
(
self
.
is_signed
)
+
self
.
bitwidth_exponent
+
self
.
bitwidth_mantissa
)
BIT
=
IntegerType
(
1
)
FP4
=
FloatType
(
bitwidth_exponent
=
2
,
bitwidth_mantissa
=
1
,
is_signed
=
True
)
def
bitwidth
(
type
:
IntegerType
|
FloatType
|
torch
.
dtype
):
if
isinstance
(
type
,
torch
.
dtype
):
return
type
.
itemsize
*
8
return
type
.
bitwidth
@
dataclass
class
Tensor
:
storage
:
Storage
|
torch
.
Tensor
dtype
:
IntegerType
|
FloatType
|
torch
.
dtype
=
None
shape
:
list
[
int
]
|
None
=
None
shape_max
:
list
[
int
]
|
None
=
None
def
__post_init__
(
self
):
# set storage
if
isinstance
(
self
.
storage
,
torch
.
Tensor
):
self
.
storage
=
Storage
(
self
.
storage
)
# initialize dtype
if
self
.
dtype
is
None
:
self
.
dtype
=
self
.
storage
.
data
.
dtype
if
bitwidth
(
self
.
dtype
)
<
8
and
self
.
shape
is
None
:
raise
ValueError
(
"shape must be provided for sub-byte types"
)
# initialize shape
if
self
.
shape
is
None
:
self
.
shape
=
list
(
self
.
storage
.
data
.
shape
)
# validate shape: all elements must be `int` or numel-1 `torch.Tensor`
is_int
=
lambda
s
:
isinstance
(
s
,
int
)
is_item
=
lambda
s
:
hasattr
(
s
,
"numel"
)
and
s
.
numel
()
==
1
assert
all
(
map
(
lambda
s
:
is_int
(
s
)
or
is_item
(
s
),
self
.
shape
))
# initialize shape_max
if
self
.
shape_max
is
None
:
self
.
shape_max
=
[
None
]
*
len
(
self
.
shape
)
for
i
,
(
s
,
smax
)
in
enumerate
(
zip
(
self
.
shape
,
self
.
shape_max
)):
if
smax
is
not
None
and
not
is_int
(
smax
):
raise
ValueError
(
f
"shape_max[
{
i
}
] must be `int` or `None`; got
{
type
(
smax
)
}
"
)
if
smax
is
None
:
self
.
shape_max
[
i
]
=
s
# validate shape_max: all elements must be `int`
assert
all
(
map
(
is_int
,
self
.
shape_max
))
# torch compatibility layer
@
property
def
ndim
(
self
):
return
len
(
self
.
shape
)
@
property
def
device
(
self
):
return
self
.
storage
.
device
def
stride
(
self
,
i
=
None
):
return
self
.
storage
.
data
.
stride
()
if
i
is
None
else
self
.
storage
.
data
.
stride
(
i
)
def
data_ptr
(
self
):
return
self
.
storage
.
data
.
data_ptr
()
def
numel
(
self
):
return
self
.
storage
.
data
.
numel
()
def
element_size
(
self
):
return
bitwidth
(
self
.
dtype
)
//
8
@
property
def
data
(
self
):
t
=
self
.
storage
return
t
.
data
if
isinstance
(
t
,
Storage
)
else
t
def
dim
(
self
):
return
self
.
ndim
def
size
(
self
,
i
=
None
):
if
i
is
None
:
return
self
.
shape
return
self
.
shape
[
i
]
@
dataclass
class
Bitmatrix
(
Tensor
):
"""
Represents a boolean matrix in a packed format where each element occupies
a single bit of memory.
_scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
with the actual bitmatrix to avoid having to launch a separate memset
kernel when we call Bitmatrix::sum().
"""
scratchpad
:
torch
.
Tensor
=
None
def
__init__
(
self
,
storage
,
shape
,
shape_max
=
None
,
scratchpad
=
None
):
super
().
__init__
(
storage
,
dtype
=
BIT
,
shape
=
shape
,
shape_max
=
shape_max
)
self
.
scratchpad
=
scratchpad
def
sum
(
self
,
partials_block_size
):
_
,
n_cols
=
self
.
shape
dev
=
self
.
device
if
self
.
scratchpad
is
None
:
self
.
scratchpad
=
clear_sums
(
n_cols
,
dev
)
out_ret
=
self
.
scratchpad
[:
n_cols
]
self
.
scratchpad
=
None
# throw error if we try to sum again
return
sum_bitmatrix_rows
(
self
,
out_ret
,
partials_block_size
)
def
get_layout
(
tensor
:
torch
.
Tensor
|
Tensor
|
None
):
if
tensor
is
None
:
return
None
if
isinstance
(
tensor
,
Tensor
):
return
tensor
.
storage
.
layout
return
StridedLayout
def
wrap_torch_tensor
(
torch_tensor
,
dtype
=
None
):
if
dtype
is
None
:
dtype
=
torch_tensor
.
dtype
shape
=
list
(
torch_tensor
.
shape
)
shape
[
torch_tensor
.
stride
().
index
(
1
)]
*=
bitwidth
(
torch_tensor
.
dtype
)
//
bitwidth
(
dtype
)
return
Tensor
(
Storage
(
torch_tensor
),
dtype
=
dtype
,
shape
=
shape
)
def
convert_layout
(
tensor
:
Tensor
,
layout_cls
:
Type
[
Layout
],
**
layout_kwargs
):
assert
isinstance
(
tensor
,
Tensor
)
old_storage
=
tensor
.
storage
old_data
=
old_storage
.
layout
.
unswizzle_data
(
old_storage
.
data
)
new_layout
=
layout_cls
(
old_data
.
shape
,
**
layout_kwargs
)
new_data
=
new_layout
.
swizzle_data
(
old_data
)
attrs
=
{
k
.
name
:
getattr
(
tensor
,
k
.
name
)
for
k
in
fields
(
tensor
)
if
k
.
name
!=
"storage"
}
return
Tensor
(
Storage
(
new_data
,
new_layout
),
**
attrs
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/__init__.py
0 → 100644
View file @
d29c39ca
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout.py
0 → 100644
View file @
d29c39ca
from
.layout_details.base
import
Layout
from
.layout_details.blackwell_scale
import
BlackwellMXScaleLayout
from
.layout_details.blackwell_value
import
BlackwellMXValueLayout
from
.layout_details.hopper_scale
import
HopperMXScaleLayout
from
.layout_details.hopper_value
import
HopperMXValueLayout
from
.layout_details.cdna4_scale
import
CDNA4MXScaleLayout
from
.layout_details.strided
import
StridedLayout
from
..target_info
import
cuda_capability_geq
,
is_hip_cdna4
__all__
=
[
"Layout"
,
"BlackwellMXValueLayout"
,
"BlackwellMXScaleLayout"
,
"HopperMXScaleLayout"
,
"HopperMXValueLayout"
,
"CDNA4MXScaleLayout"
,
"StridedLayout"
,
]
def
make_default_matmul_mxfp4_w_layout
(
mx_axis
:
int
):
if
cuda_capability_geq
(
10
):
# return StridedLayout, dict()
return
BlackwellMXValueLayout
,
dict
()
elif
cuda_capability_geq
(
9
):
return
HopperMXValueLayout
,
{
"mx_axis"
:
mx_axis
}
else
:
return
StridedLayout
,
dict
()
def
make_default_matmul_mxfp4_w_scale_layout
(
mx_axis
:
int
,
num_warps
:
int
=
8
):
if
is_hip_cdna4
():
return
CDNA4MXScaleLayout
,
dict
()
else
:
if
cuda_capability_geq
(
10
):
return
BlackwellMXScaleLayout
,
dict
()
elif
cuda_capability_geq
(
9
):
return
HopperMXScaleLayout
,
{
"mx_axis"
:
mx_axis
,
"num_warps"
:
num_warps
}
return
StridedLayout
,
dict
()
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/__init__.py
0 → 100644
View file @
d29c39ca
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/base.py
0 → 100644
View file @
d29c39ca
from
abc
import
ABC
,
abstractmethod
class
Layout
(
ABC
):
def
__init__
(
self
,
shape
)
->
None
:
self
.
initial_shape
=
shape
@
abstractmethod
def
swizzle_data
(
self
,
data
):
pass
@
abstractmethod
def
unswizzle_data
(
self
,
data
):
pass
@
abstractmethod
def
swizzle_block_shape
(
self
,
block_shape
):
pass
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_scale.py
0 → 100644
View file @
d29c39ca
import
math
import
triton
import
triton.language
as
tl
import
torch
from
.base
import
Layout
SWIZZLE_ALIGN_INNER
=
8
SWIZZLE_SIZE_INNER
=
4
SWIZZLE_SIZE_OUTER
=
128
class
BlackwellMXScaleLayout
(
Layout
):
name
:
str
=
"BLACKWELL_SCALE"
def
__init__
(
self
,
shape
)
->
None
:
super
().
__init__
(
shape
)
(
*
self
.
leading_shape
,
self
.
K
,
self
.
N
,
)
=
shape
self
.
B
=
math
.
prod
(
self
.
leading_shape
)
self
.
ALIGN_K
=
8
self
.
ALIGN_N
=
128
self
.
SWIZZLE_K
=
4
self
.
K_pad
=
(
self
.
K
+
self
.
ALIGN_K
-
1
)
//
self
.
ALIGN_K
*
self
.
ALIGN_K
self
.
N_pad
=
(
self
.
N
+
self
.
ALIGN_N
-
1
)
//
self
.
ALIGN_N
*
self
.
ALIGN_N
def
swizzle_data
(
self
,
data
):
data
=
torch
.
nn
.
functional
.
pad
(
data
,
(
0
,
self
.
N_pad
-
self
.
N
,
0
,
self
.
K_pad
-
self
.
K
)
)
data
=
data
.
transpose
(
-
1
,
-
2
).
contiguous
()
data
=
data
.
reshape
(
self
.
B
,
self
.
N_pad
//
self
.
ALIGN_N
,
self
.
ALIGN_N
//
32
,
32
,
self
.
K_pad
//
self
.
SWIZZLE_K
,
self
.
SWIZZLE_K
,
)
data
=
data
.
transpose
(
2
,
4
).
contiguous
()
data
=
data
.
view
(
1
,
self
.
B
*
self
.
N_pad
//
128
,
self
.
K_pad
//
4
,
2
,
256
)
return
data
def
unswizzle_data
(
self
,
data
):
data
=
data
.
reshape
(
self
.
B
,
self
.
N_pad
//
self
.
ALIGN_N
,
self
.
K_pad
//
self
.
SWIZZLE_K
,
32
,
self
.
ALIGN_N
//
32
,
self
.
SWIZZLE_K
,
)
data
=
data
.
transpose
(
2
,
4
)
data
=
data
.
reshape
(
*
self
.
leading_shape
,
self
.
N_pad
,
self
.
K_pad
)
data
=
data
.
transpose
(
-
1
,
-
2
)
return
data
[...,
:
self
.
K
,
:
self
.
N
]
def
swizzle_block_shape
(
self
,
block_shape
):
MX_PACK_DIVISOR
=
32
MX_SCALE_BLOCK_K
=
block_shape
[
1
]
//
MX_PACK_DIVISOR
return
[
1
,
block_shape
[
0
]
//
128
,
MX_SCALE_BLOCK_K
//
4
,
2
,
256
]
@
triton
.
jit
def
unswizzle_mx_scale_bw
(
x
,
SIZE_OUTER
:
tl
.
constexpr
=
SWIZZLE_SIZE_OUTER
,
SIZE_INNER
:
tl
.
constexpr
=
SWIZZLE_SIZE_INNER
,
ALIGN_INNER
:
tl
.
constexpr
=
SWIZZLE_ALIGN_INNER
,
):
shape_0
:
tl
.
constexpr
=
x
.
shape
[
0
]
shape_1
:
tl
.
constexpr
=
x
.
shape
[
1
]
tl
.
static_assert
(
shape_1
%
SIZE_OUTER
==
0
)
tl
.
static_assert
(
shape_1
//
SIZE_OUTER
<=
ALIGN_INNER
)
x
=
x
.
reshape
(
shape_0
,
(
shape_1
//
SIZE_OUTER
)
//
SIZE_INNER
,
32
,
SIZE_OUTER
//
32
,
SIZE_INNER
)
x
=
x
.
trans
(
0
,
3
,
2
,
1
,
4
).
reshape
(
shape_0
*
SIZE_OUTER
,
shape_1
//
SIZE_OUTER
)
return
x
Prev
1
2
3
4
5
6
7
8
9
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment