Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f81ce56b
Commit
f81ce56b
authored
Apr 23, 2026
by
chenzk
Browse files
vllm kvprune:v1.0.1
parent
2b7160c6
Changes
237
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
1879 deletions
+0
-1879
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/proton_opts.py
...tor-vllm/src/compactor_vllm/triton_kernels/proton_opts.py
+0
-19
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/__init__.py
...mpactor_vllm/triton_kernels/reduction_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/reduce_bitmatrix.py
...vllm/triton_kernels/reduction_details/reduce_bitmatrix.py
+0
-133
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing.py
...mpactor-vllm/src/compactor_vllm/triton_kernels/routing.py
+0
-521
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/__init__.py
...compactor_vllm/triton_kernels/routing_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_expt_data.py
...mpactor_vllm/triton_kernels/routing_details/_expt_data.py
+0
-75
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_routing_compute.py
...r_vllm/triton_kernels/routing_details/_routing_compute.py
+0
-241
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/specialize.py
...ctor-vllm/src/compactor_vllm/triton_kernels/specialize.py
+0
-143
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu.py
...ompactor-vllm/src/compactor_vllm/triton_kernels/swiglu.py
+0
-99
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/__init__.py
.../compactor_vllm/triton_kernels/swiglu_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/_swiglu.py
...c/compactor_vllm/triton_kernels/swiglu_details/_swiglu.py
+0
-141
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/target_info.py
...tor-vllm/src/compactor_vllm/triton_kernels/target_info.py
+0
-54
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor.py
...ompactor-vllm/src/compactor_vllm/triton_kernels/tensor.py
+0
-227
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/__init__.py
.../compactor_vllm/triton_kernels/tensor_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout.py
...rc/compactor_vllm/triton_kernels/tensor_details/layout.py
+0
-40
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/__init__.py
.../triton_kernels/tensor_details/layout_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/base.py
...vllm/triton_kernels/tensor_details/layout_details/base.py
+0
-18
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_scale.py
..._kernels/tensor_details/layout_details/blackwell_scale.py
+0
-81
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_value.py
..._kernels/tensor_details/layout_details/blackwell_value.py
+0
-37
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/cdna4_scale.py
...iton_kernels/tensor_details/layout_details/cdna4_scale.py
+0
-50
No files found.
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/proton_opts.py
deleted
100644 → 0
View file @
2b7160c6
# proton options
import
os
_launch_metadata_allow_sync
=
None
def
launch_metadata_allow_sync
():
global
_launch_metadata_allow_sync
if
_launch_metadata_allow_sync
is
None
:
_launch_metadata_allow_sync
=
not
(
os
.
getenv
(
"PROTON_LAUNCH_METADATA_NOSYNC"
)
==
"1"
)
return
_launch_metadata_allow_sync
def
set_launch_metadata_allow_sync
(
allow_sync
:
bool
):
global
_launch_metadata_allow_sync
_launch_metadata_allow_sync
=
allow_sync
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/reduction_details/reduce_bitmatrix.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
vpopc
(
x
):
"""
Vertical popcount
Input x : uint32[..., N]
Output y : uint32[..., 32]
semantics : y[..., i] = sum_j((x[..., j] >> i) & 1)
credits: @apgoucher
"""
tl
.
static_assert
(
x
.
dtype
==
tl
.
uint32
,
"x should consist of 32-bit unsigned integers"
)
BLOCK_N
:
tl
.
constexpr
=
x
.
shape
[
-
1
]
# summation axis
BATCHES
:
tl
.
constexpr
=
x
.
numel
//
BLOCK_N
# number of batches
if
BLOCK_N
>=
8
:
sa1
:
tl
.
constexpr
=
8
else
:
sa1
:
tl
.
constexpr
=
BLOCK_N
# create 8-way sums in 4-bit fields:
y
=
tl
.
reshape
(
x
,
[
BATCHES
,
BLOCK_N
//
sa1
,
sa1
,
1
])
y
=
(
y
>>
tl
.
arange
(
0
,
4
)[
None
,
None
,
None
,
:])
&
0x11111111
y
=
tl
.
sum
(
y
,
2
)
# [BATCHES, BLOCK_N // sa1, 4]
if
BLOCK_N
>=
128
:
sa2
:
tl
.
constexpr
=
16
else
:
sa2
:
tl
.
constexpr
=
BLOCK_N
//
sa1
# create 128-way sums in 8-bit fields:
y
=
tl
.
reshape
(
y
,
[
BATCHES
,
BLOCK_N
//
(
sa1
*
sa2
),
sa2
,
1
,
4
])
y
=
(
y
>>
(
4
*
tl
.
arange
(
0
,
2
))[
None
,
None
,
None
,
:,
None
])
&
0x0F0F0F0F
y
=
tl
.
sum
(
y
,
2
)
# [BATCHES, BLOCK_N // (sa1 * sa2), 2, 4]
sa3
:
tl
.
constexpr
=
BLOCK_N
//
(
sa1
*
sa2
)
# create N-way sums in 32-bit fields:
y
=
tl
.
reshape
(
y
,
[
BATCHES
,
1
,
sa3
,
8
])
y
=
(
y
>>
(
8
*
tl
.
arange
(
0
,
4
))[
None
,
:,
None
,
None
])
&
0x000000FF
y
=
tl
.
sum
(
y
,
2
)
# [BATCHES, 4, 8]
y
=
tl
.
reshape
(
y
,
x
.
shape
[:
-
1
]
+
[
32
])
return
y
@
triton
.
jit
def
_sum_bitmatrix_memset
(
Ret
,
BLOCK
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
0
)
offs
=
pid
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
tl
.
store
(
Ret
+
offs
,
0
)
@
triton
.
jit
def
_sum_bitmatrix_rows
(
B
,
shape_bm
,
stride_bm
:
tl
.
constexpr
,
stride_bn
:
tl
.
constexpr
,
# input bitmatrix
Ret
,
Partials
,
stride_pm
:
tl
.
constexpr
,
stride_pn
,
shape_pn
,
# outputs
BLOCK_MM
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
tl
.
static_assert
(
BLOCK_MM
%
BLOCK_M
==
0
)
TILE_SIZE
:
tl
.
constexpr
=
BLOCK_MM
//
BLOCK_M
if
isinstance
(
shape_bm
,
tl
.
tensor
)
and
shape_bm
.
dtype
.
is_ptr
():
shape_bm
=
tl
.
load
(
shape_bm
)
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_MM
+
tl
.
arange
(
0
,
BLOCK_MM
)
offs_n
=
pid_n
*
32
+
tl
.
arange
(
0
,
32
)
n_rows
=
shape_bm
bits
=
tl
.
load
(
B
+
pid_n
*
stride_bn
+
offs_m
*
stride_bm
,
mask
=
offs_m
<
n_rows
,
other
=
0
)
bits
=
tl
.
reshape
(
bits
,
[
TILE_SIZE
,
BLOCK_M
])
ret
=
vpopc
(
bits
)
# [TILE_SIZE, 32]
offs_t
=
pid_m
*
TILE_SIZE
+
tl
.
arange
(
0
,
TILE_SIZE
)
tl
.
atomic_add
(
Ret
+
offs_n
,
tl
.
sum
(
ret
,
0
),
sem
=
"relaxed"
)
tl
.
store
(
Partials
+
offs_t
[:,
None
]
*
stride_pm
+
offs_n
[
None
,
:]
*
stride_pn
,
ret
)
def
clear_sums
(
n_cols
,
device
,
MEMSET_BLOCK
=
512
):
cdiv
=
triton
.
cdiv
blocks
=
cdiv
(
n_cols
,
MEMSET_BLOCK
)
out_ret
=
torch
.
empty
((
blocks
*
MEMSET_BLOCK
,),
device
=
device
,
dtype
=
torch
.
int32
)
_sum_bitmatrix_memset
[(
blocks
,)](
out_ret
,
MEMSET_BLOCK
)
return
out_ret
def
sum_bitmatrix_rows
(
x
,
out_ret
,
partials_block_size
=
None
):
assert
partials_block_size
is
not
None
cdiv
=
triton
.
cdiv
PARTIALS_BLOCK_M
=
partials_block_size
n_rows
,
n_cols
=
x
.
shape
n_rows_max
=
x
.
shape_max
[
0
]
assert
out_ret
.
shape
==
(
n_cols
,)
TILE_SIZE
=
max
(
1
,
128
//
PARTIALS_BLOCK_M
)
BLOCK_MM
=
PARTIALS_BLOCK_M
*
TILE_SIZE
pids_x
=
cdiv
(
n_rows_max
,
BLOCK_MM
)
pids_y
=
cdiv
(
n_cols
,
32
)
out_partials
=
torch
.
empty
(
(
pids_y
*
32
,
pids_x
*
TILE_SIZE
),
device
=
out_ret
.
device
,
dtype
=
torch
.
int32
)
out_partials
=
torch
.
transpose
(
out_partials
,
0
,
1
)
# output tensors
_sum_bitmatrix_rows
[(
pids_x
,
pids_y
)](
x
.
storage
.
data
,
n_rows
,
x
.
stride
(
0
),
x
.
stride
(
1
),
# input
out_ret
,
# output [final reduction]
out_partials
,
out_partials
.
stride
(
0
),
out_partials
.
stride
(
1
),
out_partials
.
shape
[
1
],
# output [partial reductions]
BLOCK_M
=
PARTIALS_BLOCK_M
,
BLOCK_MM
=
BLOCK_MM
,
# constants
num_warps
=
8
,
)
out_partials
=
out_partials
[:
cdiv
(
n_rows_max
,
PARTIALS_BLOCK_M
),
:]
return
out_ret
,
out_partials
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
triton
from
dataclasses
import
dataclass
,
field
from
.routing_details._routing_compute
import
_combined_routing_compute
from
.routing_details._routing_compute
import
_combined_routing_memset
from
.routing_details._routing_compute
import
_routing_clear_bitmatrix
from
.routing_details._expt_data
import
_expt_data_memset
from
.routing_details._expt_data
import
_expt_data_compute
from
.target_info
import
is_hip
@
dataclass
class
GatherIndx
:
"""
Indices for an operation that performs:
Y = X[src_idx, :]
"""
# array such that `dst_idx[src_idx] = arange(0, N)`
src_indx
:
torch
.
Tensor
dst_indx
:
torch
.
Tensor
@
dataclass
class
ScatterIndx
:
"""
Indices for an operation that performs:
Y[dst_idx, :] = X
"""
# array such that `dst_idx[src_idx] = arange(0, N)`
src_indx
:
torch
.
Tensor
dst_indx
:
torch
.
Tensor
@
dataclass
class
ExptData
:
# hist[i] is the number of tokens routed to expert i
hist
:
torch
.
Tensor
# token_offs_raw[i] is the offset of the first token routed
# to expert i in an expert-sorted array
token_offs_raw
:
torch
.
Tensor
# token_offs_pad[block][i] is the offset of the first token routed
# to expert i in an expert-sorted array, assuming histogram
# rounded to the next multiple of `block`
token_offs_pad
:
dict
[
int
,
torch
.
Tensor
]
# block_id_map[block] contain one value for each `pid`` launched by
# the matrix multiplication kernel launched with BLOCK_M=block:
# - the value is -1 if the `pid` has no work to do
# - otherwise, the value is two int16 (packed as an int32) that
# correspond respectively to (1) the expert assigned to
# the tokens processed by this pid; (2) the block assigned to the
# tokens processed by this pid (think `pid_m` in a regular matmul)
# see `test_routing.py` for a reference implementation and more details
block_pid_map
:
dict
[
int
,
torch
.
Tensor
]
def
__post_init__
(
self
):
if
self
.
hist
is
not
None
:
assert
self
.
hist
.
dtype
==
torch
.
int32
if
self
.
token_offs_raw
is
not
None
:
assert
self
.
token_offs_raw
.
dtype
==
torch
.
int32
if
self
.
token_offs_pad
is
not
None
:
for
v
in
self
.
token_offs_pad
.
values
():
assert
v
.
dtype
==
torch
.
int32
if
self
.
block_pid_map
is
not
None
:
for
v
in
self
.
block_pid_map
.
values
():
assert
v
.
dtype
==
torch
.
int32
@
dataclass
class
RoutingData
:
gate_scal
:
torch
.
Tensor
=
field
()
expt_hist
:
torch
.
Tensor
=
field
()
n_expts_tot
:
int
=
field
()
n_expts_act
:
int
=
field
()
expt_data
:
ExptData
=
None
# Used to make perf annotation cleaner: when we use expert sharding, we can
# use this to tell the "expected" number of local tokens per expert, because
# the actual number can vary per each input.
expected_tokens_per_expt
:
int
=
field
(
default
=
None
)
def
n_blocks
(
self
,
n_rows
,
block_m
):
if
n_rows
<=
self
.
n_expts_tot
:
return
n_rows
else
:
return
(
triton
.
cdiv
(
max
(
n_rows
-
self
.
n_expts_tot
+
1
,
0
),
block_m
)
+
self
.
n_expts_tot
-
1
)
# --------------------------
# sort tokens by expert
# --------------------------
class
SortTokens
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
expt_scal
,
expt_indx
,
n_expts_tot
,
bitmatrix
):
HIST_BLOCK_M
=
32
INDX_OFFS_BLOCK_M
=
512
MEMSET_BLOCK
=
1024
cdiv
=
triton
.
cdiv
device
=
expt_scal
.
device
dtype
=
expt_scal
.
dtype
n_tokens_raw
,
_
=
bitmatrix
.
shape
n_tokens_pad
,
n_expts_act
=
expt_scal
.
shape
n_gates_pad
=
n_tokens_pad
*
n_expts_act
hist
,
partial_hist
=
bitmatrix
.
sum
(
partials_block_size
=
HIST_BLOCK_M
)
hist
=
hist
[:
n_expts_tot
]
assert
hist
.
dtype
==
torch
.
int32
# scratchpad
expt_offs
=
torch
.
empty
(
n_expts_tot
,
dtype
=
torch
.
int32
,
device
=
device
)
combined_indx
=
torch
.
empty
(
n_gates_pad
*
2
,
dtype
=
torch
.
int32
,
device
=
device
)
# output
topk_indx
=
combined_indx
[:
n_gates_pad
]
gate_indx
=
combined_indx
[
n_gates_pad
:]
gate_scal
=
torch
.
empty
(
n_gates_pad
,
dtype
=
dtype
,
device
=
device
)
(
token_offs_combined
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
blocks1a
,
blocks2a
,
MEMSET_BLOCK_A
,
HIST2_BLOCK_M
,
block_m_log2_start
,
block_m_num
,
)
=
_compute_expt_data_internal
(
hist
,
n_expts_tot
,
n_gates_pad
)
blocks1b
=
cdiv
(
n_gates_pad
*
2
,
MEMSET_BLOCK
)
+
n_expts_tot
+
1
blocks2b
=
cdiv
(
n_tokens_pad
,
HIST_BLOCK_M
)
_combined_routing_memset
[(
blocks1a
+
blocks1b
,)](
combined_indx
,
n_gates_pad
*
2
,
-
1
,
MEMSET_BLOCK
,
hist
,
#
expt_offs
,
hist
.
shape
[
0
],
n_expts_tot
,
partial_hist
,
# inputs
partial_hist
.
shape
[
0
],
partial_hist
.
stride
(
0
),
partial_hist
.
stride
(
1
),
# outputs
token_offs_combined
,
token_offs_combined
.
stride
(
0
),
#
blocks1a
,
block_pid_map
,
#
block_m_log2_start
,
SIZES
=
block_m_num
,
BLOCK_A
=
MEMSET_BLOCK_A
,
# optimization parameters
BLOCK_N
=
512
,
BLOCK_M
=
INDX_OFFS_BLOCK_M
,
# tunable parameters
)
indx_offs
=
partial_hist
_combined_routing_compute
[(
blocks2a
+
blocks2b
,)](
topk_indx
,
gate_indx
,
gate_scal
,
# outputs
expt_scal
,
expt_indx
,
indx_offs
,
indx_offs
.
stride
(
0
),
indx_offs
.
stride
(
1
),
# inputs
expt_offs
,
n_tokens_raw
,
# input shape
HIST_BLOCK_M
,
n_expts_act
,
# constants
hist
,
token_offs_pad
,
token_offs_pad
.
stride
(
0
),
block_pid_map
,
block_pid_map
.
stride
(
0
),
# outputs
block_m_log2_start
,
block_m_num
,
HIST2_BLOCK_M
,
blocks2a
,
# etc.
)
ctx
.
n_tokens_raw
=
n_tokens_raw
ctx
.
n_tokens_pad
=
n_tokens_pad
ctx
.
n_expts_act
=
n_expts_act
ctx
.
save_for_backward
(
gate_indx
)
return
(
hist
,
topk_indx
,
gate_indx
,
gate_scal
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
)
@
staticmethod
def
backward
(
ctx
,
_0
,
_1
,
_2
,
dgate_scal
,
_3
,
_4
,
_5
):
(
gate_indx
,)
=
ctx
.
saved_tensors
dgate_scal
=
dgate_scal
[
gate_indx
]
dgate_scal
=
dgate_scal
.
reshape
(
ctx
.
n_tokens_pad
,
ctx
.
n_expts_act
)
return
dgate_scal
,
None
,
None
,
None
def
sort_tokens
(
expt_scal
,
expt_indx
,
n_expts_tot
,
bitmatrix
):
return
SortTokens
.
apply
(
expt_scal
,
expt_indx
,
n_expts_tot
,
bitmatrix
)
# --------------------------
# prune routing
# --------------------------
class
PruneRouting
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
expt_scal
,
expt_indx
,
bitmatrix
,
n_expts_tot
,
simulated_ep
):
from
.compaction
import
compaction
n_tokens_pad
=
expt_scal
.
shape
[
0
]
assert
n_expts_tot
%
simulated_ep
==
0
_routing_clear_bitmatrix
[(
n_tokens_pad
,)](
bitmatrix
.
storage
.
data
,
bitmatrix
.
storage
.
data
.
stride
(
0
),
bitmatrix
.
storage
.
data
.
stride
(
1
),
bitmatrix
.
storage
.
data
.
shape
[
1
],
n_expts_tot
//
simulated_ep
,
BLOCK_N
=
512
,
)
# perform compaction to update expt_scal / expt_indx
expt_scal
,
expt_indx
=
compaction
(
expt_scal
,
expt_indx
,
bitmatrix
)
n_expts_tot
=
n_expts_tot
//
simulated_ep
bitmatrix
.
shape
[
-
1
]
=
n_expts_tot
return
expt_scal
,
expt_indx
,
bitmatrix
def
prune_routing
(
expt_scal
,
expt_indx
,
bitmatrix
,
n_expts_tot
,
simulated_ep
):
return
PruneRouting
.
apply
(
expt_scal
,
expt_indx
,
bitmatrix
,
n_expts_tot
,
simulated_ep
)
# --------------------------
# expt_data
# --------------------------
def
log2_power_of_two
(
x
):
assert
x
>
0
and
(
x
&
(
x
-
1
))
==
0
,
"x must be a power of two"
return
x
.
bit_length
()
-
1
block_m_log2_start
=
4
def
_compute_expt_data_internal
(
expt_hist
,
n_expts_tot
,
n_gates
):
MEMSET_BLOCK
=
512
HIST2_BLOCK_M
=
512
device
=
expt_hist
.
device
n_expts_tot
=
n_expts_tot
cdiv
=
triton
.
cdiv
# block_ms are all powers-of-two between 16 and 128 (inclusive)
block_m_log2_end
=
9
if
is_hip
()
else
8
block_m_num
=
block_m_log2_end
-
block_m_log2_start
if
n_gates
<=
n_expts_tot
:
max_n_tiles
=
n_gates
else
:
max_n_tiles
=
(
n_expts_tot
-
1
-
((
n_expts_tot
-
n_gates
-
1
)
//
2
**
block_m_log2_start
)
)
# allocate memory
pad
=
lambda
x
:
cdiv
(
x
,
MEMSET_BLOCK
)
*
MEMSET_BLOCK
dtype
=
torch
.
int32
token_offs_combined
=
torch
.
empty
(
(
block_m_num
+
1
,
pad
(
n_expts_tot
+
1
)),
dtype
=
dtype
,
device
=
device
)
token_offs_raw
=
token_offs_combined
[
0
][:
n_expts_tot
+
1
]
token_offs_pad
=
token_offs_combined
[
1
:]
block_pid_map
=
torch
.
empty
(
(
block_m_num
,
pad
(
max_n_tiles
)),
dtype
=
dtype
,
device
=
device
)
memset_grid
=
torch
.
numel
(
block_pid_map
)
//
MEMSET_BLOCK
# exact division
# compute outputs
token_offs_pad
=
token_offs_pad
[:,
:
n_expts_tot
+
1
]
block_pid_map
=
block_pid_map
[:,
:
max_n_tiles
]
blocks1
=
memset_grid
+
block_m_num
+
1
blocks2
=
n_expts_tot
*
block_m_num
return
(
token_offs_combined
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
blocks1
,
blocks2
,
MEMSET_BLOCK
,
HIST2_BLOCK_M
,
block_m_log2_start
,
block_m_num
,
)
def
_unpack_into_dict
(
x
):
block_m_log2_end
=
block_m_log2_start
+
x
.
shape
[
0
]
x
=
{
2
**
j
:
x
[
i
,
:]
for
i
,
j
in
enumerate
(
range
(
block_m_log2_start
,
block_m_log2_end
))
}
return
x
def
compute_expt_data
(
expt_hist
,
n_expts_tot
,
n_gates
):
if
expt_hist
is
None
:
return
ExptData
(
None
,
None
,
None
,
None
)
# this just computes the kernel arguments:
(
token_offs_combined
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
blocks1
,
blocks2
,
MEMSET_BLOCK
,
HIST2_BLOCK_M
,
block_m_log2_start
,
block_m_num
,
)
=
_compute_expt_data_internal
(
expt_hist
,
n_expts_tot
,
n_gates
)
_expt_data_memset
[(
blocks1
,)](
expt_hist
,
n_expts_tot
,
#
token_offs_combined
,
token_offs_combined
.
stride
(
0
),
#
block_pid_map
,
#
block_m_log2_start
,
SIZES
=
block_m_num
,
BLOCK
=
MEMSET_BLOCK
,
# optimization parameters
num_warps
=
4
,
)
_expt_data_compute
[(
blocks2
,)](
expt_hist
,
token_offs_pad
,
token_offs_pad
.
stride
(
0
),
block_pid_map
,
block_pid_map
.
stride
(
0
),
# outputs
block_m_log2_start
,
SIZES
=
block_m_num
,
BLOCK
=
HIST2_BLOCK_M
,
# optimization parameters
num_warps
=
4
,
)
token_offs_pad
=
_unpack_into_dict
(
token_offs_pad
)
block_pid_map
=
_unpack_into_dict
(
block_pid_map
)
return
ExptData
(
expt_hist
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
)
# --------------------------
# routing
# --------------------------
def
routing_from_bitmatrix
(
bitmatrix
,
expt_scal
,
expt_indx
,
n_expts_tot
,
n_expts_act
):
(
hist
,
topk_indx
,
gate_indx
,
gate_scal
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
,
)
=
sort_tokens
(
expt_scal
,
expt_indx
,
n_expts_tot
,
bitmatrix
)
token_offs_pad
=
_unpack_into_dict
(
token_offs_pad
)
block_pid_map
=
_unpack_into_dict
(
block_pid_map
)
expt_data
=
ExptData
(
hist
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
)
# pack the matmul data structure
gather_indx
=
GatherIndx
(
src_indx
=
topk_indx
,
dst_indx
=
gate_indx
)
scatter_indx
=
ScatterIndx
(
src_indx
=
gate_indx
,
dst_indx
=
topk_indx
)
return
(
RoutingData
(
gate_scal
,
hist
,
n_expts_tot
,
n_expts_act
,
expt_data
),
gather_indx
,
scatter_indx
,
)
def
routing
(
logits
,
n_expts_act
,
sm_first
=
False
,
expt_indx
=
None
,
simulated_ep
=
1
,
n_rows
=
None
):
from
.topk
import
topk
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
expt_scal
,
expt_indx
,
bitmatrix
=
topk
(
logits
,
n_expts_act
,
#
apply_softmax
=
not
sm_first
,
y_indx
=
expt_indx
,
n_rows
=
n_rows
,
)
n_expts_tot
=
logits
.
shape
[
-
1
]
//
simulated_ep
# mutate bitmatrix
if
simulated_ep
>
1
:
expt_scal
,
expt_indx
,
bitmatrix
=
prune_routing
(
expt_scal
,
expt_indx
,
bitmatrix
,
logits
.
shape
[
-
1
],
simulated_ep
)
return
routing_from_bitmatrix
(
bitmatrix
,
expt_scal
,
expt_indx
,
n_expts_tot
,
n_expts_act
)
# --------------------------
# torch reference
# --------------------------
def
compute_expt_data_torch
(
hist
,
n_expts_tot
,
n_gates
):
# offset for each experts
device
=
hist
.
device
token_offs_raw
=
torch
.
cumsum
(
hist
,
dim
=
0
)
token_offs_raw
=
torch
.
cat
((
torch
.
zeros
(
1
,
device
=
device
),
token_offs_raw
))
token_offs_raw
=
token_offs_raw
.
int
()
# maximum number of tiles for all values of `block_m` considered
block_ms
=
[
16
,
32
,
64
,
128
]
if
is_hip
():
block_ms
.
append
(
256
)
if
n_gates
<=
n_expts_tot
:
max_n_tiles
=
n_gates
else
:
# ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
# ceil_div(x, y): -(-x // y)
max_n_tiles
=
n_expts_tot
-
1
-
((
n_expts_tot
-
n_gates
-
1
)
//
min
(
block_ms
))
# fill up tile offset/infos for each block
token_offs_pad
=
dict
()
block_pid_map
=
dict
()
for
block_m
in
block_ms
:
n_tiles
=
(
hist
+
block_m
-
1
)
//
block_m
# matmul blocks needed
token_offs_pad
[
block_m
]
=
torch
.
cumsum
(
n_tiles
,
dim
=
0
)
token_offs_pad
[
block_m
]
=
torch
.
cat
(
(
torch
.
zeros
(
1
,
device
=
device
),
token_offs_pad
[
block_m
])
)
token_offs_pad
[
block_m
]
=
token_offs_pad
[
block_m
].
int
()
# compute data required to drive ragged batch matmul
block_pid_map
[
block_m
]
=
-
torch
.
ones
(
max_n_tiles
,
dtype
=
torch
.
int32
,
device
=
device
)
# for e in range(n_expts_tot):
# offset = token_offs_pad[block_m][e]
# for b in range(n_tiles[e]):
# block_pid_map[block_m][offset + b] = (b << 16) + e
col
=
torch
.
arange
(
max_n_tiles
,
device
=
device
)
map_vals
=
(
torch
.
arange
(
n_expts_tot
,
device
=
device
)[:,
None
]
+
(
col
<<
16
)[
None
,
:]
)
map_idxs
=
token_offs_pad
[
block_m
][:
-
1
,
None
]
+
col
[
None
,
:]
mask
=
col
[
None
,
:]
<
n_tiles
[:,
None
]
block_pid_map
[
block_m
].
index_put_
((
map_idxs
[
mask
],),
map_vals
.
int
()[
mask
])
return
ExptData
(
hist
,
token_offs_raw
,
token_offs_pad
,
block_pid_map
)
def
topk_torch
(
vals
,
k
,
expt_indx
,
has_user_provided_indx
=
False
):
# topk of experts
if
has_user_provided_indx
:
tk_indx
=
expt_indx
else
:
tk_indx
=
torch
.
argsort
(
-
vals
,
dim
=
1
,
stable
=
True
)[:,
:
k
]
tk_indx
=
tk_indx
.
long
()
tk_val
=
torch
.
take_along_dim
(
vals
,
tk_indx
,
dim
=
1
)
tk_indx
=
tk_indx
.
int
()
return
tk_val
,
tk_indx
def
routing_torch
(
logits
,
n_expts_act
,
sm_first
=
False
,
expt_indx
=
None
,
n_rows
=
None
):
has_user_provided_indx
=
expt_indx
is
not
None
n_gates_pad
=
logits
.
shape
[
0
]
*
n_expts_act
if
n_rows
is
not
None
:
logits
=
logits
[:
n_rows
,
:]
_
,
n_expts_tot
=
logits
.
shape
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
expt_scal
,
expt_indx
=
topk_torch
(
logits
,
n_expts_act
,
expt_indx
,
has_user_provided_indx
=
has_user_provided_indx
)
if
not
sm_first
:
expt_scal
=
torch
.
softmax
(
expt_scal
,
dim
=-
1
)
# sort each token's selections by expert
if
not
has_user_provided_indx
:
expt_indx
,
sort_indices
=
torch
.
sort
(
expt_indx
,
dim
=
1
)
expt_scal
=
torch
.
gather
(
expt_scal
,
1
,
sort_indices
)
# flatten topk data
expt_scal
=
expt_scal
.
reshape
(
-
1
)
expt_indx
=
expt_indx
.
reshape
(
-
1
).
to
(
torch
.
int32
)
# sort by expert_id so experts are contiguous for the matmul
topk_indx
=
torch
.
argsort
(
expt_indx
,
stable
=
True
)
gate_indx
=
torch
.
argsort
(
topk_indx
,
stable
=
True
)
gate_scal
=
expt_scal
[
topk_indx
]
hist
=
torch
.
histc
(
expt_indx
,
bins
=
n_expts_tot
,
max
=
n_expts_tot
-
1
).
int
()
# histogram of tokens over experts
# pack the matmul data structure
gather_indx
=
GatherIndx
(
src_indx
=
topk_indx
.
int
(),
dst_indx
=
gate_indx
.
int
())
scatter_indx
=
ScatterIndx
(
src_indx
=
gate_indx
.
int
(),
dst_indx
=
topk_indx
.
int
())
# compute expt_data
expt_data
=
compute_expt_data_torch
(
hist
,
n_expts_tot
,
n_gates_pad
)
return
(
RoutingData
(
gate_scal
,
hist
,
n_expts_tot
,
n_expts_act
,
expt_data
),
gather_indx
,
scatter_indx
,
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_expt_data.py
deleted
100644 → 0
View file @
2b7160c6
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_cdiv_pow2
(
n
,
log2_k
):
return
(
n
+
((
1
<<
log2_k
)
-
1
))
>>
log2_k
@
triton
.
jit
def
_expt_data_memset
(
Hist
,
n_expts_tot
,
MDStarts
,
tile_starts_stridem
,
MDTileInfo
,
first_tile_dim_log2
,
SIZES
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
if
pid
<=
SIZES
:
MDStarts
+=
pid
*
tile_starts_stridem
x_tile
=
tl
.
zeros
([
BLOCK
],
dtype
=
MDStarts
.
dtype
.
element_ty
)
Tile_ptrs
=
MDStarts
+
tl
.
arange
(
0
,
BLOCK
)
tile_dim_log2
=
tl
.
where
(
pid
==
0
,
0
,
pid
+
first_tile_dim_log2
-
1
)
for
i
in
range
(
0
,
n_expts_tot
+
1
,
BLOCK
):
offs_n
=
tl
.
arange
(
0
,
BLOCK
)
+
i
mask_n0
=
offs_n
<
n_expts_tot
hist_tok
=
tl
.
load
(
Hist
+
offs_n
,
mask
=
mask_n0
,
other
=
0
)
hist_tile
=
_cdiv_pow2
(
hist_tok
,
tile_dim_log2
)
tile_starts
=
tl
.
cumsum
(
hist_tile
,
0
)
+
x_tile
x_tile
+=
tl
.
sum
(
hist_tile
,
0
).
to
(
MDStarts
.
dtype
.
element_ty
)
tl
.
store
(
Tile_ptrs
,
tile_starts
-
hist_tile
)
Tile_ptrs
+=
BLOCK
else
:
pid
-=
SIZES
+
1
TileInfoOut
=
MDTileInfo
+
pid
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
tl
.
store
(
TileInfoOut
,
0xFFFFFFFF
)
@
triton
.
jit
def
_expt_data_compute
(
Hist
,
MDTileStarts
,
tile_starts_stridem
,
MDTileInfo
,
tile_info_stridem
,
first_tile_dim_log2
,
SIZES
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
expt_id
=
pid
//
SIZES
buff_id
=
pid
%
SIZES
MDTileStarts
+=
buff_id
*
tile_starts_stridem
MDTileInfo
+=
buff_id
*
tile_info_stridem
n_tokens
=
tl
.
load
(
Hist
+
expt_id
)
tile_dim_log2
=
first_tile_dim_log2
+
buff_id
n_blocks
=
_cdiv_pow2
(
n_tokens
,
tile_dim_log2
)
tile_off
=
tl
.
load
(
MDTileStarts
+
expt_id
)
MDTileInfo
+=
tile_off
for
block_off
in
range
(
0
,
n_blocks
,
BLOCK
):
block_offs
=
block_off
+
tl
.
arange
(
0
,
BLOCK
)
data
=
(
block_offs
<<
16
)
+
expt_id
tl
.
store
(
MDTileInfo
+
block_offs
,
data
,
mask
=
block_offs
<
n_blocks
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/routing_details/_routing_compute.py
deleted
100644 → 0
View file @
2b7160c6
import
triton
import
triton.language
as
tl
from
._expt_data
import
_expt_data_compute
,
_expt_data_memset
@
triton
.
jit
def
_routing_compute_expt_offs
(
ExpertHist
,
FinalExpertOffs
,
hist_size
,
# histogram
BLOCK_N
:
tl
.
constexpr
,
):
loop_iterations
=
(
hist_size
+
BLOCK_N
-
1
)
//
BLOCK_N
x
=
tl
.
zeros
([
BLOCK_N
],
ExpertHist
.
dtype
.
element_ty
)
for
i
in
range
(
loop_iterations
):
offs_n
=
i
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_n
=
offs_n
<
hist_size
hist2
=
tl
.
load
(
ExpertHist
+
offs_n
,
mask
=
mask_n
)
tok_starts
=
tl
.
cumsum
(
hist2
,
0
)
-
hist2
+
x
x
+=
tl
.
sum
(
hist2
,
0
)
tl
.
store
(
FinalExpertOffs
+
offs_n
,
tok_starts
,
mask
=
mask_n
)
offs_n
+=
BLOCK_N
@
triton
.
jit
def
_routing_compute_indx_offs
(
PartialHist
,
shape_pm
,
stride_pm
,
stride_pn
,
BLOCK_M
:
tl
.
constexpr
,
expt_id
):
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
# iterate over input data
curr_sum
=
0
for
_
in
range
(
0
,
shape_pm
,
BLOCK_M
):
offs
=
offs_m
*
stride_pm
+
expt_id
*
stride_pn
curr
=
tl
.
load
(
PartialHist
+
offs
,
mask
=
offs_m
<
shape_pm
)
out
=
tl
.
cumsum
(
curr
,
0
)
+
curr_sum
curr_sum
+=
tl
.
sum
(
curr
,
0
)
tl
.
store
(
PartialHist
+
offs
,
out
-
curr
,
mask
=
offs_m
<
shape_pm
)
offs_m
+=
BLOCK_M
@
triton
.
jit
def
_keyed_add
(
x
,
y
):
# we keep the key in the upper 16 bits of a uint32:
key_mask
:
tl
.
constexpr
=
0xFFFF0000
kx
=
x
&
key_mask
ky
=
y
&
key_mask
z
=
tl
.
where
(
kx
==
ky
,
x
+
y
-
kx
,
y
)
return
z
@
triton
.
jit
def
_routing_compute_indx
(
pid_m
,
GatherIndx
,
ScatterIndx
,
GateScal
,
ExptScal
,
ExptIndx
,
PartialOffs
,
stride_pm
,
stride_pn
,
TokensStart
,
n_tokens
,
BLOCK_M
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
):
if
isinstance
(
n_tokens
,
tl
.
tensor
)
and
n_tokens
.
dtype
.
is_ptr
():
n_tokens
=
tl
.
load
(
n_tokens
)
n_gates
=
n_tokens
*
N_EXPTS_ACT
tl
.
static_assert
(
N_EXPTS_ACT
*
BLOCK_M
<=
32768
)
local_offs
=
tl
.
arange
(
0
,
N_EXPTS_ACT
*
BLOCK_M
)
offs
=
pid_m
*
BLOCK_M
*
N_EXPTS_ACT
+
local_offs
expert
=
tl
.
load
(
ExptIndx
+
offs
,
mask
=
(
offs
<
n_gates
),
other
=-
1
).
to
(
tl
.
uint32
)
# stable-sort by expert ID:
kv_pairs
=
((
expert
<<
16
)
|
local_offs
).
to
(
tl
.
uint32
)
kv_pairs
=
tl
.
sort
(
kv_pairs
,
0
)
expert
=
kv_pairs
>>
16
offs
=
pid_m
*
BLOCK_M
*
N_EXPTS_ACT
+
(
kv_pairs
&
0xFFFF
)
mask
=
expert
!=
0xFFFF
gate_scal
=
tl
.
load
(
ExptScal
+
offs
,
mask
=
mask
)
# compute run lengths in expert-sorted order:
x
=
kv_pairs
&
0xFFFF0000
|
0x00000001
expts_and_inclusive_run_lengths
=
tl
.
associative_scan
(
x
,
0
,
_keyed_add
)
exclusive_run_lengths
=
(
expts_and_inclusive_run_lengths
-
1
)
&
0xFFFF
gates
=
tl
.
load
(
PartialOffs
+
pid_m
*
stride_pm
+
expert
*
stride_pn
,
mask
=
mask
)
gates
+=
tl
.
load
(
TokensStart
+
expert
,
mask
=
mask
)
gates
+=
exclusive_run_lengths
tl
.
store
(
ScatterIndx
+
offs
,
gates
,
mask
=
mask
)
tl
.
store
(
GatherIndx
+
gates
,
offs
,
mask
=
mask
)
tl
.
store
(
GateScal
+
gates
,
gate_scal
,
mask
=
mask
)
@
triton
.
jit
def
_combined_routing_compute
(
GatherIndx
,
ScatterIndx
,
GateScal
,
ExptScal
,
ExptIndx
,
PartialOffs
,
stride_pm
,
stride_pn
,
TokensStart
,
n_tokens
,
BLOCK_M
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
Hist
,
MDTileStarts
,
tile_starts_stridem
,
MDTileInfo
,
tile_info_stridem
,
first_tile_dim_log2
,
SIZES
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
blocks2a
,
):
pid
=
tl
.
program_id
(
0
)
if
pid
<
blocks2a
:
_expt_data_compute
(
Hist
,
MDTileStarts
,
tile_starts_stridem
,
MDTileInfo
,
tile_info_stridem
,
first_tile_dim_log2
,
SIZES
,
BLOCK
,
)
else
:
pid
-=
blocks2a
_routing_compute_indx
(
pid
,
GatherIndx
,
ScatterIndx
,
GateScal
,
ExptScal
,
ExptIndx
,
PartialOffs
,
stride_pm
,
stride_pn
,
TokensStart
,
n_tokens
,
BLOCK_M
,
N_EXPTS_ACT
,
)
@
triton
.
jit
def
_routing_clear_bitmatrix
(
Bitmatrix
,
stride_bm
,
stride_bn
,
shape_bn
,
cutoff
,
BLOCK_N
:
tl
.
constexpr
):
pid_m
=
tl
.
program_id
(
0
)
cutoff_word
=
cutoff
//
32
cutoff_bit
=
cutoff
%
32
cutoff_mask
=
(
1
<<
(
cutoff_bit
))
-
1
for
start_n
in
range
(
0
,
shape_bn
,
BLOCK_N
):
offs_n
=
start_n
+
tl
.
arange
(
0
,
BLOCK_N
)
values
=
tl
.
load
(
Bitmatrix
+
pid_m
*
stride_bm
+
offs_n
*
stride_bn
,
mask
=
offs_n
<
shape_bn
)
values
=
tl
.
where
(
offs_n
==
cutoff_word
,
values
&
cutoff_mask
,
values
)
values
=
tl
.
where
(
offs_n
>
cutoff_word
,
0
,
values
)
tl
.
store
(
Bitmatrix
+
pid_m
*
stride_bm
+
offs_n
*
stride_bn
,
values
,
mask
=
offs_n
<
shape_bn
,
)
@
triton
.
jit
def
_combined_routing_memset
(
Indx
,
size
,
sentinel
,
BLOCK
:
tl
.
constexpr
,
ExpertHist
,
FinalExpertOffs
,
hist_size
,
n_expts_tot
,
PartialHist
,
shape_pm
,
stride_pm
,
stride_pn
,
MDStarts
,
tile_starts_stridem
,
blocks1a
,
MDTileInfo
,
first_tile_dim_log2
,
SIZES
:
tl
.
constexpr
,
BLOCK_A
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
"""
This kernel essentially combines 6 different pieces of functionality,
statically branching on the value of tl.program_id(0) to decide which
codepath to take.
pid == 0: create the token cumsum
1 <= pid <= SIZES: create a tile cumsum
SIZES < pid < blocks1a: initialise MDTileInfo to 0xffffffff
blocks1a <= pid < blocks1a + n_expts_tot: compute_indx_offs
pid == blocks1a + n_expts_tot: compute_expt_offs
pid > blocks1a + n_expts_tot: initialise Indx to sentinel
As each of these is a relatively trivial workload, launching them from
this single trampoline is beneficial as they can execute on different
streaming multiprocesses in parallel.
"""
pid
=
tl
.
program_id
(
0
)
if
pid
<
blocks1a
:
_expt_data_memset
(
ExpertHist
,
n_expts_tot
,
MDStarts
,
tile_starts_stridem
,
MDTileInfo
,
first_tile_dim_log2
,
SIZES
,
BLOCK_A
,
)
elif
pid
==
n_expts_tot
+
blocks1a
:
_routing_compute_expt_offs
(
ExpertHist
,
FinalExpertOffs
,
hist_size
,
BLOCK_N
)
elif
pid
<
n_expts_tot
+
blocks1a
:
_routing_compute_indx_offs
(
PartialHist
,
shape_pm
,
stride_pm
,
stride_pn
,
BLOCK_M
,
pid
-
blocks1a
)
else
:
offs
=
(
pid
-
n_expts_tot
-
blocks1a
-
1
)
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
mask
=
offs
<
size
tl
.
store
(
Indx
+
offs
,
sentinel
,
mask
=
mask
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/specialize.py
deleted
100644 → 0
View file @
2b7160c6
import
inspect
import
re
import
textwrap
import
types
import
triton
def
cacheable
(
f
):
"""
A decorator that allow you to write something of the form:
@cacheable
def my_kernel(): return (expression dynamically defining a kernel)
such that it interacts gracefully with triton cache and preload.
"""
g
=
f
()
g
.
fn
.
__name__
=
f
.
__name__
g
.
fn
.
__module__
=
f
.
__module__
g
.
fn
.
__qualname__
=
f
.
__qualname__
g
.
__name__
=
f
.
__name__
g
.
__module__
=
f
.
__module__
g
.
__qualname__
=
f
.
__qualname__
g
.
_fn_name
=
f
"
{
f
.
__module__
}
.
{
f
.
__qualname__
}
"
return
g
def
define_kernel
(
src
,
module
,
attrs
=
None
,
**
extra_globals
):
"""
Dynamically create a Triton function or kernel from a src string,
linking any symbols in the kernel to objects specified by extra_globals.
"""
# create templace function
def
_empty_fn
():
pass
gdict
=
dict
(
**
(
_empty_fn
.
__globals__
))
gdict
.
update
(
extra_globals
)
f
=
types
.
FunctionType
(
_empty_fn
.
__code__
,
gdict
)
f
.
__module__
=
module
.
__name__
src
=
textwrap
.
dedent
(
src
)
src
=
src
[
src
.
find
(
"def "
)
:]
stored_functions
=
[]
function_name
=
src
[
4
:].
split
(
"("
)[
0
].
strip
()
exec_globals
=
gdict
exec_globals
.
update
({
"stored_functions"
:
stored_functions
})
exec
(
src
+
"
\n\n
stored_functions.append("
+
function_name
+
")
\n
"
,
exec_globals
)
f
.
__signature__
=
inspect
.
signature
(
stored_functions
[
0
])
f
.
__name__
=
function_name
f
.
__doc__
=
stored_functions
[
0
].
__doc__
if
attrs
is
None
:
attrs
=
dict
()
f
=
triton
.
JITFunction
(
f
,
**
attrs
)
f
.
_unsafe_update_src
(
src
)
return
f
def
specialize
(
fn
,
module
,
constants
,
tuples
,
name
=
None
,
do_not_specialize
=
tuple
()):
assert
isinstance
(
fn
,
triton
.
runtime
.
jit
.
JITFunction
)
if
name
is
None
:
name
=
f
"
{
fn
.
__name__
}
"
# Get original source code
src
=
inspect
.
getsource
(
fn
.
fn
)
src
=
textwrap
.
dedent
(
src
)
lines
=
src
.
split
(
"
\n
"
)
# Skip decorator and def line
def_idx
=
next
(
i
for
i
,
line
in
enumerate
(
lines
)
if
line
.
strip
().
startswith
(
"def"
))
# separate header vs body LOC
header_end
=
def_idx
while
not
lines
[
header_end
].
rstrip
().
endswith
(
":"
):
header_end
+=
1
body_lines
=
lines
[
header_end
+
1
:]
header_lines
=
lines
[
def_idx
:
header_end
+
1
]
# clean-up header
header_clean
=
[
l
.
split
(
"#"
,
1
)[
0
].
strip
()
# keep code, discard comment
for
l
in
header_lines
if
l
.
split
(
"#"
,
1
)[
0
].
strip
()
# skip blank‑after‑comment lines
]
# decompose arguments
header_src
=
" "
.
join
(
header_clean
)
# turn it into a single line
m
=
re
.
search
(
r
"\((.*)\)\s*:"
,
header_src
)
if
not
m
:
raise
ValueError
(
"Could not parse function header"
)
args_str
=
m
.
group
(
1
)
args
=
[
arg
.
strip
()
for
arg
in
args_str
.
split
(
","
)
if
arg
.
strip
()]
non_specialized_args
=
[]
for
arg
in
args
:
arg_key
=
arg
.
split
(
":"
)[
0
].
split
(
"="
)[
0
].
strip
()
new_args
=
tuples
.
get
(
arg_key
,
[
arg
])
if
arg_key
not
in
constants
:
non_specialized_args
+=
new_args
# add global symbols
spec_fns
=
{
v
.
__name__
:
v
for
k
,
v
in
constants
.
items
()
if
isinstance
(
v
,
triton
.
runtime
.
jit
.
JITFunction
)
}
globals
=
spec_fns
|
fn
.
get_capture_scope
()
# build new source code and define kernel dynamically
new_signature
=
f
"def
{
name
}
(
{
', '
.
join
(
non_specialized_args
)
}
):"
constexpr_lines
=
[
f
"
{
key
}
: tl.constexpr =
{
value
.
__name__
if
callable
(
value
)
else
value
}
"
for
key
,
value
in
constants
.
items
()
]
tuple_lines
=
[
f
"
{
key
}
=
{
'('
+
','
.
join
(
value
)
+
(
','
if
len
(
value
)
>=
1
else
''
)
+
')'
}
"
for
key
,
value
in
tuples
.
items
()
]
new_src
=
"
\n
"
.
join
(
[
"@triton.jit"
,
new_signature
]
+
constexpr_lines
+
tuple_lines
+
body_lines
)
# find function parameters
sig
=
inspect
.
signature
(
triton
.
runtime
.
jit
.
JITFunction
.
__init__
)
params
=
list
(
sig
.
parameters
.
values
())[
2
:]
attrs
=
{
param
.
name
:
getattr
(
fn
,
param
.
name
,
param
.
default
)
for
param
in
params
}
# make a new repr which appends the repr of the specialized functions.
base_repr
=
attrs
[
"repr"
]
def
new_repr
(
specialization
):
ret
=
base_repr
(
specialization
)
for
spec_fn
in
spec_fns
.
values
():
spec_repr
=
spec_fn
.
repr
(
None
)
if
spec_repr
:
spec_repr
=
spec_repr
.
strip
(
"_"
)
if
spec_repr
:
ret
+=
f
"_
{
spec_repr
}
"
return
ret
attrs
[
"repr"
]
=
new_repr
if
do_not_specialize
:
attrs
[
"do_not_specialize"
]
=
do_not_specialize
ret
=
define_kernel
(
new_src
,
module
,
attrs
,
**
globals
)
return
ret
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu.py
deleted
100644 → 0
View file @
2b7160c6
from
dataclasses
import
dataclass
from
compactor_vllm.triton_kernels.numerics
import
InFlexData
,
OutFlexData
import
torch
import
triton
from
.swiglu_details._swiglu
import
_swiglu
,
_swiglu_fn
from
compactor_vllm.triton_kernels
import
target_info
@
dataclass
(
frozen
=
True
)
class
FlexCtx
:
out_data
:
OutFlexData
=
OutFlexData
()
inp_data
:
InFlexData
=
InFlexData
()
saturate_inf
:
bool
=
False
@
dataclass
(
frozen
=
True
)
class
PrecisionConfig
:
limit
:
float
flex_ctx
:
FlexCtx
=
FlexCtx
()
swiglu_fn
=
_swiglu_fn
class
SwiGLU
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
a
,
alpha
,
precision_config
,
routing_data
):
N
=
a
.
shape
[
-
1
]
M
=
a
.
numel
()
//
N
assert
a
.
stride
()[
-
1
]
==
1
assert
a
.
shape
[
-
1
]
%
2
==
0
out
=
torch
.
empty
(
size
=
(
M
,
N
//
2
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
flex_ctx
=
precision_config
.
flex_ctx
# optimization hyperparameters
BLOCK_M
,
BLOCK_N
=
32
//
a
.
itemsize
,
128
num_warps
=
4
kwargs
=
{
"maxnreg"
:
64
}
if
not
target_info
.
is_hip
()
else
{}
# launch semi-persistent kernel
N_BLOCKS
=
triton
.
cdiv
(
N
//
2
,
BLOCK_N
)
num_sms
=
target_info
.
num_sms
()
if
routing_data
is
not
None
:
waves_per_sm
=
32
if
target_info
.
is_hip
()
else
128
num_pid
=
num_sms
*
(
waves_per_sm
//
num_warps
)
M_BLOCKS
=
max
(
1
,
triton
.
cdiv
(
num_pid
,
N_BLOCKS
))
grid
=
(
min
(
M_BLOCKS
*
N_BLOCKS
,
4
*
num_sms
),)
else
:
M_BLOCKS
=
triton
.
cdiv
(
M
,
BLOCK_M
)
if
M_BLOCKS
*
N_BLOCKS
>=
8
*
num_sms
:
grid
=
(
8
*
num_sms
,)
else
:
grid
=
(
min
(
M_BLOCKS
*
N_BLOCKS
,
4
*
num_sms
),)
n_tokens
=
None
if
routing_data
is
not
None
:
n_tokens
=
routing_data
.
expt_data
.
token_offs_raw
[
routing_data
.
n_expts_tot
]
_swiglu
[
grid
](
flex_ctx
.
out_data
.
reinterpret
(
out
),
flex_ctx
.
out_data
.
expected_scale
,
flex_ctx
.
out_data
.
actual_scale
,
flex_ctx
.
out_data
.
checksum_scale
,
flex_ctx
.
inp_data
.
reinterpret
(
a
),
flex_ctx
.
inp_data
.
scale
,
alpha
,
M
,
N
//
2
,
a
.
shape
[
-
1
],
1
,
out
.
shape
[
-
1
],
1
,
precision_config
.
limit
,
n_tokens
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
EVEN_N
=
(
N
//
2
)
%
BLOCK_N
==
0
,
M_BLOCKS
=
M_BLOCKS
,
N_BLOCKS
=
N_BLOCKS
,
flexpoint_saturate_inf
=
flex_ctx
.
saturate_inf
,
num_warps
=
num_warps
,
**
kwargs
,
)
out
=
out
.
view
(
a
.
shape
[:
-
1
]
+
out
.
shape
[
-
1
:])
return
out
def
swiglu
(
a
,
alpha
,
precision_config
,
routing_data
=
None
):
return
SwiGLU
.
apply
(
a
,
alpha
,
precision_config
,
routing_data
)
def
swiglu_torch
(
a
,
alpha
,
precision_config
):
limit
=
precision_config
.
limit
a_gelu
=
a
[...,
::
2
]
if
limit
is
not
None
:
a_gelu
=
a_gelu
.
clamp
(
max
=
limit
)
a_linear
=
a
[...,
1
::
2
]
if
limit
is
not
None
:
a_linear
=
a_linear
.
clamp
(
min
=-
limit
,
max
=
limit
)
out_gelu
=
a_gelu
*
torch
.
sigmoid
(
alpha
*
a_gelu
)
out
=
out_gelu
*
(
a_linear
+
1
)
return
out
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/swiglu_details/_swiglu.py
deleted
100644 → 0
View file @
2b7160c6
from
compactor_vllm.triton_kernels.numerics_details.flexpoint
import
(
load_scale
,
float_to_flex
,
update_scale
,
)
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
clip
(
x
,
limit
,
clip_lower
:
tl
.
constexpr
):
res
=
tl
.
minimum
(
x
,
limit
)
if
clip_lower
:
res
=
tl
.
maximum
(
-
limit
,
res
)
return
res
@
triton
.
jit
def
thread_local_absmax
(
x
,
BLOCK_SIZE
:
tl
.
constexpr
,
NUM_THREADS
:
tl
.
constexpr
):
return
tl
.
max
(
tl
.
reshape
(
tl
.
abs
(
x
),
[
NUM_THREADS
,
BLOCK_SIZE
//
NUM_THREADS
],
can_reorder
=
True
),
axis
=
1
,
)
def
swiglu_repr
(
specialization
):
signature
=
specialization
.
signature
constants
=
specialization
.
constants
convert_dtype
=
lambda
dtype
:
"mxfp4"
if
"u8"
in
dtype
else
dtype
dtypes
=
"x"
.
join
([
convert_dtype
(
f
"
{
signature
[
i
][
1
:]
}
"
)
for
i
in
[
"Out"
,
"A"
]])
blocks
=
"x"
.
join
([
f
"
{
constants
[
i
]
}
"
for
i
in
[
"BLOCK_M"
,
"BLOCK_N"
]])
return
f
"_swiglu_
{
dtypes
}
_
{
blocks
}
"
def
swiglu_launch_metadata
(
grid
,
kernel
,
args
):
M
,
N
=
args
[
"M"
],
args
[
"N"
]
ret
=
dict
()
ret
[
"name"
]
=
f
"
{
kernel
.
name
}
[M =
{
M
}
, N =
{
N
}
]"
A
,
Out
=
args
[
"A"
],
args
[
"Out"
]
ret
[
"bytes"
]
=
Out
.
numel
()
*
Out
.
element_size
()
+
A
.
numel
()
*
A
.
element_size
()
return
ret
@
triton
.
jit
def
compute_swiglu
(
gelu
,
linear
,
scale
,
alpha
,
limit
):
gelu
=
gelu
.
to
(
tl
.
float32
)
*
scale
if
limit
is
not
None
:
gelu
=
clip
(
gelu
,
limit
,
clip_lower
=
False
)
linear
=
linear
.
to
(
tl
.
float32
)
*
scale
if
limit
is
not
None
:
linear
=
clip
(
linear
,
limit
,
clip_lower
=
True
)
s
=
gelu
/
(
1
+
tl
.
exp
(
-
alpha
*
gelu
))
return
tl
.
fma
(
s
,
linear
,
s
)
# (s * (linear + 1))
@
triton
.
jit
(
repr
=
lambda
_
:
"_swiglu"
)
def
_swiglu_fn
(
input
,
alpha
,
limit
):
gelu
,
linear
=
tl
.
split
(
tl
.
reshape
(
input
,
(
input
.
shape
[
0
],
input
.
shape
[
1
]
//
2
,
2
)))
return
compute_swiglu
(
gelu
,
linear
,
1.0
,
alpha
,
limit
)
@
triton
.
jit
(
repr
=
swiglu_repr
,
launch_metadata
=
swiglu_launch_metadata
)
def
_swiglu
(
Out
,
OutExpectedScale
,
OutActualScale
,
OutChecksumScale
,
A
,
AScale
,
alpha
,
M
,
N
,
stride_am
,
stride_an
,
stride_outm
,
stride_outn
,
limit
:
tl
.
constexpr
,
NTokens
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
M_BLOCKS
,
N_BLOCKS
,
flexpoint_saturate_inf
:
tl
.
constexpr
,
):
if
NTokens
is
not
None
:
M
=
tl
.
load
(
NTokens
)
M_BLOCKS
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
local_max
=
tl
.
full
([
tl
.
extra
.
cuda
.
num_threads
()],
0.0
,
tl
.
float32
)
a_scale
=
load_scale
(
AScale
)
out_expected_scale
=
load_scale
(
OutExpectedScale
)
for
pid
in
tl
.
range
(
tl
.
program_id
(
0
),
M_BLOCKS
*
N_BLOCKS
,
tl
.
num_programs
(
0
),
num_stages
=
2
):
pid_m
=
pid
//
N_BLOCKS
pid_n
=
pid
%
N_BLOCKS
off_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_m
=
off_m
<
M
mask_n
=
off_n
<
N
packed_off_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
2
*
BLOCK_N
)
//
2
packed_mask_n
=
packed_off_n
<
N
packed_mask_n
=
tl
.
max_constancy
(
packed_mask_n
,
[
16
])
# load a
packed_off_n
=
pid_n
*
2
*
BLOCK_N
+
tl
.
arange
(
0
,
2
*
BLOCK_N
)
packed_offs
=
off_m
[:,
None
]
*
stride_am
+
packed_off_n
[
None
,
:]
*
stride_an
if
EVEN_N
:
a_packed
=
tl
.
load
(
A
+
packed_offs
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
else
:
if
pid_n
*
BLOCK_N
+
BLOCK_N
<=
N
:
a_packed
=
tl
.
load
(
A
+
packed_offs
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
else
:
packed_mask
=
mask_m
[:,
None
]
&
packed_mask_n
[
None
,
:]
a_packed
=
tl
.
load
(
A
+
packed_offs
,
mask
=
packed_mask
,
other
=
0.0
)
a_gelu
,
a_linear
=
tl
.
split
(
tl
.
reshape
(
a_packed
,
(
BLOCK_M
,
BLOCK_N
,
2
)))
out
=
compute_swiglu
(
a_gelu
,
a_linear
,
a_scale
,
alpha
,
limit
)
# update flexpoint stats and divide by scale
# we don't need masking because of the `other` when loading `A`
if
OutActualScale
is
not
None
:
absmax
=
thread_local_absmax
(
out
,
out
.
numel
,
tl
.
extra
.
cuda
.
num_threads
())
local_max
=
tl
.
maximum
(
local_max
,
absmax
)
out
=
float_to_flex
(
out
,
out_expected_scale
,
None
,
# ActualScale: local absmax is tracked and updated after the loop
OutChecksumScale
,
None
,
Out
,
flexpoint_saturate_inf
,
)
mask
=
mask_m
[:,
None
]
if
EVEN_N
else
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
tl
.
store
(
Out
+
off_m
[:,
None
]
*
stride_outm
+
off_n
[
None
,
:]
*
stride_outn
,
out
,
mask
)
update_scale
(
local_max
,
OutActualScale
,
Out
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/target_info.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
triton
import
triton.language
as
tl
from
triton.language.target_info
import
(
cuda_capability_geq
,
is_cuda
,
is_hip
,
is_hip_cdna3
,
is_hip_cdna4
,
)
__all__
=
[
"cuda_capability_geq"
,
"get_cdna_version"
,
"has_tma_gather"
,
"has_native_mxfp"
,
"is_cuda"
,
"is_hip"
,
"is_hip_cdna3"
,
"is_hip_cdna4"
,
"num_sms"
,
]
@
triton
.
constexpr_function
def
get_cdna_version
():
"""
Gets the AMD architecture version, i.e. CDNA3 or CDNA4, currently
only supports 3 (gfx942) or 4 (gfx950). Returns -1 if it is not AMD
hardware or unsupported architecture
"""
target
=
tl
.
target_info
.
current_target
()
if
target
.
backend
!=
"hip"
:
return
-
1
if
target
.
arch
==
"gfx942"
:
return
3
if
target
.
arch
==
"gfx950"
:
return
4
return
-
1
@
triton
.
constexpr_function
def
has_tma_gather
():
return
cuda_capability_geq
(
10
,
0
)
@
triton
.
constexpr_function
def
has_native_mxfp
():
return
cuda_capability_geq
(
10
,
0
)
def
num_sms
():
return
torch
.
cuda
.
get_device_properties
(
0
).
multi_processor_count
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor.py
deleted
100644 → 0
View file @
2b7160c6
from
dataclasses
import
dataclass
,
fields
from
typing
import
Type
import
torch
from
triton.tools.tensor_descriptor
import
TensorDescriptor
from
triton.tools.ragged_tma
import
create_ragged_descriptor
from
.reduction_details.reduce_bitmatrix
import
clear_sums
,
sum_bitmatrix_rows
from
.target_info
import
cuda_capability_geq
from
.tensor_details.layout
import
Layout
,
StridedLayout
@
dataclass
class
Storage
:
data
:
torch
.
Tensor
layout
:
Layout
=
None
def
__post_init__
(
self
):
assert
isinstance
(
self
.
data
,
torch
.
Tensor
)
if
self
.
layout
is
None
:
self
.
layout
=
StridedLayout
(
self
.
data
.
shape
)
@
property
def
device
(
self
):
return
self
.
data
.
device
def
is_tma_compliant
(
self
):
# TMAs didn't exist until Hopper
if
not
cuda_capability_geq
(
9
,
0
):
return
False
# TMAs only exist for 2D, 3D, 5D inputs
if
len
(
self
.
data
.
shape
)
not
in
[
2
,
3
,
5
]:
return
False
# TMAs need at most one stride equal to 1
# and all other strides divisble by 16
strides
=
list
(
self
.
data
.
stride
())
try
:
major_dim
=
strides
.
index
(
1
)
except
ValueError
:
major_dim
=
-
1
ndim
=
self
.
data
.
ndim
bitwidth
=
4
if
self
.
data
.
dtype
==
torch
.
uint8
else
self
.
data
.
element_size
()
*
8
compliant
=
[
strides
[
i
]
*
bitwidth
%
128
==
0
for
i
in
range
(
ndim
)
if
i
!=
major_dim
]
return
all
(
compliant
)
def
make_dense_tma
(
self
,
block_shape
,
transpose
=
False
):
strides
=
list
(
self
.
data
.
stride
())
shape
=
list
(
self
.
data
.
shape
)
transpose
=
self
.
data
.
stride
()[
-
1
]
!=
1
if
transpose
:
block_shape
=
block_shape
[:
-
2
]
+
[
block_shape
[
-
1
],
block_shape
[
-
2
]]
shape
=
shape
[:
-
2
]
+
[
shape
[
-
1
],
shape
[
-
2
]]
strides
=
strides
[:
-
2
]
+
[
strides
[
-
1
],
strides
[
-
2
]]
if
self
.
data
.
dtype
==
torch
.
uint8
and
self
.
layout
.
name
==
"BLACKWELL_VALUE"
:
indx
=
strides
.
index
(
1
)
block_shape
[
indx
]
=
block_shape
[
indx
]
//
2
if
shape
[
-
1
]
%
128
!=
0
:
raise
ValueError
(
"inner shape need to be multiple of 128 for "
"mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs."
)
block_shape
=
self
.
layout
.
swizzle_block_shape
(
block_shape
)
return
TensorDescriptor
(
self
.
data
,
shape
,
strides
,
block_shape
)
def
make_tma
(
self
,
block_shape
,
mode
,
transpose
=
False
):
if
mode
in
[
"dense"
,
"gather"
,
"scatter"
]:
return
self
.
make_dense_tma
(
block_shape
,
transpose
)
assert
mode
==
"ragged"
ragged_dim
=
len
(
self
.
data
.
shape
)
-
2
return
create_ragged_descriptor
(
self
.
data
,
block_shape
,
ragged_dim
=
ragged_dim
)
@
dataclass
class
IntegerType
:
bitwidth
:
int
@
dataclass
class
FloatType
:
bitwidth_exponent
:
int
bitwidth_mantissa
:
int
is_signed
:
bool
def
__post_init__
(
self
):
self
.
bitwidth
=
(
int
(
self
.
is_signed
)
+
self
.
bitwidth_exponent
+
self
.
bitwidth_mantissa
)
BIT
=
IntegerType
(
1
)
FP4
=
FloatType
(
bitwidth_exponent
=
2
,
bitwidth_mantissa
=
1
,
is_signed
=
True
)
def
bitwidth
(
type
:
IntegerType
|
FloatType
|
torch
.
dtype
):
if
isinstance
(
type
,
torch
.
dtype
):
return
type
.
itemsize
*
8
return
type
.
bitwidth
@
dataclass
class
Tensor
:
storage
:
Storage
|
torch
.
Tensor
dtype
:
IntegerType
|
FloatType
|
torch
.
dtype
=
None
shape
:
list
[
int
]
|
None
=
None
shape_max
:
list
[
int
]
|
None
=
None
def
__post_init__
(
self
):
# set storage
if
isinstance
(
self
.
storage
,
torch
.
Tensor
):
self
.
storage
=
Storage
(
self
.
storage
)
# initialize dtype
if
self
.
dtype
is
None
:
self
.
dtype
=
self
.
storage
.
data
.
dtype
if
bitwidth
(
self
.
dtype
)
<
8
and
self
.
shape
is
None
:
raise
ValueError
(
"shape must be provided for sub-byte types"
)
# initialize shape
if
self
.
shape
is
None
:
self
.
shape
=
list
(
self
.
storage
.
data
.
shape
)
# validate shape: all elements must be `int` or numel-1 `torch.Tensor`
is_int
=
lambda
s
:
isinstance
(
s
,
int
)
is_item
=
lambda
s
:
hasattr
(
s
,
"numel"
)
and
s
.
numel
()
==
1
assert
all
(
map
(
lambda
s
:
is_int
(
s
)
or
is_item
(
s
),
self
.
shape
))
# initialize shape_max
if
self
.
shape_max
is
None
:
self
.
shape_max
=
[
None
]
*
len
(
self
.
shape
)
for
i
,
(
s
,
smax
)
in
enumerate
(
zip
(
self
.
shape
,
self
.
shape_max
)):
if
smax
is
not
None
and
not
is_int
(
smax
):
raise
ValueError
(
f
"shape_max[
{
i
}
] must be `int` or `None`; got
{
type
(
smax
)
}
"
)
if
smax
is
None
:
self
.
shape_max
[
i
]
=
s
# validate shape_max: all elements must be `int`
assert
all
(
map
(
is_int
,
self
.
shape_max
))
# torch compatibility layer
@
property
def
ndim
(
self
):
return
len
(
self
.
shape
)
@
property
def
device
(
self
):
return
self
.
storage
.
device
def
stride
(
self
,
i
=
None
):
return
self
.
storage
.
data
.
stride
()
if
i
is
None
else
self
.
storage
.
data
.
stride
(
i
)
def
data_ptr
(
self
):
return
self
.
storage
.
data
.
data_ptr
()
def
numel
(
self
):
return
self
.
storage
.
data
.
numel
()
def
element_size
(
self
):
return
bitwidth
(
self
.
dtype
)
//
8
@
property
def
data
(
self
):
t
=
self
.
storage
return
t
.
data
if
isinstance
(
t
,
Storage
)
else
t
def
dim
(
self
):
return
self
.
ndim
def
size
(
self
,
i
=
None
):
if
i
is
None
:
return
self
.
shape
return
self
.
shape
[
i
]
@
dataclass
class
Bitmatrix
(
Tensor
):
"""
Represents a boolean matrix in a packed format where each element occupies
a single bit of memory.
_scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
with the actual bitmatrix to avoid having to launch a separate memset
kernel when we call Bitmatrix::sum().
"""
scratchpad
:
torch
.
Tensor
=
None
def
__init__
(
self
,
storage
,
shape
,
shape_max
=
None
,
scratchpad
=
None
):
super
().
__init__
(
storage
,
dtype
=
BIT
,
shape
=
shape
,
shape_max
=
shape_max
)
self
.
scratchpad
=
scratchpad
def
sum
(
self
,
partials_block_size
):
_
,
n_cols
=
self
.
shape
dev
=
self
.
device
if
self
.
scratchpad
is
None
:
self
.
scratchpad
=
clear_sums
(
n_cols
,
dev
)
out_ret
=
self
.
scratchpad
[:
n_cols
]
self
.
scratchpad
=
None
# throw error if we try to sum again
return
sum_bitmatrix_rows
(
self
,
out_ret
,
partials_block_size
)
def
get_layout
(
tensor
:
torch
.
Tensor
|
Tensor
|
None
):
if
tensor
is
None
:
return
None
if
isinstance
(
tensor
,
Tensor
):
return
tensor
.
storage
.
layout
return
StridedLayout
def
wrap_torch_tensor
(
torch_tensor
,
dtype
=
None
):
if
dtype
is
None
:
dtype
=
torch_tensor
.
dtype
shape
=
list
(
torch_tensor
.
shape
)
shape
[
torch_tensor
.
stride
().
index
(
1
)]
*=
bitwidth
(
torch_tensor
.
dtype
)
//
bitwidth
(
dtype
)
return
Tensor
(
Storage
(
torch_tensor
),
dtype
=
dtype
,
shape
=
shape
)
def
convert_layout
(
tensor
:
Tensor
,
layout_cls
:
Type
[
Layout
],
**
layout_kwargs
):
assert
isinstance
(
tensor
,
Tensor
)
old_storage
=
tensor
.
storage
old_data
=
old_storage
.
layout
.
unswizzle_data
(
old_storage
.
data
)
new_layout
=
layout_cls
(
old_data
.
shape
,
**
layout_kwargs
)
new_data
=
new_layout
.
swizzle_data
(
old_data
)
attrs
=
{
k
.
name
:
getattr
(
tensor
,
k
.
name
)
for
k
in
fields
(
tensor
)
if
k
.
name
!=
"storage"
}
return
Tensor
(
Storage
(
new_data
,
new_layout
),
**
attrs
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout.py
deleted
100644 → 0
View file @
2b7160c6
from
.layout_details.base
import
Layout
from
.layout_details.blackwell_scale
import
BlackwellMXScaleLayout
from
.layout_details.blackwell_value
import
BlackwellMXValueLayout
from
.layout_details.hopper_scale
import
HopperMXScaleLayout
from
.layout_details.hopper_value
import
HopperMXValueLayout
from
.layout_details.cdna4_scale
import
CDNA4MXScaleLayout
from
.layout_details.strided
import
StridedLayout
from
..target_info
import
cuda_capability_geq
,
is_hip_cdna4
__all__
=
[
"Layout"
,
"BlackwellMXValueLayout"
,
"BlackwellMXScaleLayout"
,
"HopperMXScaleLayout"
,
"HopperMXValueLayout"
,
"CDNA4MXScaleLayout"
,
"StridedLayout"
,
]
def
make_default_matmul_mxfp4_w_layout
(
mx_axis
:
int
):
if
cuda_capability_geq
(
10
):
# return StridedLayout, dict()
return
BlackwellMXValueLayout
,
dict
()
elif
cuda_capability_geq
(
9
):
return
HopperMXValueLayout
,
{
"mx_axis"
:
mx_axis
}
else
:
return
StridedLayout
,
dict
()
def
make_default_matmul_mxfp4_w_scale_layout
(
mx_axis
:
int
,
num_warps
:
int
=
8
):
if
is_hip_cdna4
():
return
CDNA4MXScaleLayout
,
dict
()
else
:
if
cuda_capability_geq
(
10
):
return
BlackwellMXScaleLayout
,
dict
()
elif
cuda_capability_geq
(
9
):
return
HopperMXScaleLayout
,
{
"mx_axis"
:
mx_axis
,
"num_warps"
:
num_warps
}
return
StridedLayout
,
dict
()
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/base.py
deleted
100644 → 0
View file @
2b7160c6
from
abc
import
ABC
,
abstractmethod
class
Layout
(
ABC
):
def
__init__
(
self
,
shape
)
->
None
:
self
.
initial_shape
=
shape
@
abstractmethod
def
swizzle_data
(
self
,
data
):
pass
@
abstractmethod
def
unswizzle_data
(
self
,
data
):
pass
@
abstractmethod
def
swizzle_block_shape
(
self
,
block_shape
):
pass
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_scale.py
deleted
100644 → 0
View file @
2b7160c6
import
math
import
triton
import
triton.language
as
tl
import
torch
from
.base
import
Layout
SWIZZLE_ALIGN_INNER
=
8
SWIZZLE_SIZE_INNER
=
4
SWIZZLE_SIZE_OUTER
=
128
class
BlackwellMXScaleLayout
(
Layout
):
name
:
str
=
"BLACKWELL_SCALE"
def
__init__
(
self
,
shape
)
->
None
:
super
().
__init__
(
shape
)
(
*
self
.
leading_shape
,
self
.
K
,
self
.
N
,
)
=
shape
self
.
B
=
math
.
prod
(
self
.
leading_shape
)
self
.
ALIGN_K
=
8
self
.
ALIGN_N
=
128
self
.
SWIZZLE_K
=
4
self
.
K_pad
=
(
self
.
K
+
self
.
ALIGN_K
-
1
)
//
self
.
ALIGN_K
*
self
.
ALIGN_K
self
.
N_pad
=
(
self
.
N
+
self
.
ALIGN_N
-
1
)
//
self
.
ALIGN_N
*
self
.
ALIGN_N
def
swizzle_data
(
self
,
data
):
data
=
torch
.
nn
.
functional
.
pad
(
data
,
(
0
,
self
.
N_pad
-
self
.
N
,
0
,
self
.
K_pad
-
self
.
K
)
)
data
=
data
.
transpose
(
-
1
,
-
2
).
contiguous
()
data
=
data
.
reshape
(
self
.
B
,
self
.
N_pad
//
self
.
ALIGN_N
,
self
.
ALIGN_N
//
32
,
32
,
self
.
K_pad
//
self
.
SWIZZLE_K
,
self
.
SWIZZLE_K
,
)
data
=
data
.
transpose
(
2
,
4
).
contiguous
()
data
=
data
.
view
(
1
,
self
.
B
*
self
.
N_pad
//
128
,
self
.
K_pad
//
4
,
2
,
256
)
return
data
def
unswizzle_data
(
self
,
data
):
data
=
data
.
reshape
(
self
.
B
,
self
.
N_pad
//
self
.
ALIGN_N
,
self
.
K_pad
//
self
.
SWIZZLE_K
,
32
,
self
.
ALIGN_N
//
32
,
self
.
SWIZZLE_K
,
)
data
=
data
.
transpose
(
2
,
4
)
data
=
data
.
reshape
(
*
self
.
leading_shape
,
self
.
N_pad
,
self
.
K_pad
)
data
=
data
.
transpose
(
-
1
,
-
2
)
return
data
[...,
:
self
.
K
,
:
self
.
N
]
def
swizzle_block_shape
(
self
,
block_shape
):
MX_PACK_DIVISOR
=
32
MX_SCALE_BLOCK_K
=
block_shape
[
1
]
//
MX_PACK_DIVISOR
return
[
1
,
block_shape
[
0
]
//
128
,
MX_SCALE_BLOCK_K
//
4
,
2
,
256
]
@
triton
.
jit
def
unswizzle_mx_scale_bw
(
x
,
SIZE_OUTER
:
tl
.
constexpr
=
SWIZZLE_SIZE_OUTER
,
SIZE_INNER
:
tl
.
constexpr
=
SWIZZLE_SIZE_INNER
,
ALIGN_INNER
:
tl
.
constexpr
=
SWIZZLE_ALIGN_INNER
,
):
shape_0
:
tl
.
constexpr
=
x
.
shape
[
0
]
shape_1
:
tl
.
constexpr
=
x
.
shape
[
1
]
tl
.
static_assert
(
shape_1
%
SIZE_OUTER
==
0
)
tl
.
static_assert
(
shape_1
//
SIZE_OUTER
<=
ALIGN_INNER
)
x
=
x
.
reshape
(
shape_0
,
(
shape_1
//
SIZE_OUTER
)
//
SIZE_INNER
,
32
,
SIZE_OUTER
//
32
,
SIZE_INNER
)
x
=
x
.
trans
(
0
,
3
,
2
,
1
,
4
).
reshape
(
shape_0
*
SIZE_OUTER
,
shape_1
//
SIZE_OUTER
)
return
x
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/blackwell_value.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
from
.base
import
Layout
class
BlackwellMXValueLayout
(
Layout
):
name
:
str
=
"BLACKWELL_VALUE"
def
__init__
(
self
,
shape
)
->
None
:
super
().
__init__
(
shape
)
self
.
shape
=
shape
def
swizzle_data
(
self
,
data
):
# permutation needed to make `data` row major
to_row_major
=
sorted
(
range
(
data
.
ndim
),
key
=
lambda
d
:
(
data
.
stride
(
d
),
d
))[::
-
1
]
# permutation needed to retrieve original order
inv
=
[
0
]
*
data
.
ndim
for
i
,
d
in
enumerate
(
to_row_major
):
inv
[
d
]
=
i
# leading dimension must be padded to be aligned to 128
align_dim
=
lambda
x
:
(
x
+
128
-
1
)
//
128
*
128
major_dim
=
data
.
stride
().
index
(
1
)
pad
=
align_dim
(
data
.
shape
[
major_dim
])
-
data
.
shape
[
major_dim
]
data
=
torch
.
nn
.
functional
.
pad
(
data
.
permute
(
to_row_major
),
(
0
,
pad
)).
permute
(
inv
)
return
data
def
unswizzle_data
(
self
,
data
:
torch
.
Tensor
):
# Trim padding along all dims back to the original shape recorded at init.
assert
data
.
ndim
==
len
(
self
.
shape
),
(
"Rank mismatch between data and recorded shape"
)
sizes
=
[
min
(
data
.
size
(
i
),
self
.
shape
[
i
])
for
i
in
range
(
data
.
ndim
)]
return
data
[
tuple
(
slice
(
0
,
s
)
for
s
in
sizes
)]
def
swizzle_block_shape
(
self
,
block_shape
):
return
block_shape
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/tensor_details/layout_details/cdna4_scale.py
deleted
100644 → 0
View file @
2b7160c6
import
triton
import
triton.language
as
tl
from
.base
import
Layout
NON_K_PRESHUFFLE_BLOCK_SIZE
=
32
class
CDNA4MXScaleLayout
(
Layout
):
name
:
str
=
"CDNA4_SCALE"
def
__init__
(
self
,
shape
)
->
None
:
super
().
__init__
(
shape
)
def
swizzle_data
(
self
,
data
):
block_shape
=
data
.
shape
SCALE_K
=
block_shape
[
-
2
]
N
=
block_shape
[
-
1
]
data
=
data
.
transpose
(
-
1
,
-
2
)
data
=
data
.
view
(
-
1
,
N
//
NON_K_PRESHUFFLE_BLOCK_SIZE
,
2
,
16
,
SCALE_K
//
8
,
2
,
4
,
1
)
data
=
data
.
permute
(
0
,
1
,
4
,
6
,
3
,
5
,
2
,
7
).
contiguous
()
if
len
(
block_shape
)
==
3
:
E
=
block_shape
[
0
]
data
=
data
.
reshape
(
E
,
N
//
32
,
SCALE_K
*
32
)
else
:
assert
len
(
block_shape
)
==
2
data
=
data
.
reshape
(
N
//
32
,
SCALE_K
*
32
)
return
data
.
transpose
(
-
1
,
-
2
)
def
unswizzle_data
(
self
,
data
):
raise
NotImplementedError
()
def
swizzle_block_shape
(
self
,
block_shape
):
SCALE_K
=
block_shape
[
-
2
]
N
=
block_shape
[
-
1
]
return
block_shape
[:
-
2
]
+
[
N
//
32
,
SCALE_K
*
32
]
@
triton
.
jit
def
unswizzle_mx_scale_cdna4
(
x
,
BLOCK_N
:
tl
.
constexpr
,
MX_SCALE_BLOCK_K
:
tl
.
constexpr
,
N_PRESHUFFLE_FACTOR
:
tl
.
constexpr
=
NON_K_PRESHUFFLE_BLOCK_SIZE
,
):
x
=
x
.
reshape
(
BLOCK_N
//
N_PRESHUFFLE_FACTOR
,
MX_SCALE_BLOCK_K
//
8
,
4
,
16
,
2
,
2
,
1
)
x
=
x
.
permute
(
0
,
5
,
3
,
1
,
4
,
2
,
6
)
x
=
x
.
reshape
(
BLOCK_N
,
MX_SCALE_BLOCK_K
)
return
x
Prev
1
2
3
4
5
6
7
8
9
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment