Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f81ce56b
"tests/kernels/quantization/untest_block_fp8.py" did not exist on "6116ca8cd79b642c64f4ae6f050a6bc12b96d037"
Commit
f81ce56b
authored
Apr 23, 2026
by
chenzk
Browse files
vllm kvprune:v1.0.1
parent
2b7160c6
Changes
237
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3203 deletions
+0
-3203
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction.py
...ctor-vllm/src/compactor_vllm/triton_kernels/compaction.py
+0
-76
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction_details/__init__.py
...pactor_vllm/triton_kernels/compaction_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction_details/_masked_compaction.py
...m/triton_kernels/compaction_details/_masked_compaction.py
+0
-22
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs.py
...ctor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs.py
+0
-609
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/__init__.py
...pactor_vllm/triton_kernels/matmul_ogs_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_common.py
...mpactor_vllm/triton_kernels/matmul_ogs_details/_common.py
+0
-179
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_matmul_ogs.py
...tor_vllm/triton_kernels/matmul_ogs_details/_matmul_ogs.py
+0
-429
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
...r_vllm/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
+0
-471
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_reduce_grouped.py
...vllm/triton_kernels/matmul_ogs_details/_reduce_grouped.py
+0
-126
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags.py
...actor_vllm/triton_kernels/matmul_ogs_details/opt_flags.py
+0
-303
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/__init__.py
..._kernels/matmul_ogs_details/opt_flags_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py
...els/matmul_ogs_details/opt_flags_details/opt_flags_amd.py
+0
-37
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
.../matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
+0
-119
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics.py
...pactor-vllm/src/compactor_vllm/triton_kernels/numerics.py
+0
-42
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/__init__.py
...ompactor_vllm/triton_kernels/numerics_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/flexpoint.py
...mpactor_vllm/triton_kernels/numerics_details/flexpoint.py
+0
-204
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp.py
...rc/compactor_vllm/triton_kernels/numerics_details/mxfp.py
+0
-303
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/__init__.py
.../triton_kernels/numerics_details/mxfp_details/__init__.py
+0
-0
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
...ernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
+0
-158
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
...ernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
+0
-125
No files found.
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
from
.compaction_details._masked_compaction
import
_masked_compaction
from
.tensor
import
Bitmatrix
def
compaction
(
yv
,
yi
,
bitmask
,
sentinel
=-
1
):
"""
Return compacted copies of *yv* and *yi* based on a per-row bitmask.
Only the elements whose index appears among the active bits of *bitmask*
are kept; the rest are replaced by *sentinel*. Kept elements preserve
their original left-to-right order.
Parameters
----------
yv : torch.Tensor, shape (B, K)
Values tensor.
yi : torch.Tensor, shape (B, K), dtype torch.long
Integer indices (0 ≤ index < 32) associated with *yv*.
bitmask : torch.Tensor, shape (B,) **or** (B, 32)
Per-row mask of active indices. See the in-place version for details.
sentinel : int, default -1
Value written into dropped positions of the returned tensors.
Returns
-------
(yv_out, yi_out) : Tuple[torch.Tensor, torch.Tensor], each shape (B, K)
New tensors with the same dtype/device as the inputs.
"""
n_rows
,
n_cols
=
yi
.
shape
ret_yv
=
torch
.
empty_like
(
yv
)
ret_yi
=
torch
.
empty_like
(
yi
)
if
isinstance
(
bitmask
,
Bitmatrix
):
bitmask
=
bitmask
.
storage
.
data
_masked_compaction
[(
n_rows
,)](
yv
,
yi
,
bitmask
,
bitmask
.
stride
(
0
),
bitmask
.
stride
(
1
),
# inputs
ret_yv
,
ret_yi
,
# outputs
sentinel
,
# sentinel
K
=
n_cols
,
# constants
)
return
ret_yv
,
ret_yi
def
compaction_torch
(
yv
:
torch
.
Tensor
,
yi
:
torch
.
Tensor
,
bitmask
:
torch
.
Tensor
,
sentinel
=-
1
):
"""
reference implementation of `masked_compact`
"""
B
,
K
=
yi
.
shape
device
=
yi
.
device
# Expand bitmask to a boolean matrix of active bits (B, 32)
w
=
1
<<
torch
.
arange
(
32
,
device
=
device
,
dtype
=
bitmask
.
dtype
)
bits
=
(
bitmask
.
unsqueeze
(
-
1
)
&
w
)
!=
0
mask
=
bits
.
flatten
(
start_dim
=-
2
)
# or bits.reshape(B, -1)
# For every yi element decide whether it should be kept
keep
=
mask
.
gather
(
1
,
yi
.
long
())
# Build a stable permutation that brings all "keep" items forward
# False→0, True→1 ==> invert so kept==0, dropped==1, then argsort
order
=
(
~
keep
).
to
(
torch
.
int
).
argsort
(
dim
=
1
,
stable
=
True
)
# Re‑order tensors according to above permutation
yi_sorted
=
yi
.
gather
(
1
,
order
)
yv_sorted
=
yv
.
gather
(
1
,
order
)
# fill relevant positions with sentinel
keep_sorted
=
keep
.
gather
(
1
,
order
)
yi_sorted
[
~
keep_sorted
]
=
sentinel
yv_sorted
[
~
keep_sorted
]
=
sentinel
return
yv_sorted
,
yi_sorted
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction_details/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/compaction_details/_masked_compaction.py
deleted
100644 → 0
View file @
2b7160c6
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_masked_compaction
(
Yv
,
Yi
,
BitMask
,
stride_bm
,
stride_bn
,
RetYv
,
RetYi
,
sentinel
,
K
:
tl
.
constexpr
):
pid_m
=
tl
.
program_id
(
0
)
yv
=
tl
.
load
(
Yv
+
pid_m
*
K
+
tl
.
arange
(
0
,
K
))
yi
=
tl
.
load
(
Yi
+
pid_m
*
K
+
tl
.
arange
(
0
,
K
))
div
=
yi
//
32
rem
=
yi
%
32
active_bits
=
(
tl
.
load
(
BitMask
+
pid_m
*
stride_bm
+
div
*
stride_bn
)
>>
rem
)
&
1
exc_cumsum
=
tl
.
cumsum
(
active_bits
,
0
)
-
active_bits
active_flags
=
active_bits
.
to
(
tl
.
int1
)
rev_arange
=
tl
.
where
(
active_flags
,
0
,
K
-
1
-
tl
.
arange
(
0
,
K
))
write_indx
=
exc_cumsum
+
rev_arange
yv
=
tl
.
where
(
active_flags
,
yv
,
sentinel
)
yi
=
tl
.
where
(
active_flags
,
yi
,
sentinel
)
tl
.
store
(
RetYv
+
pid_m
*
K
+
write_indx
,
yv
)
tl
.
store
(
RetYi
+
pid_m
*
K
+
write_indx
,
yi
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs.py
deleted
100644 → 0
View file @
2b7160c6
# isort: off
# fmt: off
from
dataclasses
import
dataclass
import
itertools
import
sys
import
torch
import
triton
from
enum
import
Enum
,
auto
import
math
# utilities
from
compactor_vllm.triton_kernels
import
target_info
from
compactor_vllm.triton_kernels.numerics
import
InFlexData
,
OutFlexData
from
compactor_vllm.triton_kernels.routing
import
GatherIndx
,
RoutingData
,
ScatterIndx
from
compactor_vllm.triton_kernels.target_info
import
is_cuda
# details
from
.matmul_ogs_details._matmul_ogs
import
_matmul_ogs
from
.matmul_ogs_details._p_matmul_ogs
import
_p_matmul_ogs
,
get_per_device_per_stream_alloc_fn
from
.matmul_ogs_details._reduce_grouped
import
_reduce_grouped
from
.numerics_details.mxfp
import
MXFP_BLOCK_SIZE
from
.matmul_ogs_details.opt_flags
import
make_opt_flags
,
update_opt_flags_constraints
,
InapplicableConstraint
from
.specialize
import
specialize
from
.tensor
import
Storage
,
Tensor
,
FP4
,
bitwidth
,
wrap_torch_tensor
@
dataclass
(
frozen
=
True
)
class
FnSpecs
:
name
:
str
fn
:
"triton.runtime.jit.JITFunction"
fn_arg_names
:
tuple
[
str
]
fn_arg_do_not_specialize
:
tuple
[
str
]
=
tuple
()
@
staticmethod
def
default
():
return
FnSpecs
(
"dflt"
,
None
,
tuple
())
@
dataclass
(
frozen
=
True
)
class
FusedActivation
:
specs
:
FnSpecs
=
FnSpecs
.
default
()
fn_args
:
tuple
[
object
]
=
tuple
()
reduction_n
:
int
=
1
@
dataclass
(
frozen
=
True
)
class
Epilogue
:
specs
:
FnSpecs
=
FnSpecs
.
default
()
fn_arg_values_matmul
:
tuple
[
object
]
=
tuple
()
fn_arg_values_finalize
:
tuple
[
object
]
=
tuple
()
effective_itemsize
:
float
=
None
class
FnName
(
Enum
):
QUANTIZE_MXFP8
=
auto
()
EpilogueSpecs
=
FnSpecs
# TODO: remove this alias when callers are updated
_kernels
=
dict
()
def
get_kernels
(
epilogue
:
FnSpecs
=
FnSpecs
.
default
(),
fused_activation
:
FnSpecs
=
FnSpecs
.
default
()):
global
_kernels
key
=
(
fused_activation
.
name
,
epilogue
.
name
)
if
key
in
_kernels
:
return
_kernels
[
key
]
spec_constants
=
{
"ACTIVATION_FN"
:
fused_activation
.
fn
,
"EPILOGUE_FN"
:
epilogue
.
fn
,
}
spec_tuples
=
{
"activation_fn_args"
:
fused_activation
.
fn_arg_names
,
"epilogue_fn_args"
:
epilogue
.
fn_arg_names
,
}
do_not_specialize
=
fused_activation
.
fn_arg_do_not_specialize
+
epilogue
.
fn_arg_do_not_specialize
import
types
module
=
types
.
ModuleType
(
f
"matmul_ogs_
{
'_'
.
join
(
key
)
}
"
)
sys
.
modules
[
module
.
__name__
]
=
module
module
.
_matmul_ogs
=
specialize
(
_matmul_ogs
,
module
,
spec_constants
,
spec_tuples
,
do_not_specialize
=
do_not_specialize
)
module
.
_p_matmul_ogs
=
specialize
(
_p_matmul_ogs
,
module
,
spec_constants
,
spec_tuples
,
do_not_specialize
=
do_not_specialize
)
module
.
_reduce_grouped
=
specialize
(
_reduce_grouped
,
module
,
spec_constants
,
spec_tuples
,
do_not_specialize
=
do_not_specialize
)
_kernels
[
key
]
=
module
return
module
# -----------------------------------------------------------------------------
# Matrix Multiplication + Outer Gather/Scatter
# -----------------------------------------------------------------------------
def
can_overflow_int32
(
tensor
:
torch
.
Tensor
):
max_int32
=
(
1
<<
31
)
-
1
offset
=
0
for
i
in
range
(
tensor
.
ndim
):
offset
+=
(
tensor
.
shape
[
i
]
-
1
)
*
tensor
.
stride
(
i
)
return
offset
>
max_int32
def
should_upcast_indices
(
*
args
):
return
any
(
tensor
is
not
None
and
can_overflow_int32
(
tensor
)
for
tensor
in
args
)
# ---------------------
# Numerics
# ---------------------
# fmt: off
@
dataclass
(
frozen
=
True
)
class
FlexCtx
:
lhs_data
:
InFlexData
=
InFlexData
()
rhs_data
:
InFlexData
=
InFlexData
()
out_data
:
OutFlexData
=
OutFlexData
()
@
dataclass
class
PrecisionConfig
:
max_num_imprecise_acc
:
int
=
None
allow_tf32
:
bool
=
True
flex_ctx
:
FlexCtx
=
FlexCtx
()
acc_scale
:
int
=
1.0
flexpoint_saturate_inf
:
bool
=
False
report_quantization_err_fn
:
callable
=
None
act_scale
:
Tensor
|
None
=
None
weight_scale
:
Tensor
|
None
=
None
out_scale
:
Tensor
|
None
=
None
out_dtype
:
torch
.
dtype
=
None
enforce_bitwise_invariance
:
bool
=
False
# TODO: merge in opt_flags
def
get_swap_xw
(
precision_config
,
opt_flags
):
if
target_info
.
cuda_capability_geq
(
10
,
0
):
return
precision_config
.
weight_scale
is
not
None
and
opt_flags
.
block_m
<=
64
and
opt_flags
.
is_persistent
return
False
# ---------------------
# Allocation
# ---------------------
@
dataclass
class
MatmulAllocation
:
device
:
str
output
:
tuple
[
tuple
[
int
],
torch
.
dtype
]
scratchpads
:
dict
[
str
,
tuple
]
def
init_allocation
(
x
,
w
,
precision_config
,
fused_activation
,
routing_data
,
gather_indx
,
scatter_indx
,
opt_flags
):
# ---- output ------
N
=
w
.
shape
[
-
1
]
# by default - M is number of rows in the activations
M
=
x
.
shape
[
-
2
]
# if the activations are gathered, then M is number of gather indices
if
gather_indx
is
not
None
:
M
=
gather_indx
.
src_indx
.
shape
[
0
]
# final output
if
routing_data
.
n_expts_act
==
1
or
scatter_indx
is
None
:
y_rows
=
M
else
:
Mc
=
scatter_indx
.
src_indx
.
shape
[
0
]
//
routing_data
.
n_expts_act
# compressed number of rows
y_rows
=
Mc
batch_dim
=
x
.
shape
[
0
]
if
x
.
ndim
==
3
else
1
out_shape
=
(
batch_dim
,
y_rows
,
N
//
fused_activation
.
reduction_n
)
out_dtype
=
precision_config
.
out_dtype
or
x
.
dtype
output
=
(
out_shape
,
out_dtype
)
# ---- scratchpad -----#
scratchpad
=
dict
()
if
opt_flags
.
split_k
>
1
or
(
scatter_indx
is
not
None
and
not
opt_flags
.
fused_scatter
):
scratch_out_dtype
=
torch
.
float32
if
opt_flags
.
split_k
>
1
else
out_dtype
scratchpad
[
"matmul"
]
=
((
opt_flags
.
split_k
,
1
,
M
,
N
),
scratch_out_dtype
)
if
"matmul"
in
scratchpad
and
precision_config
.
out_scale
is
not
None
:
scratchpad
[
"mx_out_scale"
]
=
((
opt_flags
.
split_k
,
1
,
M
,
triton
.
cdiv
(
N
,
MXFP_BLOCK_SIZE
)),
torch
.
uint8
)
return
MatmulAllocation
(
x
.
device
,
output
,
scratchpad
)
def
apply_allocation
(
allocation
:
MatmulAllocation
,
output
):
ret
=
dict
()
if
output
is
None
:
output
=
torch
.
empty
(
allocation
.
output
[
0
],
device
=
allocation
.
device
,
dtype
=
allocation
.
output
[
1
])
else
:
assert
output
.
shape
==
allocation
.
output
[
0
]
ret
[
"output"
]
=
output
[
None
,
:,
:]
ret
[
"scratchpad"
]
=
{
k
:
torch
.
empty
(
v
[
0
],
device
=
allocation
.
device
,
dtype
=
v
[
1
])
for
k
,
v
in
allocation
.
scratchpads
.
items
()
}
return
ret
# -----------------------------------------------------------------------------
# Canonicalize
# -----------------------------------------------------------------------------
# the `matmul_ogs` kernel can operate on 2D or 3D inputs depending on the mode being used
# we can canonicalize storages to make the implementation more uniform
def
_canonicalize_storage
(
storage
,
out_ndim
,
flex_data
):
assert
out_ndim
>=
storage
.
data
.
ndim
# Need to use as_strided instead of view because for a tensor with
# shape[-2] == 1 can have ambuiguity related to col-wise. Fo example,
# > t = torch.randn(2, 5, 1).mT
# > t_view = t.view(t.shape)
# > t.stride(), t_view.stride()
# ((5, 1, 1), (5, 5, 1))
# Our check t_view is col-wise fails since t_view.stride(-2) != 1
# This case is covered by (m, n, k) == (1000, 700, 2) in test_matmul.py
new_storage_shape
=
[
1
]
*
(
out_ndim
-
storage
.
data
.
ndim
)
+
list
(
storage
.
data
.
shape
)
new_storage_view
=
storage
.
data
.
view
(
new_storage_shape
)
new_storage_stride
=
[
new_storage_view
.
stride
(
0
)]
*
(
out_ndim
-
storage
.
data
.
ndim
)
+
list
(
storage
.
data
.
stride
())
new_storage_data
=
storage
.
data
.
as_strided
(
new_storage_shape
,
new_storage_stride
)
if
flex_data
is
not
None
:
new_storage_data
=
flex_data
.
reinterpret
(
new_storage_data
)
return
Storage
(
new_storage_data
,
storage
.
layout
)
#
def
reduce_grouped
(
x
:
torch
.
Tensor
,
indx
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
out_mx_scale
:
torch
.
Tensor
,
fused_activation
,
epilogue
,
x_flex
:
InFlexData
|
None
=
None
,
out_flex
:
OutFlexData
|
None
=
None
,
x_mx_scale
:
torch
.
Tensor
|
None
=
None
,
out_dtype
:
bool
=
None
,
flexpoint_saturate_inf
:
bool
=
False
):
"""
In-place grouped row reduction.
Arguments
- x: Tensor[AnyFloat] of shape [(num_groups * K), N]
- indx: Tensor[Int] of shape [num_groups, K]
Description
For each group g in [0, num_groups), this routine sums the K rows of `x`
specified by `indx[g, :]` and overwrites the row corresponding to the first
valid (non-negative) index with the per-group sum. Accumulation is performed
in float32 for numerical stability, and the result is written back in the
dtype of `x`.
Behavior and edge cases
- Invalid (-1) entries are skipped during accumulation and do not generate
memory traffic. If a group has no valid entries, nothing is written for
that group.
- Reduction is performed tile-by-tile along the N dimension within a single
kernel launch (persistent along N) to minimize launch overhead.
Performance notes
- Memory traffic per group is approximately (valid_rows_read + 1) * N * sizeof(x),
plus index reads. With no invalid entries, this becomes (K + 1) reads/writes
of length N per group.
Returns
- The input tensor `x` (modified in place).
"""
if
indx
is
None
and
x
.
shape
[
0
]
==
1
:
return
x
.
squeeze
(
0
),
None
if
indx
is
not
None
:
num_groups
=
indx
.
shape
[
0
]
else
:
num_groups
=
x
.
shape
[
-
2
]
if
x_flex
is
None
:
x_flex
=
InFlexData
()
if
out_flex
is
None
:
out_flex
=
OutFlexData
()
K
=
1
if
indx
is
None
else
indx
.
shape
[
1
]
out_dtype
=
x
.
dtype
if
out_dtype
is
None
else
out_dtype
assert
x
.
shape
[
-
1
]
%
fused_activation
.
reduction_n
==
0
BLOCK_N
=
512
# Resolve scalar flex scales (may be None)
x_expected_scale
=
None
if
x_flex
is
None
else
x_flex
.
scale
out_expected_scale
=
None
if
out_flex
is
None
else
out_flex
.
expected_scale
out_actual_scale
=
None
if
out_flex
is
None
else
out_flex
.
actual_scale
out_checksum_scale
=
None
if
out_flex
is
None
else
out_flex
.
checksum_scale
# Resolve MXFP output scale row stride
stride_mxb
=
0
if
x_mx_scale
is
None
else
x_mx_scale
.
stride
(
0
)
stride_mxs
=
0
if
x_mx_scale
is
None
else
x_mx_scale
.
stride
(
1
)
stride_omxs
=
0
if
out_mx_scale
is
None
else
out_mx_scale
.
stride
(
0
)
kernels
=
get_kernels
(
epilogue
.
specs
,
fused_activation
.
specs
)
kernels
.
_reduce_grouped
[(
num_groups
,
)](
x_flex
.
reinterpret
(
x
),
x
.
stride
(
0
),
x
.
stride
(
2
),
x
.
stride
(
3
),
#
x_expected_scale
,
# scalar input scale
out_flex
.
reinterpret
(
out
),
out
.
stride
(
1
),
out
.
stride
(
2
),
#
out_expected_scale
,
out_actual_scale
,
out_checksum_scale
,
indx
,
#
x
.
shape
[
0
],
x
.
shape
[
-
1
],
#
x_mx_scale
,
stride_mxb
,
stride_mxs
,
#
out_mx_scale
,
stride_omxs
,
#
*
fused_activation
.
fn_args
,
fused_activation
.
reduction_n
,
*
epilogue
.
fn_arg_values_finalize
,
HAS_IN_MX_SCALE
=
x_mx_scale
is
not
None
,
HAS_OUT_MX_SCALE
=
out_mx_scale
is
not
None
,
FLEXPOINT_SATURATE_INF
=
flexpoint_saturate_inf
,
#
BLOCK_N
=
BLOCK_N
,
K
=
K
,
#
num_warps
=
1
,
#
)
return
out
,
out_mx_scale
# -----------------------------------------------------------------------------
# Triton Implementation
# -----------------------------------------------------------------------------
def
matmul_ogs_set_idle_sms
(
num_idle_sms
):
"""
persistent kernels will leave `num_idle_sms` idle
"""
update_opt_flags_constraints
({
"idle_sms"
:
num_idle_sms
})
def
matmul_ogs
(
x
,
w
,
bias
,
routing_data
:
RoutingData
|
None
=
None
,
gather_indx
:
GatherIndx
|
None
=
None
,
scatter_indx
:
ScatterIndx
|
None
=
None
,
precision_config
:
PrecisionConfig
|
None
=
None
,
betas
:
torch
.
Tensor
|
None
=
None
,
gammas
:
torch
.
Tensor
|
None
=
None
,
out_alpha
:
float
|
None
=
None
,
y
:
torch
.
Tensor
|
None
=
None
,
fused_activation
:
FusedActivation
|
None
=
None
,
epilogue
:
Epilogue
|
None
=
None
,
):
"""
Y[:, :] = 0.
for e in num_experts:
Y[idxs_y_m(e), :] += matmul(X[idxs_x_m(e), :], W[e, :, :])
"""
is_input_batched
=
x
.
ndim
==
3
if
is_input_batched
:
assert
gather_indx
is
None
,
"gather not supported in batched mode"
assert
scatter_indx
is
None
,
"scatter not supported in batched mode"
assert
routing_data
is
None
,
"routing not supported in batched mode"
assert
w
.
ndim
==
3
and
w
.
shape
[
0
]
==
x
.
shape
[
0
]
# canonicalize inputs
if
precision_config
is
None
:
precision_config
=
PrecisionConfig
()
if
fused_activation
is
None
:
fused_activation
=
FusedActivation
(
FnSpecs
.
default
(),
tuple
(),
1
)
if
epilogue
is
None
:
epilogue
=
Epilogue
(
FnSpecs
.
default
(),
tuple
(),
tuple
(),
False
)
if
routing_data
is
None
:
routing_data
=
RoutingData
(
None
,
None
,
max
(
1
,
w
.
shape
[
0
]),
1
)
# unpack scales
w_scale
=
precision_config
.
weight_scale
w_has_mx
=
w_scale
is
not
None
is_hopper_fp8
=
is_cuda
()
and
not
target_info
.
cuda_capability_geq
(
10
,
0
)
and
bitwidth
(
w
.
dtype
)
==
8
if
is_hopper_fp8
:
assert
w
.
stride
(
-
2
)
==
1
,
"`w` must be column-major when it has data-type FP8 on capability < 10"
if
not
isinstance
(
w
,
Tensor
):
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
dtype
=
FP4
if
w
.
dtype
==
torch
.
uint8
else
w
.
dtype
w
=
wrap_torch_tensor
(
w
,
dtype
=
dtype
)
if
w_scale
is
not
None
and
not
isinstance
(
w_scale
,
Tensor
):
w_scale
=
Tensor
(
w_scale
)
if
w_scale
is
not
None
:
w_scale
.
storage
.
data
=
w_scale
.
data
.
view
(
torch
.
uint8
)
w_scale
.
dtype
=
torch
.
uint8
x_scale
=
precision_config
.
act_scale
x_has_mx
=
x_scale
is
not
None
if
x_has_mx
:
assert
x
.
stride
(
-
1
)
==
1
,
"'x' must be row-major when it has data-type mxfp"
if
x_scale
is
not
None
and
not
isinstance
(
x_scale
,
Tensor
):
x_scale
=
Tensor
(
x_scale
)
if
not
isinstance
(
x
,
Tensor
):
x
=
Tensor
(
x
,
dtype
=
x
.
dtype
)
# determine shapes
has_gather
=
gather_indx
is
not
None
has_scatter
=
scatter_indx
is
not
None
is_ragged
=
routing_data
.
expt_hist
is
not
None
M
=
x
.
shape
[
-
2
]
if
gather_indx
is
None
else
gather_indx
.
src_indx
.
shape
[
0
]
batch_size
=
w
.
shape
[
0
]
if
routing_data
.
expt_hist
is
None
and
w
.
ndim
==
3
else
1
K
,
N
=
w
.
shape
[
-
2
:]
assert
K
==
x
.
shape
[
-
1
]
if
x
.
ndim
==
3
and
w
.
ndim
==
3
:
assert
x
.
shape
[
0
]
==
w
.
shape
[
0
]
# compute optimization flags
out_dtype
=
precision_config
.
out_dtype
or
x
.
dtype
can_use_tma
=
x
.
numel
()
>
0
and
x
.
storage
.
is_tma_compliant
()
and
\
w
.
numel
()
>
0
and
w
.
storage
.
is_tma_compliant
()
and
\
(
w_scale
is
None
or
w_scale
.
storage
.
is_tma_compliant
())
# hopper w/ mxfp4 doesn't support TMA
can_use_tma
=
can_use_tma
and
(
torch
.
cuda
.
get_device_capability
()[
0
]
>
9
or
bitwidth
(
w
.
dtype
)
!=
4
)
can_use_fused_scatter
=
has_scatter
and
(
fused_activation
.
specs
.
fn
is
None
)
and
(
epilogue
.
specs
.
fn
is
None
)
and
(
routing_data
.
n_expts_act
==
1
)
opt_flags
=
make_opt_flags
(
out_dtype
,
x
.
dtype
,
w
.
dtype
,
precision_config
,
M
,
N
,
K
,
routing_data
,
can_use_tma
,
can_use_fused_scatter
,
epilogue
.
effective_itemsize
,
)
if
not
can_use_fused_scatter
and
opt_flags
.
fused_scatter
:
raise
InapplicableConstraint
(
"Fused scatter is not supported"
)
if
w_scale
is
not
None
and
opt_flags
.
is_persistent
and
not
target_info
.
has_native_mxfp
():
raise
NotImplementedError
(
"Must use non-persistent kernel for simulated MXFP"
)
if
w_scale
is
not
None
and
w_scale
.
storage
.
layout
.
name
is
not
None
and
not
opt_flags
.
is_persistent
and
target_info
.
has_native_mxfp
():
raise
NotImplementedError
(
"Must use persistent kernel and be TMA-compliant for native MXFP"
)
# fused activation
matmul_fused_activation
=
fused_activation
reduce_fused_activation
=
FusedActivation
()
if
opt_flags
.
split_k
>
1
or
(
scatter_indx
is
not
None
and
not
opt_flags
.
fused_scatter
):
matmul_fused_activation
,
reduce_fused_activation
=
reduce_fused_activation
,
matmul_fused_activation
# allocate output/scratchpad memory
allocation
=
init_allocation
(
x
,
w
,
precision_config
,
fused_activation
,
routing_data
,
gather_indx
,
scatter_indx
,
opt_flags
)
memory
=
apply_allocation
(
allocation
,
y
)
# early exit
if
batch_size
*
M
*
N
==
0
:
ret
=
memory
[
"output"
].
squeeze
(
0
)
if
not
is_input_batched
:
ret
=
ret
.
squeeze
(
0
)
return
ret
# TMA descriptors require a global memory allocation
if
opt_flags
.
is_persistent
:
triton
.
set_allocator
(
get_per_device_per_stream_alloc_fn
(
x
.
device
))
# Intermediate tensors and postprocess kernels for each situation
has_scratchpad
=
"matmul"
in
memory
[
"scratchpad"
]
# Canonical output tensor (matmul scratchpad if present, otherwise final output tensor)
out_matmul
=
memory
[
"scratchpad"
].
get
(
"matmul"
,
memory
[
"output"
])
out_matmul_flex
=
OutFlexData
()
if
out_matmul
.
dtype
==
torch
.
float32
else
precision_config
.
flex_ctx
.
out_data
# Unified mx-scale pointer; when scratchpad exists, prefer its mx buffer
out_matmul_scale
=
precision_config
.
out_scale
if
out_matmul_scale
is
not
None
:
out_matmul_scale
=
out_matmul_scale
.
data
.
view
(
torch
.
uint8
)
if
has_scratchpad
and
"mx_out_scale"
in
memory
[
"scratchpad"
]:
out_matmul_scale
=
memory
[
"scratchpad"
][
"mx_out_scale"
]
out_matmul_has_mx
=
out_matmul_scale
is
not
None
and
out_matmul
.
element_size
()
==
1
# matrix multiplication
flex
=
precision_config
.
flex_ctx
bias_stride
=
None
if
bias
is
None
else
bias
.
stride
(
0
)
num_indx
=
None
if
scatter_indx
is
None
else
scatter_indx
.
src_indx
.
shape
[
0
]
# moe metadata
expt_data
=
routing_data
.
expt_data
block_m
=
opt_flags
.
block_m
expt_hist
=
None
if
expt_data
is
None
else
expt_data
.
hist
expt_hist_sum
=
None
if
expt_data
is
None
else
expt_data
.
token_offs_pad
[
block_m
][
-
1
]
expt_token_offs_raw
=
None
if
expt_data
is
None
else
expt_data
.
token_offs_raw
expt_block_pid_map
=
None
if
expt_data
is
None
else
expt_data
.
block_pid_map
[
block_m
]
# spmd grid
grid_m
=
triton
.
cdiv
(
M
,
opt_flags
.
block_m
)
if
expt_block_pid_map
is
not
None
:
grid_m
=
routing_data
.
n_blocks
(
M
,
opt_flags
.
block_m
)
grid_n
=
triton
.
cdiv
(
N
,
opt_flags
.
block_n
)
max_grid
=
batch_size
*
grid_m
*
grid_n
*
opt_flags
.
split_k
grid
=
min
(
target_info
.
num_sms
()
-
opt_flags
.
idle_sms
,
max_grid
)
if
opt_flags
.
is_persistent
else
max_grid
# canonicalize storage
has_gather_tma
=
has_gather
and
target_info
.
has_tma_gather
()
has_scatter_tma
=
opt_flags
.
fused_scatter
and
target_info
.
has_tma_gather
()
y
=
wrap_torch_tensor
(
out_matmul
.
view
(
math
.
prod
(
out_matmul
.
shape
[:
-
1
]),
out_matmul
.
shape
[
-
1
])
if
opt_flags
.
fused_scatter
else
out_matmul
.
view
(
math
.
prod
(
out_matmul
.
shape
[:
-
2
]),
*
out_matmul
.
shape
[
-
2
:]))
x_storage
=
_canonicalize_storage
(
x
.
storage
,
2
if
has_gather_tma
else
3
,
flex
.
lhs_data
)
w_storage
=
_canonicalize_storage
(
w
.
storage
,
3
,
flex
.
rhs_data
)
y_storage
=
_canonicalize_storage
(
y
.
storage
,
2
if
has_scatter_tma
else
3
,
flex
.
out_data
)
# create tma descriptor for x
x_has_tma
=
opt_flags
.
is_persistent
and
(
has_gather_tma
or
not
has_gather
)
x_tma_block_size
=
[
1
,
opt_flags
.
block_k
]
if
has_gather_tma
else
[
1
,
opt_flags
.
block_m
,
opt_flags
.
block_k
]
x_tma_mode
=
None
if
not
x_has_tma
else
"ragged"
if
is_ragged
and
not
has_gather_tma
else
"dense"
x_tensor_or_tma
=
x_storage
.
make_tma
(
x_tma_block_size
,
x_tma_mode
)
if
x_has_tma
else
x_storage
.
data
# create tma descriptor for y
y_has_tma
=
opt_flags
.
is_persistent
and
(
has_scatter_tma
or
not
opt_flags
.
fused_scatter
)
block_n
=
opt_flags
.
block_n
//
opt_flags
.
epilogue_subtile
//
matmul_fused_activation
.
reduction_n
y_tma_block_size
=
[
1
,
block_n
]
if
has_scatter_tma
else
[
1
,
opt_flags
.
block_m
,
block_n
]
y_tma_mode
=
None
if
not
y_has_tma
else
"ragged"
if
is_ragged
and
not
has_scatter_tma
else
"dense"
y_tensor_or_tma
=
y_storage
.
make_tma
(
y_tma_block_size
,
y_tma_mode
)
if
y_has_tma
else
y_storage
.
data
# create tma descriptor for w
w_has_tma
=
opt_flags
.
is_persistent
w_tensor_or_tma
=
w_storage
.
make_tma
([
1
,
opt_flags
.
block_k
,
opt_flags
.
block_n
],
"dense"
)
if
w_has_tma
else
w_storage
.
data
# create tma descriptor for w_scale
w_scale_tensor_or_tma
=
w_scale
w_scale_has_tma
=
opt_flags
.
is_persistent
and
w_scale
is
not
None
w_scale_tensor_or_tma
=
w_scale
.
storage
.
make_tma
([
opt_flags
.
block_n
,
opt_flags
.
block_k
],
"dense"
)
if
w_scale_has_tma
else
w_scale
# canonicalize strides
x_strides
=
[
0
]
*
(
3
-
x_storage
.
data
.
ndim
)
+
list
(
x_storage
.
data
.
stride
())
x_scale_strides
=
x_scale
.
stride
()
if
x_has_mx
else
(
None
,
None
,
None
)
x_scale_strides
=
(
0
,
)
*
(
3
-
len
(
x_scale_strides
))
+
x_scale_strides
w_scale_strides
=
w_scale
.
stride
()
if
w_has_mx
and
not
w_scale_has_tma
else
(
None
,
None
,
None
)
w_scale_strides
=
(
0
,
)
*
(
3
-
len
(
w_scale_strides
))
+
w_scale_strides
out_matmul_scale_strides
=
out_matmul_scale
.
stride
()
if
out_matmul_has_mx
else
(
None
,
None
,
None
,
None
)
out_matmul_scale_strides
=
(
0
,
)
*
(
3
-
len
(
out_matmul_scale_strides
))
+
out_matmul_scale_strides
# launch kernel
kernels
=
get_kernels
(
epilogue
.
specs
,
matmul_fused_activation
.
specs
)
# When stride(-2) == stride(-1) == 1, it's ambiguous whether W is transposed
# (i.e. col-wise). Since this matters when w_has_mx is True and w_transpose
# is True the fast code path, stride(-2) == 1 takes precedence, e.g., vs.
# w_transpose = w_storage.data.stride()[-1] != 1
w_transpose
=
w_storage
.
data
.
stride
()[
-
2
]
==
1
(
kernels
.
_p_matmul_ogs
if
opt_flags
.
is_persistent
else
kernels
.
_matmul_ogs
)[(
grid
,)](
y_tensor_or_tma
,
y_storage
.
data
,
*
out_matmul
.
stride
(),
*
((
None
,
out_matmul_scale
,
None
)
if
out_matmul_has_mx
else
out_matmul_flex
),
*
out_matmul_scale_strides
[
-
3
:],
x_tensor_or_tma
,
x_storage
.
data
,
*
x_strides
,
flex
.
lhs_data
.
scale
,
None
if
x_scale
is
None
else
x_scale
.
data
.
view
(
torch
.
uint8
),
*
x_scale_strides
,
w_tensor_or_tma
,
w_storage
.
data
,
*
w_storage
.
data
.
stride
(),
w_transpose
,
flex
.
rhs_data
.
scale
,
w_scale_tensor_or_tma
,
*
w_scale_strides
,
bias
,
bias_stride
,
x
.
shape
[
-
2
],
x
.
shape
[
-
2
]
if
routing_data
.
expt_hist
is
None
else
None
,
N
,
K
,
betas
,
gammas
,
None
if
gather_indx
is
None
else
gather_indx
.
src_indx
,
None
if
scatter_indx
is
None
else
scatter_indx
.
src_indx
,
num_indx
,
None
if
not
opt_flags
.
fused_scatter
else
scatter_indx
.
dst_indx
,
None
if
not
opt_flags
.
fused_scatter
else
scatter_indx
.
dst_indx
.
shape
[
0
],
expt_hist
,
expt_token_offs_raw
,
expt_hist_sum
,
expt_block_pid_map
,
batch_size
,
grid_m
,
grid_n
,
out_alpha
,
*
matmul_fused_activation
.
fn_args
,
matmul_fused_activation
.
reduction_n
,
*
epilogue
.
fn_arg_values_matmul
,
routing_data
.
n_expts_tot
,
routing_data
.
n_expts_act
,
precision_config
.
max_num_imprecise_acc
,
precision_config
.
allow_tf32
,
precision_config
.
flexpoint_saturate_inf
,
flex
.
rhs_data
.
is_per_batch
,
opt_flags
.
block_m
,
opt_flags
.
block_n
,
opt_flags
.
block_k
,
opt_flags
.
group_m
,
XCD_SWIZZLE
=
opt_flags
.
xcd_swizzle
,
SWIZZLE_MX_VALUE
=
w
.
storage
.
layout
.
name
,
SWIZZLE_MX_SCALE
=
None
if
w_scale
is
None
else
w_scale
.
storage
.
layout
.
name
,
EPILOGUE_SUBTILE
=
opt_flags
.
epilogue_subtile
,
SPLIT_K
=
opt_flags
.
split_k
,
EVEN_K
=
K
%
opt_flags
.
block_k
==
0
,
W_CACHE_MODIFIER
=
opt_flags
.
w_cache_modifier
,
TOKENS_PER_EXPT_FOR_ANNOTATION
=
routing_data
.
expected_tokens_per_expt
,
num_warps
=
opt_flags
.
num_warps
,
num_stages
=
opt_flags
.
num_stages
,
arch
=
opt_flags
.
arch
,
UPCAST_INDICES
=
should_upcast_indices
(
x
,
w
,
out_matmul
),
X_TMA_MODE
=
x_tma_mode
,
Y_TMA_MODE
=
y_tma_mode
,
SWAP_XW
=
get_swap_xw
(
precision_config
,
opt_flags
),
IS_EPILOGUE_QUANT_MXFP8
=
epilogue
.
specs
.
name
==
FnName
.
QUANTIZE_MXFP8
.
name
,
NUM_SMS
=
grid
if
opt_flags
.
is_persistent
else
0
,
**
opt_flags
.
target_kernel_kwargs
)
# Build grouped reduction inputs in a uniform way
group_indx
=
None
if
scatter_indx
is
None
or
opt_flags
.
fused_scatter
else
scatter_indx
.
src_indx
.
view
(
-
1
,
routing_data
.
n_expts_act
)
out_final
,
out_final_mx_scale
=
reduce_grouped
(
out_matmul
,
group_indx
,
memory
[
"output"
].
squeeze
(
0
),
precision_config
.
out_scale
,
reduce_fused_activation
,
epilogue
,
x_flex
=
InFlexData
(
dtype
=
out_matmul_flex
.
dtype
,
scale
=
out_matmul_flex
.
expected_scale
),
out_flex
=
precision_config
.
flex_ctx
.
out_data
,
x_mx_scale
=
out_matmul_scale
.
squeeze
(
1
)
if
out_matmul_has_mx
else
None
,
out_dtype
=
memory
[
"output"
].
dtype
,
flexpoint_saturate_inf
=
precision_config
.
flexpoint_saturate_inf
,
)
if
not
is_input_batched
:
out_final
=
out_final
.
squeeze
(
0
)
if
out_final_mx_scale
is
not
None
:
precision_config
.
out_scale
=
out_final_mx_scale
return
out_final
# -----------------------------------------------------------------------------
# Reference Implementation
# -----------------------------------------------------------------------------
def
matmul_ogs_torch
(
x
,
w
,
bias
,
routing_data
:
RoutingData
=
None
,
gather_indx
:
GatherIndx
=
None
,
scatter_indx
:
ScatterIndx
=
None
,
precision_config
:
PrecisionConfig
=
None
,
betas
=
None
,
gammas
=
None
,
round_x
=
None
,
round_y
=
None
,
):
is_input_batched
=
x
.
ndim
==
3
assert
x
.
dtype
.
itemsize
>
1
assert
w
.
dtype
.
itemsize
>
1
if
is_input_batched
:
assert
gather_indx
is
None
,
"gather not supported in batched mode"
assert
scatter_indx
is
None
,
"scatter not supported in batched mode"
assert
routing_data
is
None
,
"routing not supported in batched mode"
assert
w
.
ndim
==
3
and
w
.
shape
[
0
]
==
x
.
shape
[
0
]
if
round_x
is
None
:
round_x
=
lambda
x
,
idx
:
x
if
round_y
is
None
:
round_y
=
lambda
x
:
x
if
bias
is
not
None
and
bias
.
ndim
==
1
:
bias
=
bias
.
view
(
1
,
*
bias
.
shape
)
if
w
.
ndim
==
2
:
w
=
w
.
view
(
1
,
*
w
.
shape
)
if
x
.
ndim
==
2
:
x
=
x
.
view
(
1
,
*
x
.
shape
)
if
routing_data
is
None
:
routing_data
=
RoutingData
(
None
,
None
,
w
.
shape
[
0
],
1
)
n_expts_act
=
routing_data
.
n_expts_act
# memory offsets
if
routing_data
.
n_expts_tot
>
1
and
not
is_input_batched
:
sizes
=
routing_data
.
expt_hist
off
=
torch
.
zeros
(
sizes
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
)
off
[
1
:]
=
torch
.
cumsum
(
sizes
,
0
)
offs
=
list
(
itertools
.
pairwise
(
off
))
else
:
offs
=
[[
0
,
x
.
shape
[
1
]]
for
_
in
range
(
w
.
shape
[
0
])]
# compute
n_rows
=
x
.
shape
[
1
]
if
gather_indx
is
None
else
gather_indx
.
dst_indx
.
shape
[
0
]
y
=
torch
.
zeros
((
x
.
shape
[
0
],
n_rows
,
w
.
shape
[
-
1
]),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
for
i
,
(
lo
,
hi
)
in
enumerate
(
offs
):
if
gather_indx
is
None
:
idx
=
torch
.
arange
(
lo
,
hi
,
device
=
x
.
device
)
else
:
idx
=
gather_indx
.
src_indx
[
lo
:
hi
]
//
n_expts_act
batch
=
i
if
is_input_batched
else
0
out
=
torch
.
matmul
(
round_x
(
x
[
batch
,
idx
,
:],
torch
.
arange
(
lo
,
hi
,
device
=
"cuda"
)).
float
(),
w
[
i
].
float
())
if
bias
is
not
None
:
out
+=
bias
[
i
,
:]
if
betas
is
None
else
bias
[
i
,
:]
*
betas
[
lo
:
hi
,
None
]
if
gammas
is
not
None
:
out
*=
gammas
[
lo
:
hi
,
None
]
y
[
batch
,
lo
:
hi
,
:]
=
round_y
(
out
)
if
not
is_input_batched
:
y
=
y
.
view
(
y
.
shape
[
1
],
y
.
shape
[
2
])
if
scatter_indx
is
None
:
return
y
# accumulate output from all experts
n_rows
=
y
.
shape
[
0
]
//
n_expts_act
out
=
torch
.
zeros
((
n_rows
,
y
.
shape
[
-
1
]),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
for
i
,
(
lo
,
hi
)
in
enumerate
(
offs
):
dst_idx
=
scatter_indx
.
dst_indx
[
lo
:
hi
]
//
n_expts_act
msk
=
dst_idx
!=
-
1
out
[
dst_idx
[
msk
],
:]
+=
y
[
lo
:
hi
,
:][
msk
,
:].
float
()
return
out
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_common.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
triton
import
triton.language
as
tl
# -----------------------------------------------------------------------------
# Utilities
# -----------------------------------------------------------------------------
@
triton
.
constexpr_function
def
get_scaled_dot_format_string
(
dtype
:
tl
.
dtype
):
mapping
=
{
tl
.
float16
:
"fp16"
,
tl
.
bfloat16
:
"bf16"
,
tl
.
uint8
:
"e2m1"
,
tl
.
float8e4nv
:
"e4m3"
,
tl
.
float8e5
:
"e5m2"
,
}
return
mapping
[
dtype
]
@
triton
.
jit
def
xcd_swizzle
(
pid
,
domain_size
,
XCD_SWIZZLE
:
tl
.
constexpr
):
"""
Swizzle the program id based on integer XCD_SWIZZLE.
This is useful for reording how blocks are ordered. A scheduler may, for example,
assign sequential blocks 0, 1, 2, 3, ..., 8, 9, 10.. to its 8 hardware units 0, 1, 2, 3, ..., 0, 1, 2.
This pattern may not be ideal for memory access, and it may be better to swizzle so the assignment
becomes 0, 0, 0, 0, ..., 1, 1, 1, ... In the swizzled arrangement, sequential blocks are assigned to
the same hardware unit.
"""
# Number of pids per group in the new arrangement
pids_per_group
=
domain_size
//
XCD_SWIZZLE
extra_pid_groups
=
domain_size
%
XCD_SWIZZLE
# Compute current current and local pid within the group
group
=
pid
%
XCD_SWIZZLE
local_pid
=
pid
//
XCD_SWIZZLE
# Calculate new pid based on the new grouping
new_pid
=
group
*
pids_per_group
+
min
(
group
,
extra_pid_groups
)
+
local_pid
return
new_pid
@
triton
.
jit
def
swizzle2d
(
pid
,
grid_m
,
grid_n
,
GROUP_M
:
tl
.
constexpr
):
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
tl
.
assume
(
group_size
>=
0
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
return
pid_m
,
pid_n
def
make_matmul_repr
(
base_name
,
order
):
def
matmul_repr
(
specialization
):
signature
=
specialization
.
signature
constants
=
specialization
.
constants
reorder
=
lambda
L
:
[
L
[
i
]
for
i
in
order
]
layout
=
lambda
stride
:
"N"
if
stride
in
constants
else
"T"
def
convert_dtype
(
dtype
):
if
"tensordesc"
in
dtype
:
ret
=
convert_dtype
(
dtype
.
split
(
"<"
)[
1
].
split
(
"["
)[
0
])
return
ret
elif
"u8"
in
dtype
:
return
"mxfp4"
elif
dtype
[
0
]
==
"*"
:
return
dtype
[
1
:]
else
:
return
dtype
dtypes
=
"x"
.
join
(
[
convert_dtype
(
f
"
{
signature
[
i
]
}
"
)
for
i
in
reorder
([
"Y"
,
"X"
,
"W"
])]
)
layouts
=
""
.
join
(
[
f
"
{
layout
(
i
)
}
"
for
i
in
reorder
([
"stride_y_n"
,
"stride_x_k"
,
"stride_w_n"
])
]
)
blocks
=
"x"
.
join
(
[
f
"
{
constants
[
i
]
}
"
for
i
in
[
"BLOCK_M"
,
"BLOCK_N"
,
"BLOCK_K"
,
"SPLIT_K"
]]
)
# mode = []
# if "GatherIndx" not in constants:
# mode += ['g']
# if "ScatterSrcIndx" not in constants:
# mode += ['s']
# suffix = "" if not mode else "_o" + (''.join(mode))
# if base_name.startswith("_p"):
# suffix += "_ptma"
return
f
"
{
base_name
}
_
{
layouts
}
_
{
dtypes
}
_
{
blocks
}
"
return
matmul_repr
def
matmul_launch_metadata
(
grid
,
kernel
,
args
):
from
..proton_opts
import
launch_metadata_allow_sync
ret
=
dict
()
M
,
N
,
K
=
args
[
"M"
],
args
[
"N"
],
args
[
"K"
]
Y
,
X
,
W
=
args
[
"YPtr"
],
args
[
"XPtr"
],
args
[
"WPtr"
]
tokens_per_expt
=
args
.
get
(
"TOKENS_PER_EXPT_FOR_ANNOTATION"
)
hist
=
args
[
"ExptHist"
]
if
hist
is
not
None
:
# If annotation is given, use that to generate name for profiling.
if
tokens_per_expt
is
not
None
:
n_rows
=
f
"
{
tokens_per_expt
}
*"
elif
launch_metadata_allow_sync
():
n_rows
=
int
(
hist
.
float
().
mean
())
else
:
n_rows
=
"unknown"
if
launch_metadata_allow_sync
():
n_tokens
=
float
(
hist
.
sum
())
n_w_bytes
=
(
W
.
numel
()
*
W
.
element_size
()
//
hist
.
numel
())
*
(
hist
>
0
).
sum
()
elif
tokens_per_expt
is
not
None
:
n_tokens
=
tokens_per_expt
*
args
[
"N_EXPTS_TOT"
]
# This may not be totally correct (e.g., we might not be using all experts)
# but it's better than nothing.
n_w_bytes
=
W
.
numel
()
*
W
.
element_size
()
else
:
n_tokens
=
None
n_w_bytes
=
0
# If annotation is given, use that to generate name for profiling.
tokens_per_expt
=
args
.
get
(
"TOKENS_PER_EXPT_FOR_ANNOTATION"
)
n_rows
=
f
"
{
tokens_per_expt
}
*"
if
tokens_per_expt
is
not
None
else
n_rows
else
:
n_tokens
=
None
n_w_bytes
=
W
.
numel
()
*
W
.
element_size
()
repr
=
(
lambda
s
,
x
:
f
"
{
s
}
=
{
x
}
"
if
x
is
not
None
else
f
"E_
{
len
(
hist
)
}
(
{
s
}
) =
{
n_rows
}
"
)
nbits
=
X
.
dtype
.
itemsize
*
8
batch_repr
=
""
if
"batch_size"
in
args
and
args
[
"batch_size"
]
>
1
:
batch_repr
=
repr
(
"B"
,
args
[
"batch_size"
])
+
", "
ret
[
"name"
]
=
(
f
"
{
kernel
.
name
}
[
{
batch_repr
}{
repr
(
'M'
,
M
)
}
,
{
repr
(
'N'
,
N
)
}
,
{
repr
(
'K'
,
K
)
}
] stg
{
kernel
.
num_stages
}
"
)
ep_subtile
=
args
[
"EPILOGUE_SUBTILE"
]
if
ep_subtile
is
not
None
and
ep_subtile
>
1
:
ret
[
"name"
]
+=
f
" ep/
{
ep_subtile
}
"
if
hist
is
not
None
and
n_tokens
is
None
:
return
ret
# Don't fill metadata because we can't compute them properly.
fM
=
M
if
M
is
not
None
else
n_tokens
fK
=
K
if
K
is
not
None
else
n_tokens
ret
[
f
"flops
{
nbits
}
"
]
=
2.0
*
fM
*
N
*
fK
gindx
=
args
.
get
(
"GatherIndx"
,
None
)
# sindx = args.get("WriteBackIndx", None)
n_x_bytes
=
X
.
numel
()
*
X
.
element_size
()
n_y_bytes
=
Y
.
numel
()
*
Y
.
element_size
()
if
hist
is
not
None
:
assert
n_tokens
is
not
None
n_expts_act
=
args
[
"N_EXPTS_ACT"
]
if
(
gindx
is
not
None
)
and
launch_metadata_allow_sync
():
# recreate inverse GatherIndx.
dst
=
torch
.
full_like
(
gindx
,
-
1
)
idx
=
torch
.
arange
(
len
(
gindx
),
device
=
gindx
.
device
,
dtype
=
torch
.
int32
)
mask
=
gindx
!=
-
1
dst
[
gindx
[
mask
]]
=
idx
[
mask
]
n_read_rows
=
(
dst
.
view
((
-
1
,
n_expts_act
))
!=
-
1
).
any
(
dim
=
1
).
sum
()
else
:
n_read_rows
=
n_tokens
n_x_bytes
=
n_read_rows
*
X
.
shape
[
-
1
]
*
X
.
element_size
()
n_y_bytes
=
n_tokens
*
Y
.
shape
[
-
1
]
*
Y
.
element_size
()
ret
[
"bytes"
]
=
int
(
n_x_bytes
+
n_y_bytes
+
n_w_bytes
)
return
ret
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_matmul_ogs.py
deleted
100644 → 0
View file @
2b7160c6
# isort: off
# fmt: off
import
triton
import
triton.language
as
tl
from
compactor_vllm.triton_kernels.tensor_details.layout_details.blackwell_scale
import
unswizzle_mx_scale_bw
from
compactor_vllm.triton_kernels.tensor_details.layout_details.hopper_scale
import
unswizzle_mxfp4_scale_hopper
from
compactor_vllm.triton_kernels.tensor_details.layout_details.hopper_value
import
mxfp4_to_bf16_triton
from
compactor_vllm.triton_kernels.tensor_details.layout_details.cdna4_scale
import
unswizzle_mx_scale_cdna4
from
compactor_vllm.triton_kernels.numerics_details.flexpoint
import
float_to_flex
,
load_scale
from
compactor_vllm.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp
import
MXFP_BLOCK_SIZE
from
._common
import
make_matmul_repr
,
matmul_launch_metadata
,
swizzle2d
,
xcd_swizzle
,
get_scaled_dot_format_string
@
triton
.
jit
def
_zero_masked_rows
(
pid_m
,
pid_n
,
Y
,
stride_y_m
,
stride_y_n
,
N
,
ScatterSrcIndx
,
num_idxs
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
):
offs_m
=
BLOCK_M
*
pid_m
.
to
(
tl
.
int64
)
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
BLOCK_N
*
pid_n
+
tl
.
arange
(
0
,
BLOCK_N
)
src_idx
=
tl
.
load
(
ScatterSrcIndx
+
offs_m
,
mask
=
offs_m
<
num_idxs
,
other
=
0
)
YPtrs
=
Y
+
offs_m
[:,
None
]
*
stride_y_m
+
offs_n
[
None
,
:]
*
stride_y_n
mask_n
=
offs_n
<
N
mask
=
(
src_idx
==
-
1
)[:,
None
]
&
mask_n
[
None
,
:]
tl
.
store
(
YPtrs
,
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
),
mask
=
mask
)
_matmul_ogs_repr
=
make_matmul_repr
(
"_matmul_ogs"
,
[
0
,
1
,
2
])
@
triton
.
jit
(
do_not_specialize
=
[
"TOKENS_PER_EXPT_FOR_ANNOTATION"
],
repr
=
_matmul_ogs_repr
,
launch_metadata
=
matmul_launch_metadata
)
def
_matmul_ogs
(
Y
,
YPtr
,
stride_y_k
,
stride_y_z
,
stride_y_m
,
stride_y_n
,
YExpectedScale
,
YActualScale
,
YChecksumScale
,
stride_y_mx_z
,
stride_y_mx_m
,
stride_y_mx_n
,
X
,
XPtr
,
stride_x_z
,
stride_x_m
,
stride_x_k
,
XScale
,
XMxScale
,
stride_x_mx_z
,
stride_x_mx_m
,
stride_x_mx_k
,
W
,
WPtr
,
stride_w_e
,
stride_w_k
,
stride_w_n
,
W_TRANSPOSE
:
tl
.
constexpr
,
WScale
,
WMxScale
,
stride_w_mx_e
,
stride_w_mx_k
,
stride_w_mx_n
,
B
,
stride_b_e
,
# Bias
NRows
,
M
,
N
,
K
,
# shapes
# expt data
Betas
,
Gammas
,
GatherIndx
,
ScatterSrcIndx
,
num_idxs
,
WriteBackIndx
,
writeback_size
,
ExptHist
,
ExptOffs
,
ExptOffsSum
,
ExptData
,
# true grid size
batch_size
,
grid_m
,
grid_n
,
# Out scale
out_alpha
,
# fused activation function
ACTIVATION_FN
:
tl
.
constexpr
,
activation_fn_args
,
ACTIVATION_REDUCTION_N
:
tl
.
constexpr
,
# epilogue transform
EPILOGUE_FN
:
tl
.
constexpr
,
epilogue_fn_args
,
# MoE config
N_EXPTS_TOT
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
# precision config
MAX_NUM_IMPRECISE_ACC
:
tl
.
constexpr
,
ALLOW_TF32
:
tl
.
constexpr
,
FLEXPOINT_SATURATE_INF
:
tl
.
constexpr
,
PER_BATCH_SCALE
:
tl
.
constexpr
,
# optimization config
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
XCD_SWIZZLE
:
tl
.
constexpr
,
# One of ["HOPPER", "BLACKWELL", None]
SWIZZLE_MX_VALUE
:
tl
.
constexpr
,
# One of ["HOPPER", "BLACKWELL", None]
SWIZZLE_MX_SCALE
:
tl
.
constexpr
,
EPILOGUE_SUBTILE
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
W_CACHE_MODIFIER
:
tl
.
constexpr
,
NUM_SMS
:
tl
.
constexpr
,
X_TMA_MODE
:
tl
.
constexpr
,
Y_TMA_MODE
:
tl
.
constexpr
,
TOKENS_PER_EXPT_FOR_ANNOTATION
=
None
,
UPCAST_INDICES
:
tl
.
constexpr
=
False
,
SWAP_XW
:
tl
.
constexpr
=
False
,
IS_EPILOGUE_QUANT_MXFP8
:
tl
.
constexpr
=
False
):
tl
.
assume
(
stride_y_k
>=
0
)
tl
.
assume
(
stride_y_z
>=
0
)
tl
.
assume
(
stride_y_m
>=
0
)
tl
.
assume
(
stride_y_n
>=
0
)
tl
.
assume
(
stride_x_z
>=
0
)
tl
.
assume
(
stride_x_m
>=
0
)
tl
.
assume
(
stride_x_k
>=
0
)
tl
.
assume
(
stride_w_e
>=
0
)
tl
.
assume
(
stride_w_k
>=
0
)
tl
.
assume
(
stride_w_n
>=
0
)
if
stride_w_mx_e
is
not
None
:
tl
.
assume
(
stride_w_mx_e
>=
0
)
if
stride_w_mx_k
is
not
None
:
tl
.
assume
(
stride_w_mx_k
>=
0
)
if
stride_w_mx_n
is
not
None
:
tl
.
assume
(
stride_w_mx_n
>=
0
)
if
B
is
not
None
:
tl
.
assume
(
stride_b_e
>=
0
)
tl
.
assume
(
batch_size
>=
0
)
tl
.
assume
(
grid_m
>=
0
)
tl
.
assume
(
grid_n
>=
0
)
is_w_microscaled
:
tl
.
constexpr
=
WMxScale
is
not
None
MX_PACK_DIVISOR
:
tl
.
constexpr
=
MXFP_BLOCK_SIZE
if
is_w_microscaled
:
w_type
:
tl
.
constexpr
=
W
.
dtype
.
element_ty
is_mxfp4
:
tl
.
constexpr
=
w_type
==
tl
.
uint8
tl
.
static_assert
(
w_type
==
tl
.
uint8
or
(
w_type
==
tl
.
float8e4nv
or
w_type
==
tl
.
float8e5
),
"mx_weight_ptr must be uint8 or fp8"
)
tl
.
static_assert
(
WMxScale
.
dtype
.
element_ty
==
tl
.
uint8
,
"mx_scale_ptr must be uint8"
)
tl
.
static_assert
(
BLOCK_K
%
MX_PACK_DIVISOR
==
0
,
"BLOCK_K must be a multiple of MX_PACK_DIVISOR"
)
tl
.
static_assert
(
SWIZZLE_MX_VALUE
==
"HOPPER_VALUE"
or
SWIZZLE_MX_VALUE
is
None
,
"Only Hopper swizzling is supported for values"
)
else
:
tl
.
static_assert
(
SWIZZLE_MX_VALUE
is
None
)
tl
.
static_assert
(
SWIZZLE_MX_SCALE
is
None
)
is_x_microscaled
:
tl
.
constexpr
=
XMxScale
is
not
None
if
is_x_microscaled
:
x_type
:
tl
.
constexpr
=
X
.
dtype
.
element_ty
tl
.
static_assert
(
is_w_microscaled
)
tl
.
static_assert
(
x_type
==
tl
.
float8e4nv
,
"mx_act_ptr must be float8e4nv"
)
tl
.
static_assert
(
XMxScale
.
dtype
.
element_ty
==
tl
.
uint8
,
"mx_scale_ptr must be uint8"
)
tl
.
static_assert
(
BLOCK_K
%
MX_PACK_DIVISOR
==
0
,
"BLOCK_K must be a multiple of MX_PACK_DIVISOR"
)
is_out_microscaled
:
tl
.
constexpr
=
stride_y_mx_z
is
not
None
OUT_BLOCK_N
:
tl
.
constexpr
=
BLOCK_N
//
ACTIVATION_REDUCTION_N
yN
=
N
//
ACTIVATION_REDUCTION_N
pid
=
tl
.
program_id
(
0
)
if
ExptOffsSum
is
not
None
and
XCD_SWIZZLE
>
1
:
# Determine how much padding there is on the expert data. This allows us to
# know the true grid size and avoid processing padding tiles.
padding_m
=
grid_m
-
tl
.
load
(
ExptOffsSum
)
else
:
padding_m
:
tl
.
constexpr
=
0
HAS_FUSED_SCATTER
:
tl
.
constexpr
=
WriteBackIndx
is
not
None
index_type
:
tl
.
constexpr
=
tl
.
int64
if
UPCAST_INDICES
else
tl
.
int32
unpadded_m
=
grid_m
-
padding_m
tl
.
assume
(
unpadded_m
>=
0
)
total_actual_tiles
=
batch_size
*
unpadded_m
*
grid_n
*
SPLIT_K
if
padding_m
>
0
and
pid
>=
total_actual_tiles
:
tl
.
device_assert
(
batch_size
==
0
)
pid_mn
=
pid
-
total_actual_tiles
if
pid_mn
<
padding_m
*
grid_n
:
pid_m
,
pid_n
=
swizzle2d
(
pid_mn
,
padding_m
,
grid_n
,
GROUP_M
)
# set masked out rows to 0
if
HAS_FUSED_SCATTER
and
N_EXPTS_ACT
==
1
:
_zero_masked_rows
(
pid_m
,
pid_n
,
Y
,
stride_y_m
,
stride_y_n
,
yN
,
ScatterSrcIndx
,
num_idxs
,
BLOCK_M
,
OUT_BLOCK_N
)
return
# swizzle program ids
pid_emnk
=
pid
if
XCD_SWIZZLE
!=
1
:
pid_emnk
=
xcd_swizzle
(
pid_emnk
,
total_actual_tiles
,
XCD_SWIZZLE
)
pid_e
=
pid_emnk
//
(
unpadded_m
*
grid_n
*
SPLIT_K
)
pid_mnk
=
pid_emnk
%
(
unpadded_m
*
grid_n
*
SPLIT_K
)
pid_k
=
pid_mnk
%
SPLIT_K
pid_mn
=
pid_mnk
//
SPLIT_K
pid_m
,
pid_n
=
swizzle2d
(
pid_mn
,
unpadded_m
,
grid_n
,
GROUP_M
)
# For split-k, advance to the output k slice
if
SPLIT_K
>
1
:
Y
+=
pid_k
.
to
(
index_type
)
*
stride_y_k
if
is_out_microscaled
:
YActualScale
+=
pid_k
.
to
(
index_type
)
*
stride_x_mx_k
# set masked out rows to 0
if
HAS_FUSED_SCATTER
and
N_EXPTS_ACT
==
1
:
_zero_masked_rows
(
pid_m
,
pid_n
,
Y
,
stride_y_m
,
stride_y_n
,
yN
,
ScatterSrcIndx
,
num_idxs
,
BLOCK_M
,
OUT_BLOCK_N
)
# unpack expert data
if
ExptData
is
None
:
tl
.
static_assert
(
M
is
not
None
)
expt_id
,
start_z
,
start_m
,
block_id
=
pid_e
,
pid_e
,
0
,
pid_m
else
:
tl
.
static_assert
(
M
is
None
)
expt_data
=
tl
.
load
(
ExptData
+
pid_m
)
if
expt_data
==
-
1
:
return
expt_id
=
expt_data
&
0x0000FFFF
block_id
=
expt_data
>>
16
M
=
tl
.
load
(
ExptHist
+
expt_id
)
start_m
=
tl
.
load
(
ExptOffs
+
expt_id
)
start_z
=
0
expt_id
,
block_id
=
expt_id
.
to
(
index_type
),
block_id
.
to
(
index_type
)
start_m
,
start_z
=
start_m
.
to
(
index_type
),
start_z
.
to
(
index_type
)
pid_n
,
pid_k
=
pid_n
.
to
(
index_type
),
pid_k
.
to
(
index_type
)
# A pointers
offs_x_m
=
BLOCK_M
*
block_id
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_x_m
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offs_x_m
%
M
,
BLOCK_M
),
BLOCK_M
)
X
+=
start_z
*
stride_x_z
if
GatherIndx
is
None
:
X
+=
start_m
*
stride_x_m
else
:
GatherIndx
+=
start_m
# no needs to bounds-check here because `offs_x_m` wraps around M dim
offs_x_m
=
tl
.
load
(
GatherIndx
+
offs_x_m
)
//
N_EXPTS_ACT
offs_k
=
BLOCK_K
*
pid_k
+
tl
.
arange
(
0
,
BLOCK_K
)
XPtrs
=
X
+
offs_x_m
.
to
(
index_type
)[:,
None
]
*
stride_x_m
+
offs_k
.
to
(
index_type
)[
None
,
:]
*
stride_x_k
# TODO: refactor if/else when triton front end improves
if
is_w_microscaled
:
if
SWIZZLE_MX_VALUE
==
"HOPPER_VALUE"
:
tl
.
static_assert
(
is_mxfp4
,
"Only mxfp4 is supported for HOPPER swizzling"
)
tl
.
static_assert
(
not
is_x_microscaled
)
# We have pack 2 fp4 values in a byte but we divide the dimension by 2
# when swizzling
W_K_DIVISOR
:
tl
.
constexpr
=
1
W_K_MULTIPLIER
:
tl
.
constexpr
=
2
W_N_DIVISOR
:
tl
.
constexpr
=
4
else
:
# We have pack 2 fp4 values in a byte
W_K_DIVISOR
:
tl
.
constexpr
=
2
if
is_mxfp4
else
1
W_K_MULTIPLIER
:
tl
.
constexpr
=
1
W_N_DIVISOR
:
tl
.
constexpr
=
1
if
W_TRANSPOSE
:
# When weight is transposed, 2 fp4 values are packed per Byte along
# the contiguous dimension, K.
PACKED_BLOCK_K_W
:
tl
.
constexpr
=
(
BLOCK_K
//
W_K_DIVISOR
)
*
W_K_MULTIPLIER
PACKED_BLOCK_N_W
:
tl
.
constexpr
=
BLOCK_N
//
W_N_DIVISOR
else
:
# When weight is not transposed, fp4 values are *not* packed along
# the contiguous dimension, N.
PACKED_BLOCK_K_W
:
tl
.
constexpr
=
BLOCK_K
PACKED_BLOCK_N_W
:
tl
.
constexpr
=
BLOCK_N
//
W_K_DIVISOR
MX_SCALE_BLOCK_K
:
tl
.
constexpr
=
BLOCK_K
//
MX_PACK_DIVISOR
WMxScale
+=
expt_id
*
stride_w_mx_e
if
SWIZZLE_MX_SCALE
==
"BLACKWELL_SCALE"
:
# TODO: support non W_TRANSPOSE with blackwell swizzling
tl
.
static_assert
(
W_TRANSPOSE
)
tl
.
static_assert
(
BLOCK_N
%
128
==
0
)
tl
.
static_assert
(
MX_SCALE_BLOCK_K
%
4
==
0
)
PACKED_MX_BLOCK
:
tl
.
constexpr
=
(
MX_SCALE_BLOCK_K
//
4
)
*
32
*
4
*
4
SCALE_BLOCK_N
:
tl
.
constexpr
=
BLOCK_N
//
128
stride_scale_k
:
tl
.
constexpr
=
1
elif
SWIZZLE_MX_SCALE
==
"HOPPER_SCALE"
:
# TODO: support non W_TRANSPOSE with Hopper swizzling
tl
.
static_assert
(
W_TRANSPOSE
)
n_warps
:
tl
.
constexpr
=
tl
.
extra
.
cuda
.
num_warps
()
tl
.
static_assert
(
BLOCK_N
%
(
2
*
n_warps
*
2
*
8
)
==
0
)
tl
.
static_assert
(
MX_SCALE_BLOCK_K
%
2
==
0
)
PACKED_MX_BLOCK
:
tl
.
constexpr
=
MX_SCALE_BLOCK_K
*
32
SCALE_BLOCK_N
:
tl
.
constexpr
=
BLOCK_N
//
32
stride_scale_k
=
stride_w_mx_k
elif
SWIZZLE_MX_SCALE
==
"CDNA4_SCALE"
:
tl
.
static_assert
(
stride_w_mx_k
is
not
None
)
tl
.
static_assert
(
stride_w_mx_n
is
not
None
)
NON_K_PRESHUFFLE_BLOCK_SIZE
:
tl
.
constexpr
=
32
PACKED_MX_BLOCK
:
tl
.
constexpr
=
MX_SCALE_BLOCK_K
*
NON_K_PRESHUFFLE_BLOCK_SIZE
SCALE_BLOCK_N
:
tl
.
constexpr
=
BLOCK_N
//
NON_K_PRESHUFFLE_BLOCK_SIZE
stride_scale_k
=
stride_w_mx_k
else
:
PACKED_MX_BLOCK
:
tl
.
constexpr
=
MX_SCALE_BLOCK_K
SCALE_BLOCK_N
:
tl
.
constexpr
=
BLOCK_N
stride_scale_k
=
stride_w_mx_k
offs_n_scale
=
(
pid_n
*
SCALE_BLOCK_N
+
tl
.
arange
(
0
,
SCALE_BLOCK_N
))
%
N
offs_n_scale
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offs_n_scale
,
SCALE_BLOCK_N
),
SCALE_BLOCK_N
)
# K dimension must be the last dimension for the scales
offs_k_scale
=
PACKED_MX_BLOCK
*
pid_k
+
tl
.
arange
(
0
,
PACKED_MX_BLOCK
)
WMxScalePtrs
=
WMxScale
+
offs_k_scale
.
to
(
index_type
)[
None
,
:]
*
stride_scale_k
+
offs_n_scale
.
to
(
index_type
)[:,
None
]
*
stride_w_mx_n
else
:
WMxScalePtrs
=
None
offs_k_scale
=
None
W_K_DIVISOR
:
tl
.
constexpr
=
1
W_K_MULTIPLIER
:
tl
.
constexpr
=
1
W_N_DIVISOR
:
tl
.
constexpr
=
1
PACKED_BLOCK_K_W
:
tl
.
constexpr
=
BLOCK_K
PACKED_BLOCK_N_W
:
tl
.
constexpr
=
BLOCK_N
# B pointers
offs_w_n
=
pid_n
*
PACKED_BLOCK_N_W
+
tl
.
arange
(
0
,
PACKED_BLOCK_N_W
)
offs_w_n
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offs_w_n
%
(
N
//
W_N_DIVISOR
),
PACKED_BLOCK_N_W
),
PACKED_BLOCK_N_W
)
if
is_x_microscaled
:
XMxScale
+=
start_z
.
to
(
index_type
)
*
stride_x_mx_z
if
GatherIndx
is
None
:
XMxScale
+=
start_m
*
stride_x_mx_m
offs_x_k_scale
=
MX_SCALE_BLOCK_K
*
pid_k
+
tl
.
arange
(
0
,
MX_SCALE_BLOCK_K
)
XMxScalePtrs
=
XMxScale
+
offs_x_m
.
to
(
index_type
)[:,
None
]
*
stride_x_mx_m
+
offs_x_k_scale
.
to
(
index_type
)[
None
,
:]
*
stride_x_mx_k
else
:
XMxScalePtrs
=
None
offs_w_k
=
PACKED_BLOCK_K_W
*
pid_k
+
tl
.
arange
(
0
,
PACKED_BLOCK_K_W
)
W
+=
expt_id
*
stride_w_e
WPtrs
=
W
+
(
offs_w_k
.
to
(
index_type
)[:,
None
]
*
stride_w_k
+
offs_w_n
.
to
(
index_type
)[
None
,
:]
*
stride_w_n
)
# compute output
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
K
,
BLOCK_K
*
pid_k
,
-
(
BLOCK_K
*
SPLIT_K
)):
if
EVEN_K
:
mask_k
=
tl
.
full
([
BLOCK_K
],
True
,
dtype
=
tl
.
int1
)
mask_k_w
=
tl
.
full
([
PACKED_BLOCK_K_W
],
True
,
dtype
=
tl
.
int1
)
if
is_w_microscaled
and
SWIZZLE_MX_SCALE
is
None
:
mask_k_scale
=
tl
.
full
([
PACKED_MX_BLOCK
],
True
,
dtype
=
tl
.
int1
)
if
is_x_microscaled
:
mask_x_k_scale
=
tl
.
full
([
MX_SCALE_BLOCK_K
],
True
,
dtype
=
tl
.
int1
)
else
:
mask_k
=
offs_k
<
k
mask_k_w
=
offs_w_k
<
((
k
//
(
W_K_DIVISOR
if
W_TRANSPOSE
else
1
))
*
W_K_MULTIPLIER
)
if
is_w_microscaled
and
SWIZZLE_MX_SCALE
is
None
:
mask_k_scale
=
offs_k_scale
*
MX_PACK_DIVISOR
<
k
if
is_x_microscaled
:
mask_x_k_scale
=
offs_x_k_scale
*
MX_PACK_DIVISOR
<
k
x
=
tl
.
load
(
XPtrs
,
mask
=
mask_k
[
None
,
:],
other
=
0.0
)
w
=
tl
.
load
(
WPtrs
,
mask
=
mask_k_w
[:,
None
],
other
=
0.0
,
cache_modifier
=
W_CACHE_MODIFIER
)
if
is_w_microscaled
:
x_format
:
tl
.
constexpr
=
get_scaled_dot_format_string
(
x
.
dtype
)
w_format
:
tl
.
constexpr
=
get_scaled_dot_format_string
(
w
.
dtype
)
if
is_x_microscaled
:
x_scales
=
tl
.
load
(
XMxScalePtrs
,
mask
=
mask_x_k_scale
[
None
,
:])
elif
x_format
==
"fp16"
or
x_format
==
"bf16"
:
x_scales
:
tl
.
constexpr
=
None
else
:
# Scale of 1 in E8M0 format
x_scales
=
tl
.
full
((
BLOCK_M
,
MX_SCALE_BLOCK_K
),
127
,
dtype
=
tl
.
uint8
)
if
SWIZZLE_MX_SCALE
==
"BLACKWELL_SCALE"
:
w_scales
=
unswizzle_mx_scale_bw
(
tl
.
load
(
WMxScalePtrs
))
elif
SWIZZLE_MX_SCALE
==
"HOPPER_SCALE"
:
# Handshake with the swizzling code
num_warps
:
tl
.
constexpr
=
tl
.
extra
.
cuda
.
num_warps
()
w_scales
=
unswizzle_mxfp4_scale_hopper
(
tl
.
load
(
WMxScalePtrs
),
mx_axis
=
1
,
num_warps
=
num_warps
)
elif
SWIZZLE_MX_SCALE
==
"CDNA4_SCALE"
:
w_scales
=
unswizzle_mx_scale_cdna4
(
tl
.
load
(
WMxScalePtrs
),
BLOCK_N
,
MX_SCALE_BLOCK_K
)
else
:
w_scales
=
tl
.
load
(
WMxScalePtrs
,
mask
=
mask_k_scale
[
None
,
:])
if
SWIZZLE_MX_VALUE
==
"HOPPER_VALUE"
:
# Handshake with the swizzling code
tl
.
static_assert
(
x_format
==
"bf16"
)
tl
.
static_assert
(
w_format
==
"e2m1"
)
w
=
mxfp4_to_bf16_triton
(
w
.
trans
(),
w_scales
,
1
)
tl
.
static_assert
(
w
.
dtype
==
tl
.
bfloat16
)
acc
=
acc
.
trans
()
x
=
x
.
trans
()
# w = w.trans()
acc
=
tl
.
dot
(
w
,
x
,
acc
,
max_num_imprecise_acc
=
MAX_NUM_IMPRECISE_ACC
,
allow_tf32
=
ALLOW_TF32
)
acc
=
acc
.
trans
()
else
:
rhs_k_pack
:
tl
.
constexpr
=
W_TRANSPOSE
or
not
is_w_microscaled
or
W_K_DIVISOR
!=
2
acc
=
tl
.
dot_scaled
(
x
,
x_scales
,
x_format
,
w
,
w_scales
,
w_format
,
acc
=
acc
,
fast_math
=
True
,
rhs_k_pack
=
rhs_k_pack
)
if
SWIZZLE_MX_SCALE
==
"BLACKWELL_SCALE"
:
WMxScalePtrs
+=
(
MX_SCALE_BLOCK_K
//
4
*
SPLIT_K
)
*
stride_w_mx_k
else
:
WMxScalePtrs
+=
(
PACKED_MX_BLOCK
*
SPLIT_K
)
*
stride_w_mx_k
if
is_x_microscaled
:
XMxScalePtrs
+=
(
MX_SCALE_BLOCK_K
*
SPLIT_K
)
*
stride_x_mx_k
else
:
acc
=
tl
.
dot
(
x
,
w
,
acc
,
max_num_imprecise_acc
=
MAX_NUM_IMPRECISE_ACC
,
allow_tf32
=
ALLOW_TF32
)
XPtrs
+=
(
BLOCK_K
*
SPLIT_K
)
*
stride_x_k
WPtrs
+=
(
PACKED_BLOCK_K_W
*
SPLIT_K
)
*
stride_w_k
# bias + scale
offs_m
=
BLOCK_M
*
block_id
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_y_n
=
BLOCK_N
*
pid_n
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_m
=
offs_m
<
M
mask_n
=
offs_y_n
<
N
if
B
is
not
None
:
BPtrs
=
B
+
expt_id
*
stride_b_e
+
offs_y_n
if
pid_k
==
0
:
bias
=
tl
.
load
(
BPtrs
,
mask
=
mask_n
,
other
=
0
)
else
:
bias
=
tl
.
full
([
BLOCK_N
],
0
,
dtype
=
tl
.
float32
)
else
:
bias
=
tl
.
full
([
BLOCK_N
],
0
,
dtype
=
tl
.
float32
)
if
Betas
is
not
None
:
betas
=
tl
.
load
(
Betas
+
start_m
+
offs_m
,
mask
=
mask_m
,
other
=
0.0
)
else
:
betas
=
tl
.
full
([
BLOCK_M
],
1
,
dtype
=
tl
.
float32
)
if
Gammas
is
not
None
:
gammas
=
tl
.
load
(
Gammas
+
start_m
+
offs_m
,
mask
=
mask_m
,
other
=
0.0
)
else
:
gammas
=
tl
.
full
([
BLOCK_M
],
1
,
dtype
=
tl
.
float32
)
# flexpoint
x_scale
=
load_scale
(
XScale
)
if
PER_BATCH_SCALE
:
w_scale
=
load_scale
(
WScale
+
expt_id
)
else
:
w_scale
=
load_scale
(
WScale
)
acc
*=
x_scale
*
w_scale
acc
=
acc
+
bias
[
None
,
:]
*
betas
[:,
None
]
if
out_alpha
is
not
None
:
acc
*=
out_alpha
if
ACTIVATION_FN
is
not
None
:
out
=
ACTIVATION_FN
(
acc
,
*
activation_fn_args
)
tl
.
static_assert
(
out
.
shape
[
1
]
==
OUT_BLOCK_N
,
f
"Activation fn out.shape[1] (
{
out
.
shape
[
1
]
}
) doesn't match computed OUT_BLOCK_N (
{
OUT_BLOCK_N
}
)"
)
offs_y_n
=
OUT_BLOCK_N
*
pid_n
+
tl
.
arange
(
0
,
OUT_BLOCK_N
)
mask_n
=
offs_y_n
<
yN
else
:
tl
.
static_assert
(
ACTIVATION_REDUCTION_N
==
1
,
"Activation reduction must be 1 if no activation fn is provided"
)
out
=
acc
out
*=
gammas
[:,
None
]
# write-back
Y
+=
start_z
.
to
(
index_type
)
*
stride_y_z
if
WriteBackIndx
is
not
None
:
WriteBackIndx
+=
start_m
dst_idx
=
tl
.
load
(
WriteBackIndx
+
offs_m
,
mask
=
start_m
+
offs_m
<
writeback_size
,
other
=-
1
)
mask_m
=
mask_m
&
(
dst_idx
!=
-
1
)
offs_y_m
=
dst_idx
else
:
Y
+=
start_m
*
stride_y_m
offs_y_m
=
offs_m
YPtrs
=
Y
+
offs_y_m
.
to
(
index_type
)[:,
None
]
*
stride_y_m
+
offs_y_n
.
to
(
index_type
)[
None
,
:]
*
stride_y_n
mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
is_out_microscaled
:
MX_SCALE_BLOCK_N
:
tl
.
constexpr
=
BLOCK_N
//
MXFP_BLOCK_SIZE
N_MX_BLOCK
:
tl
.
constexpr
=
tl
.
cdiv
(
N
,
MXFP_BLOCK_SIZE
)
tl
.
static_assert
(
EPILOGUE_FN
is
not
None
)
out
,
out_scale
=
EPILOGUE_FN
(
out
,
mask
,
*
epilogue_fn_args
)
tl
.
static_assert
(
BLOCK_N
%
MX_SCALE_BLOCK_N
==
0
,
""
)
offs_y_n_scale
=
MX_SCALE_BLOCK_N
*
pid_n
+
tl
.
arange
(
0
,
MX_SCALE_BLOCK_N
)
mask_n_scale
=
offs_y_n_scale
<
N_MX_BLOCK
YActualScale
+=
start_z
.
to
(
index_type
)
*
stride_y_mx_z
if
WriteBackIndx
is
None
:
YActualScale
+=
start_m
*
stride_y_mx_m
YActualScalePtrs
=
YActualScale
+
offs_y_m
.
to
(
index_type
)[:,
None
]
*
stride_y_mx_m
+
offs_y_n_scale
.
to
(
index_type
)[
None
,
:]
*
stride_y_mx_n
else
:
YActualScalePtrs
=
YActualScale
+
(
offs_y_m
-
NRows
).
to
(
index_type
)[:,
None
]
*
stride_y_mx_m
+
offs_y_n_scale
.
to
(
index_type
)[
None
,
:]
*
stride_y_mx_n
tl
.
store
(
YActualScalePtrs
,
out_scale
,
mask
=
mask_m
[:,
None
]
&
mask_n_scale
[
None
,
:])
else
:
out
=
float_to_flex
(
out
,
YExpectedScale
,
YActualScale
,
YChecksumScale
,
mask
,
Y
,
FLEXPOINT_SATURATE_INF
)
if
EPILOGUE_FN
is
not
None
and
not
IS_EPILOGUE_QUANT_MXFP8
:
out
=
EPILOGUE_FN
(
out
,
*
epilogue_fn_args
,
target_dtype
=
YPtrs
.
dtype
.
element_ty
)
tl
.
store
(
YPtrs
,
out
,
mask
=
mask
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py
deleted
100644 → 0
View file @
2b7160c6
# isort: off
# fmt: off
import
torch
import
triton
import
triton.language
as
tl
from
triton.tools.ragged_tma
import
load_ragged
,
store_ragged
from
compactor_vllm.triton_kernels
import
target_info
from
compactor_vllm.triton_kernels.tensor_details.layout_details.blackwell_scale
import
unswizzle_mx_scale_bw
from
compactor_vllm.triton_kernels.numerics_details.flexpoint
import
(
float_to_flex
,
load_scale
,
nan_propagating_absmax_reduce
,
compute_scale
,
)
from
compactor_vllm.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp
import
MXFP_BLOCK_SIZE
from
._common
import
make_matmul_repr
,
matmul_launch_metadata
,
swizzle2d
,
xcd_swizzle
,
get_scaled_dot_format_string
@
triton
.
constexpr_function
def
cuda_capability_geq
(
major
,
minor
):
return
target_info
.
cuda_capability_geq
(
major
,
minor
)
@
triton
.
constexpr_function
def
get_dtype
(
tensor_or_desc
:
tl
.
tensor
|
tl
.
tensor_descriptor
)
->
tl
.
dtype
:
if
isinstance
(
tensor_or_desc
,
tl
.
tensor
):
return
tensor_or_desc
.
dtype
.
element_ty
elif
isinstance
(
tensor_or_desc
,
tl
.
tensor_descriptor
):
return
tensor_or_desc
.
dtype
else
:
raise
ValueError
(
f
"Invalid type:
{
type
(
tensor_or_desc
)
}
"
)
@
triton
.
jit
def
_load_tile_attrs
(
tile_id
,
num_tiles
,
grid_m
,
grid_n
,
padding_m
,
M
,
ExptData
,
ExptHist
,
ExptOffs
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
XCD_SWIZZLE
:
tl
.
constexpr
):
# unpack and swizzle program ids
pid_emnk
=
tile_id
if
XCD_SWIZZLE
!=
1
:
pid_emnk
=
xcd_swizzle
(
pid_emnk
,
num_tiles
//
SPLIT_K
,
XCD_SWIZZLE
)
pid_e
=
pid_emnk
//
((
grid_m
-
padding_m
)
*
grid_n
*
SPLIT_K
)
pid_mnk
=
pid_emnk
%
((
grid_m
-
padding_m
)
*
grid_n
*
SPLIT_K
)
if
SPLIT_K
>
1
:
pid_k
=
pid_mnk
%
SPLIT_K
pid_mn
=
pid_mnk
//
SPLIT_K
else
:
pid_k
:
tl
.
constexpr
=
0
pid_mn
=
pid_mnk
pid_m
,
pid_n
=
swizzle2d
(
pid_mn
,
(
grid_m
-
padding_m
),
grid_n
,
GROUP_M
)
# unpack expert data
if
ExptData
is
None
:
tl
.
static_assert
(
M
is
not
None
)
expt_id
,
start_z
,
start_m
,
block_id
,
eM
=
pid_e
,
pid_e
,
0
,
pid_m
,
-
1
else
:
tl
.
static_assert
(
M
is
None
)
expt_data
=
tl
.
load
(
ExptData
+
pid_m
)
expt_id
=
expt_data
&
0x0000FFFF
block_id
=
expt_data
>>
16
eM
=
tl
.
load
(
ExptHist
+
expt_id
)
start_m
=
tl
.
load
(
ExptOffs
+
expt_id
)
start_z
=
0
off_m
=
BLOCK_M
*
block_id
off_n
=
BLOCK_N
*
pid_n
return
expt_id
,
start_z
,
start_m
,
eM
,
off_m
,
off_n
,
pid_k
@
triton
.
jit
def
_load_writeback_idx_and_mask
(
WriteBackIndx
,
writeback_size
,
offs
,
mask
):
mask
=
mask
&
(
offs
<
writeback_size
)
offs
=
tl
.
load
(
WriteBackIndx
+
offs
,
mask
=
mask
,
other
=-
1
)
mask
=
offs
!=
-
1
return
(
offs
,
mask
)
_matmul_ogs_repr
=
make_matmul_repr
(
"_p_matmul_ogs"
,
[
0
,
1
,
2
])
@
triton
.
jit
(
do_not_specialize
=
[
"TOKENS_PER_EXPT_FOR_ANNOTATION"
],
repr
=
_matmul_ogs_repr
,
launch_metadata
=
matmul_launch_metadata
)
def
_p_matmul_ogs
(
Y
,
YPtr
,
stride_y_k
,
stride_y_z
,
stride_y_m
,
stride_y_n
,
YExpectedScale
,
YActualScale
,
YChecksumScale
,
stride_y_mx_z
,
stride_y_mx_m
,
stride_y_mx_n
,
X
,
XPtr
,
stride_x_z
,
stride_x_m
,
stride_x_k
,
XScale
,
XMxScale
,
stride_x_mx_z
,
stride_x_mx_m
,
stride_x_mx_k
,
W
,
WPtr
,
stride_w_e
,
stride_w_k
,
stride_w_n
,
W_TRANSPOSE
:
tl
.
constexpr
,
WScale
,
MxScale
,
stride_mx_e
,
stride_mx_k
,
stride_mx_n
,
B
,
stride_b_e
,
# Bias
NRows
,
M
,
N
,
K
,
# shapes
# expt data
Betas
,
Gammas
,
GatherIndx
,
ScatterSrcIndx
,
num_idxs
,
WriteBackIndx
,
writeback_size
,
ExptHist
,
ExptOffs
,
ExptOffsSum
,
ExptData
,
# true grid size
batch_size
,
grid_m
,
grid_n
,
# Out scale
out_alpha
,
# fused activation function
ACTIVATION_FN
:
tl
.
constexpr
,
activation_fn_args
,
ACTIVATION_REDUCTION_N
:
tl
.
constexpr
,
# epilogue transform
EPILOGUE_FN
:
tl
.
constexpr
,
epilogue_fn_args
,
# MoE config
N_EXPTS_TOT
:
tl
.
constexpr
,
N_EXPTS_ACT
:
tl
.
constexpr
,
# precision config
MAX_NUM_IMPRECISE_ACC
:
tl
.
constexpr
,
ALLOW_TF32
:
tl
.
constexpr
,
FLEXPOINT_SATURATE_INF
:
tl
.
constexpr
,
PER_BATCH_SCALE
:
tl
.
constexpr
,
# optimization config
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
XCD_SWIZZLE
:
tl
.
constexpr
,
# NYI: Must be None
SWIZZLE_MX_VALUE
:
tl
.
constexpr
,
# One of ["BLACKWELL", None]
SWIZZLE_MX_SCALE
:
tl
.
constexpr
,
EPILOGUE_SUBTILE
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
W_CACHE_MODIFIER
:
tl
.
constexpr
,
NUM_SMS
:
tl
.
constexpr
,
X_TMA_MODE
:
tl
.
constexpr
,
Y_TMA_MODE
:
tl
.
constexpr
,
TOKENS_PER_EXPT_FOR_ANNOTATION
=
None
,
UPCAST_INDICES
:
tl
.
constexpr
=
False
,
SWAP_XW
:
tl
.
constexpr
=
False
,
IS_EPILOGUE_QUANT_MXFP8
:
tl
.
constexpr
=
False
):
# tl.static_assert(SWIZZLE_MX_VALUE is None, "NYI. Value swizzling")
# why is this faster than using host-side tensor descriptor?!
if
Y_TMA_MODE
is
not
None
:
Y
=
tl
.
make_tensor_descriptor
(
YPtr
,
Y
.
shape
,
Y
.
strides
[:
-
1
]
+
(
1
,),
Y
.
block_shape
)
is_microscaled_format
:
tl
.
constexpr
=
MxScale
is
not
None
tl
.
static_assert
(
not
is_microscaled_format
or
W_TRANSPOSE
,
"NYI. Non-transposed mxfp4 weights"
)
MX_PACK_DIVISOR
:
tl
.
constexpr
=
MXFP_BLOCK_SIZE
if
is_microscaled_format
:
w_type
:
tl
.
constexpr
=
get_dtype
(
W
)
tl
.
static_assert
(
w_type
==
tl
.
uint8
or
(
w_type
==
tl
.
float8e4nv
or
w_type
==
tl
.
float8e5
),
"mx_weight_ptr must be uint8"
)
tl
.
static_assert
(
get_dtype
(
MxScale
)
==
tl
.
uint8
,
"mx_scale_ptr must be uint8"
)
tl
.
static_assert
(
BLOCK_K
%
MX_PACK_DIVISOR
==
0
,
"BLOCK_K must be a multiple of MX_PACK_DIVISOR"
)
tl
.
static_assert
(
SWIZZLE_MX_SCALE
==
"BLACKWELL_SCALE"
or
SWIZZLE_MX_SCALE
is
None
,
"Only Blackwell swizzling is supported for scales"
)
# We have pack 2 fp4 values in a byte
W_PACK_DIVISOR
:
tl
.
constexpr
=
2
if
w_type
==
tl
.
uint8
else
1
PACKED_BLOCK_K_W
:
tl
.
constexpr
=
BLOCK_K
//
W_PACK_DIVISOR
MX_SCALE_BLOCK_K
:
tl
.
constexpr
=
BLOCK_K
//
MX_PACK_DIVISOR
else
:
W_PACK_DIVISOR
:
tl
.
constexpr
=
1
MX_SCALE_BLOCK_K
:
tl
.
constexpr
=
1
PACKED_BLOCK_K_W
:
tl
.
constexpr
=
BLOCK_K
tl
.
static_assert
(
SWIZZLE_MX_SCALE
is
None
)
if
ExptOffsSum
is
not
None
:
# Determine how much padding there is on the expert data. This allows us to
# know the true grid size and avoid processing padding tiles.
padding_m
=
grid_m
-
tl
.
load
(
ExptOffsSum
)
else
:
padding_m
:
tl
.
constexpr
=
0
index_type
:
tl
.
constexpr
=
tl
.
int64
USE_FLEXPOINT_SCALE
:
tl
.
constexpr
=
YActualScale
is
not
None
or
YChecksumScale
is
not
None
HAS_SCATTER
:
tl
.
constexpr
=
WriteBackIndx
is
not
None
HAS_GATHER
:
tl
.
constexpr
=
GatherIndx
is
not
None
USE_GATHER_TMA
:
tl
.
constexpr
=
HAS_GATHER
and
X_TMA_MODE
==
"dense"
USE_SCATTER_TMA
:
tl
.
constexpr
=
HAS_SCATTER
and
Y_TMA_MODE
==
"dense"
if
EPILOGUE_SUBTILE
is
None
:
SUBTILE_FACTOR
:
tl
.
constexpr
=
1
else
:
SUBTILE_FACTOR
:
tl
.
constexpr
=
EPILOGUE_SUBTILE
EPILOGUE_BLOCK_N
:
tl
.
constexpr
=
BLOCK_N
//
SUBTILE_FACTOR
OUT_BLOCK_N
:
tl
.
constexpr
=
EPILOGUE_BLOCK_N
//
ACTIVATION_REDUCTION_N
yN
=
N
//
ACTIVATION_REDUCTION_N
# set masked out rows to 0
if
HAS_SCATTER
and
N_EXPTS_ACT
==
1
:
# Iterate with reversed pids so that later pids will get more tiles if the number of
# tiles isn't evenly divisible by the number of SMs.
# The main loop after this iterates in the forward direction such that earlier
# pids get more tiles if the number of tiles isn't evenly divisible.
# This helps balance the work across the SMs.
for
pid_mnk
in
range
(
NUM_SMS
-
tl
.
program_id
(
0
)
-
1
,
batch_size
*
grid_m
*
grid_n
*
SPLIT_K
,
NUM_SMS
):
pid_k
=
pid_mnk
%
SPLIT_K
pid_mn
=
pid_mnk
//
SPLIT_K
pid_m
,
pid_n
=
swizzle2d
(
pid_mn
,
grid_m
,
grid_n
,
GROUP_M
)
z
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
//
ACTIVATION_REDUCTION_N
],
dtype
=
tl
.
float32
)
offs_m
=
z
.
shape
[
0
]
*
pid_m
+
tl
.
arange
(
0
,
z
.
shape
[
0
])
offs_n
=
z
.
shape
[
1
]
*
pid_n
+
tl
.
arange
(
0
,
z
.
shape
[
1
])
src_idx
=
tl
.
load
(
ScatterSrcIndx
+
offs_m
,
mask
=
offs_m
<
num_idxs
,
other
=
0
)
YPtrs
=
YPtr
+
offs_m
.
to
(
index_type
)[:,
None
]
*
stride_y_m
+
offs_n
[
None
,
:]
*
stride_y_n
mask_n
=
offs_n
<
yN
mask
=
(
src_idx
==
-
1
)[:,
None
]
&
mask_n
[
None
,
:]
tl
.
store
(
YPtrs
+
pid_k
*
stride_y_k
,
z
,
mask
=
mask
)
k_tiles
=
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)
num_tiles
=
batch_size
*
(
grid_m
-
padding_m
)
*
grid_n
*
SPLIT_K
# If true, do not share loop-carried variables between the prologue and the
# epilogue to enable better pipelining with mmav5
INDEPENDENT_EPILOGUE
:
tl
.
constexpr
=
cuda_capability_geq
(
10
,
0
)
# start negative; will be incremented at the top of the loop
if
INDEPENDENT_EPILOGUE
:
tile_id1
=
tl
.
program_id
(
0
)
-
NUM_SMS
# Keep track of local max for updating flexpoint scales.
THREADS_PER_BLOCK
:
tl
.
constexpr
=
tl
.
extra
.
cuda
.
num_threads
()
local_absmax
=
tl
.
full
([
THREADS_PER_BLOCK
],
0.0
,
tl
.
uint32
)
DISALLOW_ACC_MULTI_BUFFER
:
tl
.
constexpr
=
is_microscaled_format
and
BLOCK_M
*
BLOCK_N
>=
128
*
256
for
tile_id
in
tl
.
range
(
tl
.
program_id
(
0
),
num_tiles
,
NUM_SMS
,
flatten
=
True
,
disallow_acc_multi_buffer
=
DISALLOW_ACC_MULTI_BUFFER
,
warp_specialize
=
True
):
expt_id
,
start_z
,
start_m
,
eM
,
off_m
,
off_n
,
pid_k
=
_load_tile_attrs
(
tile_id
,
num_tiles
,
grid_m
,
grid_n
,
padding_m
,
M
,
ExptData
,
ExptHist
,
ExptOffs
,
BLOCK_M
,
BLOCK_N
,
SPLIT_K
,
GROUP_M
,
XCD_SWIZZLE
)
# Base pointers and offsets.
if
X_TMA_MODE
is
None
:
XBase
=
X
+
start_z
.
to
(
index_type
)
*
stride_x_z
offs_x_k
=
tl
.
arange
(
0
,
BLOCK_K
)[
None
,
:]
*
stride_x_k
if
SPLIT_K
>
1
:
offs_x_k
+=
pid_k
.
to
(
index_type
)
*
BLOCK_K
*
stride_x_k
if
USE_GATHER_TMA
:
offs_m
=
off_m
+
tl
.
arange
(
0
,
BLOCK_M
)
mask_m
=
offs_m
<
(
M
if
M
is
not
None
else
eM
)
if
ExptData
is
None
:
offs_x_m
=
tl
.
load
(
GatherIndx
+
start_m
.
to
(
index_type
)
+
offs_m
,
mask
=
mask_m
)
# Bump rows to account for the Z offset.
offs_x_m
+=
start_z
*
(
stride_x_z
//
stride_x_m
)
offs_x_m
=
tl
.
where
(
mask_m
,
offs_x_m
,
-
1
)
else
:
offs_x_m
=
tl
.
load
(
GatherIndx
+
start_m
.
to
(
index_type
)
+
offs_m
,
mask
=
mask_m
,
other
=-
N_EXPTS_ACT
)
//
N_EXPTS_ACT
elif
X_TMA_MODE
is
None
:
tl
.
static_assert
(
HAS_GATHER
)
offs_m
=
off_m
+
tl
.
arange
(
0
,
BLOCK_M
)
if
M
is
not
None
:
offs_m
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offs_m
%
M
,
BLOCK_M
),
BLOCK_M
)
else
:
offs_m
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
offs_m
%
eM
,
BLOCK_M
),
BLOCK_M
)
# no needs to bounds-check here because `offs_m` wraps around M dim
offs_m
=
tl
.
load
(
GatherIndx
+
start_m
.
to
(
index_type
)
+
offs_m
)
//
N_EXPTS_ACT
offs_x_m
=
offs_m
.
to
(
index_type
)[:,
None
]
*
stride_x_m
acc
=
tl
.
zeros
((
BLOCK_N
,
BLOCK_M
)
if
SWAP_XW
else
(
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
ki
in
tl
.
range
(
k_tiles
,
disallow_acc_multi_buffer
=
DISALLOW_ACC_MULTI_BUFFER
):
off_k
=
pid_k
*
BLOCK_K
+
ki
*
BLOCK_K
*
SPLIT_K
off_k_w
=
pid_k
*
PACKED_BLOCK_K_W
+
ki
*
PACKED_BLOCK_K_W
*
SPLIT_K
off_k_mx
=
pid_k
*
MX_SCALE_BLOCK_K
+
ki
*
MX_SCALE_BLOCK_K
*
SPLIT_K
# --- load x ---
if
USE_GATHER_TMA
:
x
=
X
.
gather
(
offs_x_m
,
off_k
)
elif
X_TMA_MODE
==
"dense"
:
x
=
X
.
load
([
start_z
,
start_m
+
off_m
,
off_k
])
x
=
x
.
reshape
(
BLOCK_M
,
BLOCK_K
)
elif
X_TMA_MODE
==
"ragged"
:
x
=
load_ragged
(
X
,
start_m
,
eM
,
[
start_z
,
off_m
,
off_k
],
ragged_dim
=
1
)
x
=
x
.
reshape
(
BLOCK_M
,
BLOCK_K
)
else
:
tl
.
static_assert
(
X_TMA_MODE
is
None
)
XPtrs
=
XBase
+
offs_x_m
+
offs_x_k
XBase
+=
BLOCK_K
*
SPLIT_K
*
stride_x_k
mask_k
=
tl
.
arange
(
0
,
BLOCK_K
)
<
K
-
off_k
if
EVEN_K
:
if
SPLIT_K
>
1
:
x
=
tl
.
load
(
XPtrs
,
mask
=
mask_k
[
None
,
:],
other
=
0.0
)
else
:
x
=
tl
.
load
(
XPtrs
)
else
:
x
=
tl
.
load
(
XPtrs
,
mask
=
mask_k
[
None
,
:],
other
=
0.0
)
# --- load w ---
if
W_TRANSPOSE
:
w
=
tl
.
reshape
(
W
.
load
([
expt_id
,
off_n
,
off_k_w
]),
W
.
block_shape
[
1
:]).
T
else
:
w
=
tl
.
reshape
(
W
.
load
([
expt_id
,
off_k_w
,
off_n
]),
W
.
block_shape
[
1
:])
# --- load w_scale ---
if
is_microscaled_format
:
x_format
:
tl
.
constexpr
=
get_scaled_dot_format_string
(
x
.
dtype
)
mx_format
:
tl
.
constexpr
=
get_scaled_dot_format_string
(
w
.
dtype
)
if
x_format
==
"fp16"
or
x_format
==
"bf16"
:
x_scales
:
tl
.
constexpr
=
None
else
:
x_scales
=
tl
.
full
((
BLOCK_M
,
BLOCK_K
//
MX_PACK_DIVISOR
),
127
,
dtype
=
tl
.
uint8
)
if
SWIZZLE_MX_SCALE
==
"BLACKWELL_SCALE"
:
flattened_expt_n_idx
=
expt_id
*
((
N
+
127
)
//
128
)
+
(
off_n
//
128
)
w_scales
=
MxScale
.
load
([
0
,
flattened_expt_n_idx
,
pid_k
*
MX_SCALE_BLOCK_K
//
4
+
ki
*
(
MX_SCALE_BLOCK_K
//
4
*
SPLIT_K
),
0
,
0
])
w_scales
=
w_scales
.
reshape
((
w_scales
.
shape
[
1
],
w_scales
.
shape
[
2
]
*
w_scales
.
shape
[
-
2
]
*
w_scales
.
shape
[
-
1
]))
w_scales
=
unswizzle_mx_scale_bw
(
w_scales
)
else
:
w_scales
=
MxScale
.
load
([
expt_id
,
off_k_mx
,
off_n
])
w_scales
=
tl
.
reshape
(
w_scales
,
*
w_scales
.
shape
[
1
:]).
T
# --- update accumulator ---
if
is_microscaled_format
:
if
SWAP_XW
:
acc
=
tl
.
dot_scaled
(
w
.
T
,
w_scales
,
mx_format
,
x
.
T
,
x_scales
,
x_format
,
acc
=
acc
,
fast_math
=
True
)
else
:
acc
=
tl
.
dot_scaled
(
x
,
x_scales
,
x_format
,
w
,
w_scales
,
mx_format
,
acc
=
acc
,
fast_math
=
True
)
else
:
if
SWAP_XW
:
acc
=
tl
.
dot
(
w
.
T
,
x
.
T
,
acc
,
max_num_imprecise_acc
=
MAX_NUM_IMPRECISE_ACC
,
allow_tf32
=
ALLOW_TF32
)
else
:
acc
=
tl
.
dot
(
x
,
w
,
acc
,
max_num_imprecise_acc
=
MAX_NUM_IMPRECISE_ACC
,
allow_tf32
=
ALLOW_TF32
)
if
INDEPENDENT_EPILOGUE
:
tile_id1
+=
NUM_SMS
expt_id1
,
start_z1
,
start_m1
,
eM1
,
off_m1
,
off_n1
,
pid_k1
=
_load_tile_attrs
(
tile_id1
,
num_tiles
,
grid_m
,
grid_n
,
padding_m
,
M
,
ExptData
,
ExptHist
,
ExptOffs
,
BLOCK_M
,
BLOCK_N
,
SPLIT_K
,
GROUP_M
,
XCD_SWIZZLE
)
else
:
tile_id1
,
expt_id1
,
start_z1
,
start_m1
,
eM1
=
tile_id
,
expt_id
,
start_z
,
start_m
,
eM
off_m1
,
off_n1
,
pid_k1
=
off_m
,
off_n
,
pid_k
offs_m
=
off_m1
+
tl
.
arange
(
0
,
BLOCK_M
)
mask_m
=
offs_m
<
(
M
if
M
is
not
None
else
eM1
)
if
USE_SCATTER_TMA
:
offs_y_m
,
mask_m
=
_load_writeback_idx_and_mask
(
WriteBackIndx
,
writeback_size
,
start_m1
+
offs_m
,
mask_m
)
MASK_ACC
:
tl
.
constexpr
=
USE_FLEXPOINT_SCALE
if
SPLIT_K
>
1
:
# Compute the split k offset in number of rows, and add it to offs_y_m.
# This allows us to write to the correct slice in the output tensor while using
# a 2D TMA scatter.
tl
.
device_assert
(
stride_y_k
//
stride_y_m
==
tl
.
cdiv
(
stride_y_k
,
stride_y_m
))
split_k_row_offs
=
pid_k1
*
(
stride_y_k
//
stride_y_m
)
offs_y_m
=
tl
.
where
(
mask_m
,
offs_y_m
+
split_k_row_offs
,
offs_y_m
)
elif
Y_TMA_MODE
is
None
:
tl
.
static_assert
(
HAS_SCATTER
)
offs_y_m
,
mask_m
=
_load_writeback_idx_and_mask
(
WriteBackIndx
,
writeback_size
,
start_m1
+
offs_m
,
mask_m
)
MASK_ACC
:
tl
.
constexpr
=
USE_FLEXPOINT_SCALE
else
:
offs_y_m
=
start_m1
+
offs_m
MASK_ACC
=
False
if
USE_GATHER_TMA
else
USE_FLEXPOINT_SCALE
# bias + scale
offs_y_n
=
off_n1
+
tl
.
arange
(
0
,
BLOCK_N
)
mask_n
=
offs_y_n
<
N
if
B
is
not
None
:
BPtrs
=
B
+
expt_id1
*
stride_b_e
+
offs_y_n
if
pid_k1
==
0
:
bias
=
tl
.
load
(
BPtrs
,
mask
=
mask_n
,
other
=
0
)
else
:
bias
=
tl
.
full
([
BLOCK_N
],
0
,
dtype
=
tl
.
float32
)
else
:
bias
=
tl
.
full
([
BLOCK_N
],
0
,
dtype
=
tl
.
float32
)
if
Betas
is
not
None
:
betas
=
tl
.
load
(
Betas
+
start_m1
+
offs_m
,
mask
=
mask_m
,
other
=
0.0
)
else
:
betas
=
tl
.
full
([
BLOCK_M
],
1
,
dtype
=
tl
.
float32
)
if
Gammas
is
not
None
:
gammas
=
tl
.
load
(
Gammas
+
start_m1
+
offs_m
,
mask
=
mask_m
,
other
=
0.0
)
else
:
gammas
=
tl
.
full
([
BLOCK_M
],
1
,
dtype
=
tl
.
float32
)
x_scale
=
load_scale
(
XScale
)
if
PER_BATCH_SCALE
:
w_scale
=
load_scale
(
WScale
+
expt_id1
)
else
:
w_scale
=
load_scale
(
WScale
)
accs
=
(
acc
,)
biases
=
(
bias
,)
if
SUBTILE_FACTOR
>=
2
:
acc0
,
acc1
=
acc
.
reshape
(
BLOCK_M
,
2
,
BLOCK_N
//
2
).
permute
(
0
,
2
,
1
).
split
()
accs
=
(
acc0
,
acc1
)
bias0
,
bias1
=
bias
.
reshape
(
2
,
BLOCK_N
//
2
).
permute
(
1
,
0
).
split
()
biases
=
(
bias0
,
bias1
)
if
SUBTILE_FACTOR
>=
4
:
acc00
,
acc01
=
acc0
.
reshape
(
BLOCK_M
,
2
,
BLOCK_N
//
4
).
permute
(
0
,
2
,
1
).
split
()
acc10
,
acc11
=
acc1
.
reshape
(
BLOCK_M
,
2
,
BLOCK_N
//
4
).
permute
(
0
,
2
,
1
).
split
()
accs
=
(
acc00
,
acc01
,
acc10
,
acc11
)
bias00
,
bias01
=
bias0
.
reshape
(
2
,
BLOCK_N
//
4
).
permute
(
1
,
0
).
split
()
bias10
,
bias11
=
bias1
.
reshape
(
2
,
BLOCK_N
//
4
).
permute
(
1
,
0
).
split
()
biases
=
(
bias00
,
bias01
,
bias10
,
bias11
)
tl
.
static_assert
(
EPILOGUE_BLOCK_N
==
BLOCK_N
//
SUBTILE_FACTOR
)
tl
.
static_assert
(
len
(
accs
)
==
SUBTILE_FACTOR
)
for
a_i
in
tl
.
static_range
(
len
(
accs
)):
acc_tile
=
accs
[
a_i
]
acc_tile
*=
x_scale
*
w_scale
if
SWAP_XW
:
acc_tile
=
acc_tile
.
T
acc_tile
=
acc_tile
+
biases
[
a_i
][
None
,
:]
*
betas
[:,
None
]
if
out_alpha
is
not
None
:
acc_tile
*=
out_alpha
if
ACTIVATION_FN
is
not
None
:
out
=
ACTIVATION_FN
(
acc_tile
,
*
activation_fn_args
)
tl
.
static_assert
(
out
.
shape
[
1
]
==
OUT_BLOCK_N
,
f
"Activation fn out.shape[1] (
{
out
.
shape
[
1
]
}
) doesn't match computed OUT_BLOCK_N (
{
OUT_BLOCK_N
}
)"
)
else
:
tl
.
static_assert
(
ACTIVATION_REDUCTION_N
==
1
,
"Activation reduction must be 1 if no activation fn is provided"
)
out
=
acc_tile
out
*=
gammas
[:,
None
]
if
MASK_ACC
:
out
=
tl
.
where
(
mask_m
[:,
None
],
out
,
0.0
)
# Flexpoint
out_view
=
tl
.
reshape
(
out
,
[
out
.
numel
//
THREADS_PER_BLOCK
,
THREADS_PER_BLOCK
],
can_reorder
=
True
)
local_absmax
=
tl
.
maximum
(
local_absmax
,
nan_propagating_absmax_reduce
(
out_view
,
axis
=
0
))
out
=
float_to_flex
(
out
,
YExpectedScale
,
None
,
# ActualScale: local absmax is tracked and updated after the loop
YChecksumScale
,
None
,
# mask: out is manually masked to 0
YPtr
,
FLEXPOINT_SATURATE_INF
)
if
EPILOGUE_FN
is
not
None
:
out
=
EPILOGUE_FN
(
out
,
*
epilogue_fn_args
,
target_dtype
=
YPtr
.
dtype
.
element_ty
,
pid
=
len
(
accs
)
*
tile_id1
+
a_i
)
out_off_n
=
off_n1
//
ACTIVATION_REDUCTION_N
+
a_i
*
OUT_BLOCK_N
out
=
out
.
to
(
YPtr
.
dtype
.
element_ty
)
if
USE_SCATTER_TMA
:
# Convert -1 offsets to INT_MAX. We do this by clearing the leading bit. Note that
# there shouldn't be any other negative values.
offs_y_m
=
(
offs_y_m
.
to
(
tl
.
uint32
,
bitcast
=
True
)
&
0x7FFFFFFF
).
to
(
tl
.
int32
,
bitcast
=
True
)
Y
.
scatter
(
out
,
offs_y_m
,
out_off_n
)
elif
Y_TMA_MODE
==
"dense"
:
out
=
tl
.
reshape
(
out
,
[
1
]
+
out
.
shape
)
off_kz
=
pid_k
*
batch_size
+
start_z1
Y
.
store
([
off_kz
,
off_m1
,
out_off_n
],
out
)
elif
Y_TMA_MODE
==
"ragged"
:
out
=
tl
.
reshape
(
out
,
[
1
]
+
out
.
shape
)
store_ragged
(
Y
,
start_m1
,
eM1
,
[
pid_k
,
off_m1
,
out_off_n
],
out
,
ragged_dim
=
1
)
else
:
tl
.
static_assert
(
Y_TMA_MODE
is
None
)
offs_y_n
=
out_off_n
+
tl
.
arange
(
0
,
OUT_BLOCK_N
)
mask_n
=
offs_y_n
<
yN
YPtrs
=
YPtr
+
pid_k1
.
to
(
index_type
)
*
stride_y_k
+
start_z1
.
to
(
index_type
)
*
stride_y_z
+
offs_y_m
.
to
(
index_type
)[:,
None
]
*
stride_y_m
+
offs_y_n
[
None
,
:]
*
stride_y_n
mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
tl
.
store
(
YPtrs
,
out
,
mask
=
mask
)
# Update the flexpoint scales
if
YActualScale
is
not
None
:
tl
.
atomic_max
(
YActualScale
,
compute_scale
(
local_absmax
.
to
(
tl
.
float32
,
bitcast
=
True
),
YPtr
),
sem
=
"relaxed"
)
_per_device_alloc_fns
=
{}
def
get_per_device_per_stream_alloc_fn
(
device
):
if
device
not
in
_per_device_alloc_fns
:
_per_stream_tensors
=
{}
def
alloc_fn
(
size
:
int
,
alignment
:
int
,
stream
):
assert
alignment
==
128
if
stream
not
in
_per_stream_tensors
or
_per_stream_tensors
[
stream
].
numel
()
<
size
:
_per_stream_tensors
[
stream
]
=
torch
.
empty
(
size
,
device
=
device
,
dtype
=
torch
.
int8
)
_per_stream_tensors
[
stream
].
__hibernate__
=
{
"type"
:
"ignore"
}
return
_per_stream_tensors
[
stream
]
_per_device_alloc_fns
[
device
]
=
alloc_fn
return
_per_device_alloc_fns
[
device
]
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/_reduce_grouped.py
deleted
100644 → 0
View file @
2b7160c6
from
compactor_vllm.triton_kernels.numerics_details.flexpoint
import
(
float_to_flex
,
load_scale
,
)
from
compactor_vllm.triton_kernels.numerics_details.mxfp
import
quantize_mxfp8_fn
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_reduce_grouped
(
X
,
stride_xb
:
tl
.
uint64
,
stride_xm
:
tl
.
uint64
,
stride_xn
,
#
XScale
,
# input scalar flex scale
Out
,
stride_om
:
tl
.
uint64
,
stride_on
,
# output tensor
OutExpectedScale
,
OutActualScale
,
OutChecksumScale
,
# output scalar flex scales
InIndx
,
B
,
N
,
#
XMxScale
,
stride_mxb
:
tl
.
uint64
,
stride_mxs
:
tl
.
uint64
,
# optional per-32-col output MXFP scales (uint8)
OutMxScale
,
stride_omxs
:
tl
.
uint64
,
# optional per-32-col output MXFP scales (uint8)
# fused activation function
ACTIVATION_FN
:
tl
.
constexpr
,
activation_fn_args
,
ACTIVATION_REDUCTION_N
:
tl
.
constexpr
,
# epilogue transform
EPILOGUE_FN
:
tl
.
constexpr
,
epilogue_fn_args
,
#
HAS_IN_MX_SCALE
:
tl
.
constexpr
,
HAS_OUT_MX_SCALE
:
tl
.
constexpr
,
FLEXPOINT_SATURATE_INF
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
pid_t
=
tl
.
program_id
(
0
)
BLOCK_N_OUT
:
tl
.
constexpr
=
BLOCK_N
//
ACTIVATION_REDUCTION_N
# persistent along N: single program on N, iterate tiles of size BLOCK_N
start
=
pid_t
*
K
# load indices into a tuple
if
InIndx
is
None
:
indxs
=
(
pid_t
,)
else
:
indxs
=
()
for
i
in
tl
.
static_range
(
0
,
K
):
indxs
=
indxs
+
(
tl
.
load
(
InIndx
+
start
+
i
),)
# determine first valid topk row
fi
=
indxs
[(
K
-
1
)]
for
i
in
tl
.
static_range
(
K
-
2
,
-
1
,
-
1
):
fi
=
tl
.
where
(
indxs
[
i
]
!=
-
1
,
indxs
[
i
],
fi
)
# record overwritten row index (may be -1 if none)
XPtrs
=
X
+
tl
.
arange
(
0
,
BLOCK_N
)
*
stride_xn
OutPtrs
=
Out
+
tl
.
arange
(
0
,
BLOCK_N_OUT
)
*
stride_on
if
HAS_IN_MX_SCALE
:
XScalePtrs
=
XMxScale
+
tl
.
arange
(
0
,
BLOCK_N
//
32
)
*
stride_xn
if
HAS_OUT_MX_SCALE
:
OutScalePtrs
=
OutMxScale
+
tl
.
arange
(
0
,
BLOCK_N_OUT
//
32
)
*
stride_on
x_scale
=
load_scale
(
XScale
)
for
n_curr
in
tl
.
range
(
0
,
N
,
BLOCK_N
,
num_stages
=
4
):
acc
=
tl
.
zeros
([
BLOCK_N_OUT
],
dtype
=
tl
.
float32
)
x_n_mask
=
tl
.
arange
(
0
,
BLOCK_N
)
<
N
-
n_curr
x_n_mask_scale
=
tl
.
arange
(
0
,
BLOCK_N
//
32
)
<
tl
.
cdiv
(
N
-
n_curr
,
32
)
# accumulate contributions for this tile
for
i
in
tl
.
static_range
(
0
,
K
):
curr
=
tl
.
zeros
([
BLOCK_N
],
dtype
=
tl
.
float32
)
# iterate over split_k partial values
for
b
in
tl
.
range
(
0
,
B
):
is_valid
=
indxs
[
i
]
!=
-
1
x_row_ptr
=
XPtrs
+
indxs
[
i
]
*
stride_xm
+
b
*
stride_xb
vals
=
tl
.
load
(
x_row_ptr
,
mask
=
x_n_mask
&
is_valid
,
other
=
0.0
)
vals
=
vals
.
to
(
tl
.
float32
)
if
HAS_IN_MX_SCALE
:
scale_row_ptr
=
XScalePtrs
+
indxs
[
i
]
*
stride_mxs
+
b
*
stride_mxb
scale
=
tl
.
load
(
scale_row_ptr
,
mask
=
x_n_mask_scale
&
is_valid
,
other
=
0.0
)
scale
=
(
scale
.
to
(
tl
.
uint32
)
<<
23
).
to
(
tl
.
float32
,
bitcast
=
True
)
vals
=
vals
.
reshape
([
BLOCK_N
//
32
,
32
])
vals
=
(
scale
[:,
None
]
*
vals
).
reshape
([
BLOCK_N
])
curr
+=
vals
# apply nonlinearity to split-k output
if
ACTIVATION_FN
is
not
None
:
curr
=
ACTIVATION_FN
(
curr
[
None
,
:],
*
activation_fn_args
)
curr
=
tl
.
reshape
(
curr
,
[
curr
.
shape
[
-
1
]])
# update final accumulator
acc
+=
curr
acc
*=
x_scale
# Compute per-32-col MXFP scales for this tile if requested
Nrem
=
(
N
-
n_curr
)
//
ACTIVATION_REDUCTION_N
out_n_mask
=
tl
.
arange
(
0
,
BLOCK_N_OUT
)
<
Nrem
out_n_mask_scale
=
tl
.
arange
(
0
,
BLOCK_N_OUT
//
32
)
<
tl
.
cdiv
(
Nrem
,
32
)
if
HAS_OUT_MX_SCALE
:
acc
,
acc_scale
=
quantize_mxfp8_fn
(
acc
[
None
,
:],
out_n_mask
[
None
,
:])
acc
=
tl
.
reshape
(
acc
,
[
acc
.
shape
[
-
1
]])
acc_scale
=
tl
.
reshape
(
acc_scale
,
[
acc_scale
.
shape
[
-
1
]])
# Convert to flexpoint output if configured (scalar scales)
acc
=
float_to_flex
(
acc
,
OutExpectedScale
,
OutActualScale
,
OutChecksumScale
,
None
,
Out
,
FLEXPOINT_SATURATE_INF
,
)
# write-back for this tile
out_ptr
=
OutPtrs
+
pid_t
*
stride_om
tl
.
store
(
out_ptr
,
acc
,
mask
=
out_n_mask
)
if
HAS_OUT_MX_SCALE
:
out_scale_ptr
=
OutScalePtrs
+
pid_t
*
stride_omxs
tl
.
store
(
out_scale_ptr
,
acc_scale
,
mask
=
out_n_mask_scale
)
XPtrs
+=
BLOCK_N
*
stride_xn
OutPtrs
+=
BLOCK_N_OUT
*
stride_on
if
HAS_IN_MX_SCALE
:
XScalePtrs
+=
BLOCK_N
//
32
*
stride_xn
if
HAS_OUT_MX_SCALE
:
OutScalePtrs
+=
BLOCK_N_OUT
//
32
*
stride_xn
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags.py
deleted
100644 → 0
View file @
2b7160c6
# isort: off
# fmt: off
from
dataclasses
import
dataclass
import
triton
from
compactor_vllm.triton_kernels.target_info
import
get_cdna_version
import
torch
from
.opt_flags_details
import
opt_flags_amd
,
opt_flags_nvidia
@
dataclass
class
OptFlags
:
block_m
:
int
block_n
:
int
block_k
:
int
num_warps
:
int
num_stages
:
int
group_m
:
int
xcd_swizzle
:
int
w_cache_modifier
:
str
split_k
:
int
is_persistent
:
bool
fused_scatter
:
bool
idle_sms
:
int
epilogue_subtile
:
int
|
None
arch
:
str
target_kernel_kwargs
:
dict
def
__post_init__
(
self
):
if
self
.
fused_scatter
and
self
.
split_k
!=
1
:
raise
ValueError
(
"Not supported"
)
def
make_default_opt_flags_amd
(
out_dtype
,
lhs_dtype
,
rhs_dtype
,
precision_config
,
m
,
n
,
k
,
routing_data
,
can_use_persistent_tma
,
can_use_fused_scatter
,
enforce_bitwise_invariance
,
epilogue_effective_itemsize
,
constraints
,
):
constraints_supported
=
[
"block_m"
,
"block_n"
,
"block_k"
,
"split_k"
,
"fused_scatter"
,
"is_persistent"
,
"epilogue_subtile"
]
assert
not
any
([
c
not
in
constraints_supported
for
c
in
constraints
]),
constraints
.
keys
()
# tokens per expert
if
routing_data
is
None
:
tokens_per_expt
=
m
elif
routing_data
.
expected_tokens_per_expt
is
None
:
tokens_per_expt
=
max
(
1
,
m
//
routing_data
.
n_expts_tot
)
else
:
tokens_per_expt
=
routing_data
.
expected_tokens_per_expt
is_cdna4
=
get_cdna_version
()
==
4
# block_m
if
constraints
.
get
(
"block_m"
,
None
):
block_m
=
constraints
[
"block_m"
]
elif
enforce_bitwise_invariance
:
block_m
=
256
if
is_cdna4
else
128
elif
tokens_per_expt
>=
512
and
n
>=
2048
:
block_m
=
256
if
is_cdna4
else
128
elif
is_cdna4
and
m
>=
512
:
block_m
=
128
else
:
block_m
=
max
(
32
,
min
(
triton
.
next_power_of_2
(
tokens_per_expt
),
64
))
if
routing_data
is
not
None
:
grid_m
=
routing_data
.
n_blocks
(
m
,
block_m
)
else
:
grid_m
=
triton
.
cdiv
(
m
,
block_m
)
# group_m:
group_m
=
4
# number of xcds
num_xcds
=
8
xcd_swizzle
=
num_xcds
# block_nk:
block_n
,
block_k
=
opt_flags_amd
.
compute_block_nk
(
n
,
block_m
,
grid_m
,
num_xcds
,
lhs_dtype
,
rhs_dtype
,
precision_config
)
# Replace block_k if provided in constraints.
# TODO: Does opt_flags_amd.compute_block_nk need to be refactored?
if
constraints
.
get
(
"block_k"
,
None
)
is
not
None
:
block_k
=
constraints
[
"block_k"
]
if
constraints
.
get
(
"block_n"
,
None
)
is
not
None
:
block_n
=
constraints
[
"block_n"
]
is_persistent
=
constraints
.
get
(
"is_persistent"
,
False
)
# split_k:
if
constraints
.
get
(
"split_k"
,
None
)
is
not
None
:
split_k
=
constraints
[
"split_k"
]
elif
is_persistent
or
enforce_bitwise_invariance
:
split_k
=
1
else
:
grid_size
=
grid_m
*
((
n
+
block_n
-
1
)
//
block_n
)
n_cu
=
torch
.
cuda
.
get_device_properties
(
0
).
multi_processor_count
split_k
=
max
(
1
,
n_cu
//
grid_size
)
# w_cache_modifier:
w_cache_modifier
=
".cg"
if
block_m
<=
32
else
None
# num_warps, num_stages
num_warps
=
2
if
(
m
is
not
None
and
m
<=
16
)
else
8
num_stages
=
2
# AMD-specific
target_kernel_kwargs
=
{
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
}
epilogue_subtile
=
constraints
.
get
(
'epilogue_subtile'
,
None
)
if
epilogue_subtile
is
None
:
epilogue_subtile
=
1
ret
=
OptFlags
(
block_m
=
block_m
,
block_n
=
block_n
,
block_k
=
block_k
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
group_m
=
group_m
,
xcd_swizzle
=
xcd_swizzle
,
w_cache_modifier
=
w_cache_modifier
,
split_k
=
split_k
,
is_persistent
=
is_persistent
,
fused_scatter
=
constraints
.
get
(
'fused_scatter'
,
False
),
idle_sms
=
0
,
epilogue_subtile
=
epilogue_subtile
,
arch
=
None
,
target_kernel_kwargs
=
target_kernel_kwargs
,
)
# check constraints
assert
all
(
getattr
(
ret
,
ck
)
==
cv
for
ck
,
cv
in
constraints
.
items
()
if
cv
is
not
None
),
f
"
{
ret
}
!=
{
constraints
}
"
return
ret
def
make_default_opt_flags_nvidia
(
out_dtype
,
lhs_dtype
,
rhs_dtype
,
precision_config
,
m
,
n
,
k
,
routing_data
,
can_use_persistent_tma
,
can_use_fused_scatter
,
enforce_bitwise_invariance
,
epilogue_effective_itemsize
,
constraints
,
):
constraints_supported
=
[
"block_m"
,
"block_k"
,
"split_k"
,
"is_persistent"
,
"fused_scatter"
,
"epilogue_subtile"
,
"num_stages"
,
"idle_sms"
]
assert
not
any
([
c
not
in
constraints_supported
for
c
in
constraints
]),
constraints
.
keys
()
# tokens per expert
if
routing_data
is
None
:
tokens_per_expt
=
m
elif
routing_data
.
expected_tokens_per_expt
is
None
:
tokens_per_expt
=
max
(
1
,
m
//
routing_data
.
n_expts_tot
)
else
:
tokens_per_expt
=
routing_data
.
expected_tokens_per_expt
# pid swizzling
group_m
=
8
xcd_swizzle
=
1
# block_m
if
constraints
.
get
(
"block_m"
,
None
):
block_m
=
constraints
[
"block_m"
]
elif
enforce_bitwise_invariance
:
block_m
=
128
else
:
block_m
=
max
(
16
,
min
(
triton
.
next_power_of_2
(
tokens_per_expt
),
128
))
# block n
arch
=
None
block_n
=
opt_flags_nvidia
.
compute_block_n
(
n
,
arch
,
precision_config
)
# is_persistent
grid_size
=
opt_flags_nvidia
.
compute_grid_size
(
routing_data
,
m
,
n
,
block_m
,
block_n
)
n_sms
=
torch
.
cuda
.
get_device_properties
(
0
).
multi_processor_count
tiles_per_sm
=
grid_size
/
n_sms
supports_persistent
=
can_use_persistent_tma
and
(
arch
is
None
or
int
(
arch
[
2
:
-
1
])
>=
9
)
if
constraints
.
get
(
"is_persistent"
,
None
)
is
not
None
:
is_persistent
=
constraints
[
"is_persistent"
]
else
:
has_simple_epilogue
=
precision_config
.
max_num_imprecise_acc
is
None
is_persistent
=
supports_persistent
and
has_simple_epilogue
and
(
tiles_per_sm
>=
2.0
or
lhs_dtype
.
itemsize
<=
1
)
and
out_dtype
.
itemsize
<
4
# TEMP CHANGE
if
precision_config
.
act_scale
is
not
None
or
precision_config
.
out_scale
is
not
None
:
is_persistent
=
False
# block k
if
constraints
.
get
(
"block_k"
,
None
)
is
not
None
:
block_k
=
constraints
[
"block_k"
]
else
:
block_k
=
opt_flags_nvidia
.
compute_block_k
(
m
,
k
,
is_persistent
,
lhs_dtype
,
rhs_dtype
,
precision_config
)
# split_k
if
constraints
.
get
(
"split_k"
,
None
)
is
not
None
:
split_k
=
constraints
[
"split_k"
]
elif
is_persistent
or
enforce_bitwise_invariance
or
precision_config
.
act_scale
is
not
None
or
precision_config
.
out_scale
is
not
None
:
split_k
=
1
else
:
estimated_actual_grid_size
=
opt_flags_nvidia
.
compute_grid_size
(
None
,
m
,
n
,
block_m
,
block_n
)
split_k
=
opt_flags_nvidia
.
compute_split_k
(
block_k
,
k
,
estimated_actual_grid_size
)
if
split_k
>
1
:
# With split_k, results are written in f32. Use that for the following computations.
out_dtype
=
torch
.
float32
compute_num_stages_args
=
(
precision_config
,
is_persistent
,
block_m
,
block_n
,
block_k
,
out_dtype
,
lhs_dtype
,
rhs_dtype
,
)
if
constraints
.
get
(
"epilogue_subtile"
,
None
)
is
not
None
:
subtiles_to_check
=
[
constraints
[
"epilogue_subtile"
]]
else
:
subtiles_to_check
=
[
1
,
2
,
4
]
num_stages
=
-
1
for
ep
in
subtiles_to_check
:
ns
=
opt_flags_nvidia
.
compute_num_stages
(
*
compute_num_stages_args
,
ep
,
epilogue_effective_itemsize
)
if
ns
>
num_stages
:
epilogue_subtile
,
num_stages
=
ep
,
ns
assert
num_stages
>=
1
if
constraints
.
get
(
"num_stages"
,
None
):
num_stages
=
constraints
[
"num_stages"
]
# fused scatter scratchpad
if
constraints
.
get
(
"fused_scatter"
,
None
)
is
not
None
:
fused_scatter
=
constraints
[
"fused_scatter"
]
else
:
fused_scatter
=
can_use_fused_scatter
and
split_k
==
1
# Handshake with the HBM swizzling
num_warps
=
opt_flags_nvidia
.
compute_num_warps
(
block_m
,
block_n
,
precision_config
)
ret
=
OptFlags
(
block_m
=
block_m
,
block_n
=
block_n
,
block_k
=
block_k
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
fused_scatter
=
fused_scatter
,
group_m
=
group_m
,
xcd_swizzle
=
xcd_swizzle
,
w_cache_modifier
=
None
,
split_k
=
split_k
,
is_persistent
=
is_persistent
,
epilogue_subtile
=
epilogue_subtile
,
arch
=
arch
,
target_kernel_kwargs
=
dict
(),
idle_sms
=
constraints
.
get
(
"idle_sms"
,
0
),
)
# check constraints
assert
all
(
getattr
(
ret
,
ck
)
==
cv
for
ck
,
cv
in
constraints
.
items
()
if
cv
is
not
None
),
f
"
{
ret
}
!=
{
constraints
}
"
return
ret
# --------------
# User Interface
# --------------
_opt_flags_constraints
:
dict
=
dict
()
_opt_flags
:
OptFlags
|
None
=
None
def
update_opt_flags_constraints
(
constraints
:
dict
[
str
,
int
]):
global
_opt_flags_constraints
_opt_flags_constraints
.
update
(
constraints
)
def
reset_opt_flags_constraints
():
global
_opt_flags_constraints
_opt_flags_constraints
=
dict
()
def
set_opt_flags
(
opt_flags
:
OptFlags
):
global
_opt_flags
assert
not
_opt_flags_constraints
,
"setting constraints is incompatible with manual flags override"
assert
not
_opt_flags
,
"opt_flags already set; please reset to None first"
_opt_flags
=
opt_flags
class
InapplicableConstraint
(
Exception
):
pass
def
make_opt_flags
(
out_dtype
,
lhs_dtype
,
rhs_dtype
,
precision_config
,
m
,
n
,
k
,
routing_data
,
can_use_persistent_tma
,
can_use_fused_scatter
,
epilogue_effective_itemsize
,
):
if
_opt_flags_constraints
.
get
(
"is_persistent"
,
False
)
and
not
can_use_persistent_tma
:
raise
InapplicableConstraint
(
"cannot enforce `is_persistent=True` constraint"
)
if
_opt_flags_constraints
.
get
(
"fused_scatter"
,
False
)
and
not
can_use_fused_scatter
:
raise
InapplicableConstraint
(
"cannot enforce `fused_scatter=True` constraint"
)
enforce_bitwise_invariance
=
precision_config
.
enforce_bitwise_invariance
if
_opt_flags
is
not
None
:
assert
not
_opt_flags_constraints
return
_opt_flags
args
=
[
out_dtype
,
lhs_dtype
,
rhs_dtype
,
precision_config
,
m
,
n
,
k
,
routing_data
,
can_use_persistent_tma
,
can_use_fused_scatter
,
enforce_bitwise_invariance
,
epilogue_effective_itemsize
,
_opt_flags_constraints
]
backend
=
triton
.
runtime
.
driver
.
active
.
get_current_target
().
backend
if
backend
==
"hip"
:
return
make_default_opt_flags_amd
(
*
args
)
if
backend
==
"cuda"
:
return
make_default_opt_flags_nvidia
(
*
args
)
assert
False
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_amd.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
triton
from
compactor_vllm.triton_kernels.target_info
import
get_cdna_version
from
compactor_vllm.triton_kernels.tensor
import
bitwidth
def
compute_block_nk
(
n
,
block_m
,
grid_m
,
num_xcds
,
lhs_dtype
,
rhs_dtype
,
precision_config
):
lhs_width
=
bitwidth
(
lhs_dtype
)
/
8
rhs_width
=
bitwidth
(
rhs_dtype
)
/
8
# block_n:
n_cu
=
torch
.
cuda
.
get_device_properties
(
0
).
multi_processor_count
if
n
is
not
None
:
if
n
<=
128
and
(
n
&
(
n
-
1
))
==
0
:
block_n
=
n
else
:
block_n
=
max
(
32
,
min
(
256
,
triton
.
next_power_of_2
(
grid_m
*
n
*
num_xcds
//
n_cu
))
)
elif
block_m
>
64
:
block_n
=
256
else
:
block_n
=
128
if
get_cdna_version
()
==
4
and
block_m
==
128
:
block_n
=
512
# block_k needs to match the cacheline size (128B)
block_k
=
int
(
128
//
min
(
lhs_width
,
rhs_width
))
# TODO: block_k = 128 seems to work better for now.
# perhaps due to increased number of k loops to pipeline
if
precision_config
.
weight_scale
is
not
None
and
get_cdna_version
()
!=
4
:
block_k
=
128
return
block_n
,
block_k
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
import
triton
from
compactor_vllm.triton_kernels
import
target_info
from
compactor_vllm.triton_kernels.tensor
import
get_layout
,
bitwidth
,
FP4
from
compactor_vllm.triton_kernels.tensor_details.layout
import
HopperMXScaleLayout
from
compactor_vllm.triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp
import
(
MXFP_BLOCK_SIZE
,
)
def
compute_grid_size
(
routing_data
,
m
,
n
,
block_m
,
block_n
):
if
routing_data
is
not
None
:
grid_m
=
routing_data
.
n_blocks
(
m
,
block_m
)
else
:
grid_m
=
triton
.
cdiv
(
m
,
block_m
)
grid_n
=
(
n
+
block_n
-
1
)
//
block_n
return
grid_m
*
grid_n
def
compute_block_n
(
n
:
int
,
arch
,
precision_config
):
# block_n:
layout
=
get_layout
(
precision_config
.
weight_scale
)
if
isinstance
(
layout
,
HopperMXScaleLayout
)
and
layout
.
num_warps
==
4
:
return
128
elif
precision_config
.
max_num_imprecise_acc
is
None
and
n
>
128
:
return
256
else
:
return
max
(
16
,
min
(
128
,
triton
.
next_power_of_2
(
n
)))
def
compute_block_k
(
m
:
int
,
k
:
int
|
None
,
is_persistent
:
bool
,
lhs_dtype
,
rhs_dtype
,
precision_config
):
lhs_width
=
bitwidth
(
lhs_dtype
)
rhs_width
=
bitwidth
(
rhs_dtype
)
# block_k needs to match the cacheline size (1024 bits)
block_k
=
int
(
1024
//
min
(
lhs_width
,
rhs_width
))
has_native_mxfp
=
target_info
.
cuda_capability_geq
(
10
,
0
)
if
rhs_width
==
4
and
not
has_native_mxfp
:
block_k
=
128
elif
k
is
not
None
:
block_k
=
max
(
32
,
min
(
triton
.
next_power_of_2
(
k
),
block_k
))
has_mx_weight_scale
=
(
precision_config
is
not
None
and
precision_config
.
weight_scale
is
not
None
)
if
has_native_mxfp
and
is_persistent
and
has_mx_weight_scale
:
block_k
=
min
(
block_k
,
128
)
return
block_k
def
compute_split_k
(
block_k
:
int
,
k
:
int
|
None
,
grid_size
:
int
)
->
int
:
device_props
=
torch
.
cuda
.
get_device_properties
(
0
)
n_sms
=
device_props
.
multi_processor_count
split_k
=
n_sms
//
grid_size
if
k
is
not
None
:
# avoid split_k for small k
num_block_k
=
triton
.
cdiv
(
k
,
block_k
)
split_k
=
min
(
split_k
,
num_block_k
//
4
)
split_k
=
max
(
split_k
,
1
)
return
split_k
def
compute_num_warps
(
block_m
,
block_n
,
precision_config
):
layout
=
get_layout
(
precision_config
.
weight_scale
)
if
isinstance
(
layout
,
HopperMXScaleLayout
):
return
layout
.
num_warps
return
max
(
block_m
*
block_n
//
4096
,
4
)
def
compute_num_stages
(
precision_config
,
is_persistent
,
block_m
,
block_n
,
block_k
,
out_dtype
,
lhs_dtype
,
rhs_dtype
,
epilogue_subtile
,
epilogue_effective_itemsize
,
):
if
precision_config
.
max_num_imprecise_acc
is
not
None
:
return
3
weight_size
=
bitwidth
(
rhs_dtype
)
/
8
stage_size
=
(
block_m
*
block_k
*
lhs_dtype
.
itemsize
+
block_k
*
block_n
*
weight_size
)
device_props
=
torch
.
cuda
.
get_device_properties
(
0
)
smem_capacity
=
device_props
.
shared_memory_per_block_optin
has_native_mxfp
=
target_info
.
cuda_capability_geq
(
10
,
0
)
if
has_native_mxfp
and
getattr
(
precision_config
,
"weight_scale"
,
None
)
is
not
None
:
if
rhs_dtype
==
FP4
:
# 4-bit e2m1 weights are padded 2x
# https://docs.nvidia.com/cuda/parallel-thread-execution/#packing-format-used-for-matrix-a-and-b-by-kind-mxf8f6f4-in-shared-memory
stage_size
+=
block_k
*
block_n
*
weight_size
if
is_persistent
:
# Per-stage wait barrier
stage_size
+=
8
if
target_info
.
cuda_capability_geq
(
10
,
0
):
acc_size
=
epilogue_effective_itemsize
or
out_dtype
.
itemsize
else
:
acc_size
=
out_dtype
.
itemsize
if
target_info
.
cuda_capability_geq
(
10
,
0
)
and
epilogue_subtile
is
not
None
:
acc_block_n
=
block_n
//
epilogue_subtile
else
:
acc_block_n
=
block_n
# pipelined TMA store local to global, or
# pipelined layout conversion before store of the accumulator
# note: layout conversion has some padding
smem_capacity
-=
int
((
block_m
+
4
)
*
acc_block_n
*
acc_size
)
if
precision_config
.
weight_scale
is
not
None
:
# mx scales
stage_size
+=
block_n
*
(
block_k
//
int
(
MXFP_BLOCK_SIZE
))
elif
has_native_mxfp
:
# mx scales
stage_size
+=
block_n
*
(
block_k
//
int
(
MXFP_BLOCK_SIZE
))
num_stages
=
min
(
4
,
smem_capacity
//
int
(
stage_size
))
return
num_stages
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics.py
deleted
100644 → 0
View file @
2b7160c6
import
torch
from
dataclasses
import
dataclass
MAX_FINITE_FLOAT8E5
=
57344.0
MAX_FINITE_FLOAT8E4NV
=
448.0
MAX_FINITE_FLOAT8E4B8
=
240.0
@
dataclass
(
frozen
=
True
)
class
BaseFlexData
:
dtype
:
torch
.
dtype
|
None
=
None
def
view
(
self
,
x
:
torch
.
Tensor
):
if
self
.
dtype
is
None
:
return
x
return
x
.
view
(
self
.
dtype
)
def
reinterpret
(
self
,
x
):
if
self
.
dtype
is
None
or
x
.
dtype
.
itemsize
>
1
:
return
x
return
x
.
view
(
self
.
dtype
)
@
dataclass
(
frozen
=
True
)
class
InFlexData
(
BaseFlexData
):
scale
:
torch
.
Tensor
|
None
=
None
@
property
def
is_per_batch
(
self
):
return
False
if
self
.
scale
is
None
else
len
(
self
.
scale
)
>
1
@
dataclass
(
frozen
=
True
)
class
OutFlexData
(
BaseFlexData
):
expected_scale
:
torch
.
Tensor
|
None
=
None
actual_scale
:
torch
.
Tensor
|
None
=
None
checksum_scale
:
torch
.
Tensor
|
None
=
None
def
__iter__
(
self
):
yield
self
.
expected_scale
yield
self
.
actual_scale
yield
self
.
checksum_scale
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/flexpoint.py
deleted
100644 → 0
View file @
2b7160c6
from
..numerics
import
MAX_FINITE_FLOAT8E4B8
,
MAX_FINITE_FLOAT8E4NV
,
MAX_FINITE_FLOAT8E5
import
triton
import
triton.language
as
tl
from
compactor_vllm.triton_kernels.target_info
import
cuda_capability_geq
# -------------------------------
# Kernels stuff
# -------------------------------
TL_MAX_FINITE_FLOAT8E5
=
tl
.
constexpr
(
MAX_FINITE_FLOAT8E5
)
TL_MAX_FINITE_FLOAT8E4NV
=
tl
.
constexpr
(
MAX_FINITE_FLOAT8E4NV
)
TL_MAX_FINITE_FLOAT8E4B8
=
tl
.
constexpr
(
MAX_FINITE_FLOAT8E4B8
)
TL_MAX_FINITE_FLOAT8E4B15
=
tl
.
constexpr
(
1.750
)
TL_MAX_FINITE_FLOAT16
=
tl
.
constexpr
(
65472.0
)
TL_RCP_MAX_FINITE_FLOAT8E5
=
tl
.
constexpr
(
0x37924925
)
# 0x1.24924Ap-16
TL_RCP_MAX_FINITE_FLOAT8E4NV
=
tl
.
constexpr
(
0x3B124925
)
# 0x1.24924Ap-9
TL_RCP_MAX_FINITE_FLOAT8E4B8
=
tl
.
constexpr
(
0x3B888889
)
# 0x1.111112p-8
TL_RCP_MAX_FINITE_FLOAT8E4B15
=
tl
.
constexpr
(
0x3F124925
)
# 0x1.24924Ap-1
TL_RCP_MAX_FINITE_FLOAT16
=
tl
.
constexpr
(
0x37802008
)
# 0x1.004010p-16
@
triton
.
jit
def
max_finite
(
dtype
):
if
dtype
==
tl
.
constexpr
(
tl
.
float8e5
):
return
TL_MAX_FINITE_FLOAT8E5
elif
dtype
==
tl
.
constexpr
(
tl
.
float8e4nv
):
return
TL_MAX_FINITE_FLOAT8E4NV
elif
dtype
==
tl
.
constexpr
(
tl
.
float8e4b8
):
return
TL_MAX_FINITE_FLOAT8E4B8
elif
dtype
==
tl
.
constexpr
(
tl
.
float8e4b15
):
return
TL_MAX_FINITE_FLOAT8E4B15
elif
dtype
==
tl
.
constexpr
(
tl
.
float16
):
return
TL_MAX_FINITE_FLOAT16
else
:
tl
.
static_assert
(
tl
.
constexpr
(
False
),
f
"
{
dtype
}
not supported in flexpoint"
)
@
triton
.
jit
def
rcp_max_finite
(
dtype
):
if
dtype
==
tl
.
constexpr
(
tl
.
float8e5
):
return
TL_RCP_MAX_FINITE_FLOAT8E5
elif
dtype
==
tl
.
constexpr
(
tl
.
float8e4nv
):
return
TL_RCP_MAX_FINITE_FLOAT8E4NV
elif
dtype
==
tl
.
constexpr
(
tl
.
float8e4b8
):
return
TL_RCP_MAX_FINITE_FLOAT8E4B8
elif
dtype
==
tl
.
constexpr
(
tl
.
float8e4b15
):
return
TL_RCP_MAX_FINITE_FLOAT8E4B15
elif
dtype
==
tl
.
constexpr
(
tl
.
float16
):
return
TL_RCP_MAX_FINITE_FLOAT16
else
:
tl
.
static_assert
(
tl
.
constexpr
(
False
),
f
"
{
dtype
}
not supported in flexpoint"
)
@
triton
.
jit
def
sm86_min_nan_xorsign_abs_f32
(
a
,
b
):
"""Wrapper for min.NaN.xorsign.abs.f32 PTX instruction.
Computes the minimum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
NaN inputs are propagated to the output.
Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
"""
tl
.
static_assert
(
cuda_capability_geq
(
8
,
6
),
"min.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+"
,
)
tl
.
static_assert
(
a
.
dtype
==
tl
.
float32
,
"min.NaN.xorsign.abs.f32 requires float32 inputs"
)
tl
.
static_assert
(
b
.
dtype
==
tl
.
float32
,
"min.NaN.xorsign.abs.f32 requires float32 inputs"
)
return
tl
.
inline_asm_elementwise
(
"""{
min.NaN.xorsign.abs.f32 $0, $1, $2;
}"""
,
"=r,r,r"
,
[
a
,
b
],
dtype
=
tl
.
float32
,
is_pure
=
True
,
pack
=
1
,
)
@
triton
.
jit
def
sm86_max_nan_xorsign_abs_f32
(
a
,
b
):
"""Wrapper for max.NaN.xorsign.abs.f32 PTX instruction.
Computes the maximum of the absolute values of the two inputs and sets its sign to the XOR of the signs of the inputs.
NaN inputs are propagated to the output.
Requires CUDA compute capability 8.6+ (A100 and A30 Ampere GPUs don't support it, but A40/A16/A10/A2, Ada, and Hopper GPUs do).
"""
tl
.
static_assert
(
cuda_capability_geq
(
8
,
6
),
"max.NaN.xorsign.abs.f32 requires CUDA compute capability 8.6+"
,
)
tl
.
static_assert
(
a
.
dtype
==
tl
.
float32
,
"max.NaN.xorsign.abs.f32 requires float32 inputs"
)
tl
.
static_assert
(
b
.
dtype
==
tl
.
float32
,
"max.NaN.xorsign.abs.f32 requires float32 inputs"
)
return
tl
.
inline_asm_elementwise
(
"""{
max.NaN.xorsign.abs.f32 $0, $1, $2;
}"""
,
"=r,r,r"
,
[
a
,
b
],
dtype
=
tl
.
float32
,
is_pure
=
True
,
pack
=
1
,
)
@
triton
.
jit
def
load_scale
(
scale_ptr
):
return
1.0
if
scale_ptr
is
None
else
tl
.
load
(
scale_ptr
)
@
triton
.
jit
def
flex_to_float
(
x
,
scale_ptr
):
scale
=
load_scale
(
scale_ptr
)
return
x
.
to
(
tl
.
float32
)
*
scale
@
triton
.
jit
def
clip
(
x
,
limit
):
res
=
tl
.
minimum
(
x
,
limit
)
res
=
tl
.
maximum
(
-
limit
,
res
)
return
res
@
triton
.
jit
def
nan_propagating_absmax_reduce
(
x
,
axis
=
None
):
if
cuda_capability_geq
(
8
,
6
):
# abs-max-reduce as floating-point if `max.NaN.xorsign.abs.f32` is supported.
x_absmax
=
tl
.
reduce
(
x
,
axis
,
sm86_max_nan_xorsign_abs_f32
)
# Note: sign of reduction result is the xor of signs of all inputs, explicitly clear the sign bit to fix it.
x_absmax
=
x_absmax
.
to
(
tl
.
uint32
,
bitcast
=
True
)
&
0x7FFFFFFF
else
:
# Clear the sign bit, max-reduce as integer (same as NaN-propagating max-reduce as float)
masked_abs_x
=
x
.
to
(
tl
.
uint32
,
bitcast
=
True
)
&
0x7FFFFFFF
x_absmax
=
tl
.
max
(
masked_abs_x
,
axis
)
return
x_absmax
@
triton
.
jit
def
compute_scale
(
x
,
Out
):
x_absmax
=
nan_propagating_absmax_reduce
(
tl
.
ravel
(
x
,
can_reorder
=
True
))
# atomic_max does not propagate NaNs, so we replace them with +inf (0x7f800000).
# We use integer minimum because NaNs are above +inf in integer representation.
x_absmax
=
tl
.
minimum
(
x_absmax
,
0x7F800000
).
to
(
tl
.
float32
,
bitcast
=
True
)
RCP_MAX_VALUE
=
rcp_max_finite
(
Out
.
dtype
.
element_ty
)
return
tl
.
fma
(
x_absmax
,
RCP_MAX_VALUE
.
to
(
tl
.
float32
,
bitcast
=
True
),
1.0e-30
)
@
triton
.
jit
def
update_scale
(
x
,
scale_ptr
,
Out
)
->
None
:
if
scale_ptr
is
not
None
:
scale
=
compute_scale
(
x
,
Out
)
tl
.
atomic_max
(
scale_ptr
,
scale
,
sem
=
"relaxed"
)
@
triton
.
jit
def
float_to_flex
(
x
,
expected_scale_ptr_or_val
,
actual_scale_ptr
,
checksum_scale_ptr
,
mask
,
Out
,
saturate_infs
:
tl
.
constexpr
,
):
if
expected_scale_ptr_or_val
is
not
None
:
if
expected_scale_ptr_or_val
.
dtype
.
is_ptr
():
invscale
=
1.0
/
tl
.
load
(
expected_scale_ptr_or_val
)
else
:
invscale
=
1.0
/
expected_scale_ptr_or_val
else
:
invscale
=
1.0
if
checksum_scale_ptr
is
not
None
:
x_int32
=
x
.
to
(
tl
.
int32
,
bitcast
=
True
)
zero
=
tl
.
cast
(
0.0
,
tl
.
int32
)
if
mask
is
not
None
:
x_int32
=
tl
.
where
(
mask
,
x_int32
,
zero
)
checksum_local
=
tl
.
xor_sum
(
tl
.
ravel
(
x_int32
,
can_reorder
=
True
),
0
)
tl
.
atomic_add
(
checksum_scale_ptr
,
checksum_local
)
if
mask
is
not
None
:
if
actual_scale_ptr
is
not
None
:
x
=
tl
.
where
(
mask
,
x
,
0.0
)
update_scale
(
x
,
actual_scale_ptr
,
Out
)
x
=
x
*
invscale
# if expected_scale_ptr is not None, we applied flexpoint scale. We only want to clip in this case.
if
expected_scale_ptr_or_val
is
not
None
:
if
saturate_infs
:
CLIP_VALUE
=
max_finite
(
Out
.
dtype
.
element_ty
)
x
=
clip
(
x
,
CLIP_VALUE
)
return
x
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp.py
deleted
100644 → 0
View file @
2b7160c6
# isort: off
# fmt: off
from
enum
import
Enum
import
triton
import
torch
import
torch.nn.functional
as
F
from
.mxfp_details._upcast_from_mxfp
import
_upcast_from_mxfp
from
.mxfp_details._downcast_to_mxfp
import
_downcast_to_mxfp
,
MXFP_BLOCK_SIZE
,
_quantize_mxfp8_fn
# -----------------------------------------------------------------------------
# Dequantization / Quantization Utilities
# -----------------------------------------------------------------------------
class
DequantScaleRoundingMode
(
Enum
):
ROUND_UP
=
0
ROUND_DOWN
=
1
def
downcast_to_mxfp
(
src_tensor
:
torch
.
Tensor
,
out_quant_type
:
torch
.
dtype
,
axis
:
int
,
DEQUANT_SCALE_ROUNDING_MODE
:
DequantScaleRoundingMode
=
DequantScaleRoundingMode
.
ROUND_UP
):
"""
Convert the src weights to mx format. The src weight is quantized along the axis dimension.
If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte.
Note that this means the k_dim of the tensor will be half of the logical k_dim.
If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored
in their respective formats.
"""
ndim
=
src_tensor
.
ndim
assert
-
ndim
<=
axis
<
ndim
,
f
"Invalid axis
{
axis
=
}
"
axis
=
axis
if
axis
>=
0
else
axis
+
ndim
# downcast
src_tensor
=
src_tensor
.
transpose
(
axis
,
src_tensor
.
ndim
-
1
)
is_fp4
=
out_quant_type
==
torch
.
uint8
is_fp8
=
out_quant_type
in
(
torch
.
float8_e4m3fn
,
torch
.
float8_e5m2
)
assert
is_fp4
or
is_fp8
divisor
=
2
if
is_fp4
else
1
L
=
src_tensor
.
shape
[
-
1
]
if
is_fp4
:
assert
L
%
2
==
0
,
f
"axis dim must be divisible by 2 for e2m1. Got
{
L
}
"
out_shape
=
src_tensor
.
shape
[:
-
1
]
+
(
L
//
divisor
,
)
out_scale_shape
=
src_tensor
.
shape
[:
-
1
]
+
(
triton
.
cdiv
(
L
,
MXFP_BLOCK_SIZE
),
)
out_quant_tensor
=
src_tensor
.
new_empty
(
out_shape
,
dtype
=
out_quant_type
)
out_scale
=
src_tensor
.
new_empty
(
out_scale_shape
,
dtype
=
torch
.
uint8
)
if
src_tensor
.
numel
()
>
0
:
kernel_src_tensor
=
src_tensor
.
reshape
(
-
1
,
src_tensor
.
shape
[
-
1
])
kernel_quant_tensor
=
out_quant_tensor
.
view
(
-
1
,
out_quant_tensor
.
shape
[
-
1
])
kernel_scale
=
out_scale
.
view
(
-
1
,
out_scale
.
shape
[
-
1
])
BLOCK_OUT_DIM
=
128
BLOCK_QUANT_DIM
=
MXFP_BLOCK_SIZE
.
value
grid_out
=
triton
.
cdiv
(
kernel_src_tensor
.
shape
[
0
],
BLOCK_OUT_DIM
)
grid_quant
=
triton
.
cdiv
(
kernel_src_tensor
.
shape
[
1
],
BLOCK_QUANT_DIM
)
_downcast_to_mxfp
[(
grid_out
,
grid_quant
)](
kernel_quant_tensor
,
*
kernel_quant_tensor
.
stride
(),
kernel_scale
,
*
kernel_scale
.
stride
(),
kernel_src_tensor
,
*
kernel_src_tensor
.
stride
(),
*
kernel_src_tensor
.
shape
,
BLOCK_OUT_DIM
,
BLOCK_QUANT_DIM
,
DEQUANT_SCALE_ROUNDING_MODE
.
value
,
num_warps
=
8
)
out_quant_tensor
=
out_quant_tensor
.
transpose
(
axis
,
src_tensor
.
ndim
-
1
)
out_scale
=
out_scale
.
transpose
(
axis
,
src_tensor
.
ndim
-
1
)
return
out_quant_tensor
,
out_scale
def
upcast_from_mxfp
(
tensor
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
target_dtype
:
torch
.
dtype
,
axis
:
int
):
"""
Upcasts an mxfp (packed) weight tensor back to float16 or bfloat16.
The function assumes that the tensors were quantized along the given axis.
It permutes the tensor so that the quantized axis is last, reshapes to 2D,
launches the Triton upcast kernel, and then unpermutes back to the original order.
"""
ndim
=
tensor
.
ndim
assert
-
ndim
<=
axis
<
ndim
,
f
"Invalid axis
{
axis
=
}
"
axis
=
axis
if
axis
>=
0
else
axis
+
ndim
assert
tensor
.
ndim
==
scale
.
ndim
,
(
f
"Weight and scale must have the same number of dimensions. "
f
"Got
{
tensor
.
ndim
=
}
and
{
scale
.
ndim
=
}
"
)
# dtype checks
assert
tensor
.
dtype
in
{
torch
.
uint8
,
torch
.
float8_e5m2
,
torch
.
float8_e4m3fn
},
\
f
"Invalid tensor dtype
{
tensor
.
dtype
=
}
"
assert
scale
.
dtype
==
torch
.
uint8
,
f
"Invalid scale dtype
{
scale
.
dtype
=
}
"
assert
target_dtype
in
(
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
),
f
"Invalid output dtype
{
target_dtype
=
}
"
# upcast
logical_quant_dim
=
tensor
.
shape
[
axis
]
*
(
2
if
tensor
.
dtype
==
torch
.
uint8
else
1
)
tensor
=
tensor
.
transpose
(
axis
,
tensor
.
ndim
-
1
).
contiguous
()
scale
=
scale
.
transpose
(
axis
,
scale
.
ndim
-
1
).
contiguous
()
out
=
torch
.
empty
((
*
tensor
.
shape
[:
-
1
],
logical_quant_dim
),
dtype
=
target_dtype
,
device
=
tensor
.
device
)
reshaped_out
=
out
.
view
(
-
1
,
out
.
shape
[
-
1
])
reshaped_tensor
=
tensor
.
view
(
-
1
,
tensor
.
shape
[
-
1
])
reshaped_scale
=
scale
.
view
(
-
1
,
scale
.
shape
[
-
1
])
BLOCK_OUT_DIM
=
128
BLOCK_QUANT_DIM
=
MXFP_BLOCK_SIZE
.
value
blocks_out_dim
=
triton
.
cdiv
(
reshaped_out
.
shape
[
0
],
BLOCK_OUT_DIM
)
blocks_quant_dim
=
triton
.
cdiv
(
reshaped_out
.
shape
[
1
],
BLOCK_QUANT_DIM
)
_upcast_from_mxfp
[(
blocks_out_dim
,
blocks_quant_dim
)](
reshaped_out
,
*
reshaped_out
.
stride
(),
reshaped_scale
,
*
reshaped_scale
.
stride
(),
reshaped_tensor
,
*
reshaped_tensor
.
stride
(),
*
reshaped_out
.
shape
,
BLOCK_OUT_DIM
,
BLOCK_QUANT_DIM
,
num_warps
=
8
)
out
=
out
.
transpose
(
axis
,
scale
.
ndim
-
1
).
contiguous
()
return
out
# ------------
def
right_shift_unsigned
(
x
,
shift
):
# CUDA torch does not support bit ops on uint32, so we need to mask to get unsigned right shift
return
(
x
>>
shift
)
&
((
1
<<
(
32
-
shift
))
-
1
)
def
get_max_quant_val
(
dtype
:
torch
.
dtype
):
d
=
{
torch
.
uint8
:
6.0
,
torch
.
float8_e5m2
:
57344.0
,
torch
.
float8_e4m3fn
:
448.0
}
assert
dtype
in
d
return
d
[
dtype
]
def
downcast_to_mxfp_torch
(
src_tensor
:
torch
.
Tensor
,
out_quant_type
:
torch
.
dtype
,
axis
:
int
,
DEQUANT_SCALE_ROUNDING_MODE
:
DequantScaleRoundingMode
=
DequantScaleRoundingMode
.
ROUND_UP
):
"""
Converts the src tensor to the output format specified by out_quant_type.
axis: The axis along which the tensors are contiguous and quantization is applied.
DEQUANT_SCALE_ROUNDING_MODE: 0 for ROUND_UP, 1 for ROUND_DOWN.
Returns:
out_quant_tensor: Quantized tensor in mx format.
• For mxfp8, the output has the same shape as src_tensor.
• For mxfp4, the size along the axis is halved, and the tensor is returned as a torch.uint8.
scale: Scale tensor (stored as uint8) computed per group of 32 elements along the axis.
Its shape is the same as src_tensor except that the axis is replaced by ceil(L/32),
where L is the original length along that axis.
"""
# This should probably be packed into its own tiny class
ndim
=
src_tensor
.
ndim
assert
-
ndim
<=
axis
<
ndim
,
f
"Invalid axis
{
axis
=
}
"
assert
src_tensor
.
dtype
in
{
torch
.
float32
,
torch
.
bfloat16
,
torch
.
float16
},
f
"Invalid input tensor dtype
{
src_tensor
.
dtype
}
"
axis
=
axis
if
axis
>=
0
else
axis
+
ndim
is_fp4
=
out_quant_type
==
torch
.
uint8
is_fp8
=
"float8"
in
str
(
out_quant_type
)
assert
is_fp4
or
is_fp8
,
f
"Invalid input tensor dtype
{
out_quant_type
}
"
device
=
src_tensor
.
device
# For mxfp4 conversion, we assume the contiguous axis length is even.
if
is_fp4
:
axis_shape
=
src_tensor
.
size
(
axis
)
assert
axis_shape
%
2
==
0
,
"For mxfp4 conversion the contiguous axis length must be even."
# Permute the tensor so that the contiguous axis becomes the last dimension.
src
=
src_tensor
.
transpose
(
axis
,
src_tensor
.
ndim
-
1
).
to
(
torch
.
float32
)
axis_shape
=
src
.
shape
[
-
1
]
# Pad the axis to be divisible by 32, in case it is not.
next_multiple
=
triton
.
cdiv
(
axis_shape
,
MXFP_BLOCK_SIZE
)
*
MXFP_BLOCK_SIZE
pad_amount
=
next_multiple
-
axis_shape
padded_src
=
F
.
pad
(
src
,
(
0
,
pad_amount
))
valid_mask
=
F
.
pad
(
torch
.
ones_like
(
src
,
dtype
=
torch
.
bool
),
(
0
,
pad_amount
))
padded_axis_shape
=
padded_src
.
size
(
-
1
)
# now divisible by 32
# --- Compute per-group maximums for scale ---
# Set padded entries to -1 so they don’t affect the max.
abs_f
=
torch
.
abs
(
padded_src
)
abs_f
=
torch
.
where
(
valid_mask
,
abs_f
,
torch
.
tensor
(
-
1.0
,
device
=
device
,
dtype
=
padded_src
.
dtype
))
# Reshape the last dimension into groups of 32.
new_shape
=
padded_src
.
shape
[:
-
1
]
+
(
padded_axis_shape
//
MXFP_BLOCK_SIZE
,
MXFP_BLOCK_SIZE
)
abs_groups
=
abs_f
.
view
(
*
new_shape
)
# Compute maximum along the group dimension (of size 32).
max_val
,
_
=
abs_groups
.
max
(
dim
=-
1
,
keepdim
=
True
)
# Choose a max quantization value depending on type.
max_quant_val
=
get_max_quant_val
(
out_quant_type
)
dequant_scale
=
max_val
/
max_quant_val
# shape: (..., padded_axis_shape//32, 1)
# Convert to int to round the FP32 scale, prior to quantization!
ds_int
=
dequant_scale
.
view
(
torch
.
int32
)
if
DEQUANT_SCALE_ROUNDING_MODE
==
DequantScaleRoundingMode
.
ROUND_UP
:
ds_int_rounded
=
(
ds_int
+
0x007FFFFF
)
&
0x7F800000
else
:
ds_int_rounded
=
ds_int
&
0x7F800000
# Reinterpret back as float32.
dequant_scale_rounded
=
ds_int_rounded
.
view
(
torch
.
float32
)
# Compute the quantization scale.
quant_scale
=
torch
.
where
(
dequant_scale_rounded
==
0
,
torch
.
tensor
(
0.0
,
device
=
device
),
1.0
/
dequant_scale_rounded
)
# Quantize the tensor
orig_padded_shape
=
padded_src
.
shape
padded_src_groups
=
padded_src
.
view
(
*
new_shape
)
quant_tensor
=
padded_src_groups
*
quant_scale
# Reshape back to the original shape and trim padding
quant_tensor
=
quant_tensor
.
view
(
orig_padded_shape
)
quant_tensor
=
quant_tensor
[...,
:
axis_shape
]
# Finally, convert the quantized tensor to the target format
if
is_fp8
:
# Conversion must use satfinite PTX, so clamp before the conversion in torch to emulate this behavior
quant_tensor
=
torch
.
clamp
(
quant_tensor
,
-
max_quant_val
,
max_quant_val
)
out_weight
=
quant_tensor
.
to
(
out_quant_type
)
else
:
assert
is_fp4
,
f
"Invalid output quantization type
{
out_quant_type
}
"
# For mxfp4, perform bit-level manipulation and pack two 4-bit values per uint8.
# First, reinterpret the quantized tensor bits.
q_int
=
quant_tensor
.
contiguous
().
view
(
torch
.
int32
)
# Extract sign, exponent, and mantissa.
signs
=
q_int
&
0x80000000
exponents
=
right_shift_unsigned
(
q_int
,
23
)
&
0xFF
mantissas
=
q_int
&
0x7FFFFF
E8_BIAS
=
127
E2_BIAS
=
1
# Adjust mantissas for subnormals.
mantissas
=
torch
.
where
(
exponents
<
E8_BIAS
,
(
0x400000
|
right_shift_unsigned
(
mantissas
,
1
))
>>
(
E8_BIAS
-
exponents
-
1
),
mantissas
)
exponents
=
torch
.
maximum
(
exponents
,
torch
.
tensor
(
E8_BIAS
-
E2_BIAS
,
device
=
device
))
-
(
E8_BIAS
-
E2_BIAS
)
e2m1_tmp
=
right_shift_unsigned
(((
exponents
<<
2
)
|
right_shift_unsigned
(
mantissas
,
21
))
+
1
,
1
)
e2m1_tmp
=
torch
.
minimum
(
e2m1_tmp
,
torch
.
tensor
(
0x7
,
device
=
device
))
e2m1_value
=
(
right_shift_unsigned
(
signs
,
28
)
|
e2m1_tmp
).
to
(
torch
.
uint8
)
# shape: (..., even_axis_shape)
# Pack pairs of 4-bit values along the last dimension.
e2m1_value
=
e2m1_value
.
view
(
*
e2m1_value
.
shape
[:
-
1
],
axis_shape
//
2
,
2
)
evens
=
e2m1_value
[...,
0
]
odds
=
e2m1_value
[...,
1
]
out_weight
=
evens
|
(
odds
<<
4
)
# shape: (..., axis_shape//2)
# --- Process and output the scale ---
dq_scale
=
(
ds_int_rounded
.
view
(
*
dequant_scale
.
shape
)
>>
23
).
to
(
torch
.
uint8
)
# shape: (..., axis_shape//32, 1)
dq_scale
=
dq_scale
.
squeeze
(
-
1
)
out_weight
=
out_weight
.
transpose
(
axis
,
src_tensor
.
ndim
-
1
)
dq_scale
=
dq_scale
.
transpose
(
axis
,
src_tensor
.
ndim
-
1
)
return
out_weight
,
dq_scale
def
cvt_e2m1_to_fp32
(
input_tensor
):
assert
input_tensor
.
dtype
==
torch
.
uint8
input_tensor
=
input_tensor
.
to
(
torch
.
int32
)
evens
=
input_tensor
&
0xF
odds
=
(
input_tensor
>>
4
)
&
0xF
vals
=
[
0.0
,
0.5
,
1
,
1.5
,
2
,
3
,
4
,
6
]
outputs
=
torch
.
tensor
(
vals
,
dtype
=
torch
.
float32
,
device
=
input_tensor
.
device
)
outputs
=
torch
.
cat
([
outputs
,
-
outputs
])
even_floats
=
outputs
[
evens
]
odd_floats
=
outputs
[
odds
]
output_tensor
=
torch
.
stack
([
even_floats
,
odd_floats
],
dim
=-
1
)
output_tensor
=
output_tensor
.
view
(
*
input_tensor
.
shape
[:
-
1
],
-
1
)
return
output_tensor
def
upcast_from_mxfp_torch
(
tensor
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
target_dtype
:
torch
.
dtype
,
axis
:
int
):
"""
Converts the mxfp4/mxfp8 tensor to the target format specified by target_dtype.
axis: The axis along which dequantization is applied.
Returns:
out_weight: Tensor in the target format.
"""
ndim
=
tensor
.
ndim
assert
-
ndim
<=
axis
<
ndim
,
f
"Invalid axis
{
axis
=
}
"
is_fp8
=
tensor
.
dtype
==
torch
.
float8_e4m3fn
or
tensor
.
dtype
==
torch
.
float8_e5m2
assert
is_fp8
or
tensor
.
dtype
==
torch
.
uint8
,
f
"Invalid input quantization type
{
tensor
.
dtype
}
"
# Permute the tensor and scale so that the quantization axis becomes the last dimension
axis
=
axis
if
axis
>=
0
else
axis
+
ndim
scale
=
scale
.
transpose
(
axis
,
scale
.
ndim
-
1
)
tensor
=
tensor
.
transpose
(
axis
,
tensor
.
ndim
-
1
)
dq_scale
=
(
scale
.
to
(
torch
.
int32
)
<<
23
).
view
(
torch
.
float32
)
# Shift to the exponent and bitcast to fp32
if
tensor
.
dtype
==
torch
.
uint8
:
fp32_tensor
=
cvt_e2m1_to_fp32
(
tensor
)
else
:
fp32_tensor
=
tensor
.
to
(
torch
.
float32
)
logical_quant_dim
=
tensor
.
shape
[
-
1
]
*
(
2
if
tensor
.
dtype
==
torch
.
uint8
else
1
)
axis_shape
=
fp32_tensor
.
size
(
-
1
)
padded_axis_shape
=
triton
.
cdiv
(
logical_quant_dim
,
MXFP_BLOCK_SIZE
)
*
MXFP_BLOCK_SIZE
pad_size
=
padded_axis_shape
-
axis_shape
padded_tensor
=
F
.
pad
(
fp32_tensor
,
(
0
,
pad_size
))
new_axis_shape
=
padded_tensor
.
shape
[
-
1
]
new_shape
=
padded_tensor
.
shape
[:
-
1
]
+
(
new_axis_shape
//
MXFP_BLOCK_SIZE
,
MXFP_BLOCK_SIZE
)
padded_tensor
=
padded_tensor
.
view
(
*
new_shape
)
dq_scale_padded
=
dq_scale
.
unsqueeze
(
-
1
)
# shape: [..., ceil(axis_shape/32), 1]
out_padded
=
padded_tensor
*
dq_scale_padded
# Flatten back and remove the padded tail
out_padded
=
out_padded
.
view
(
*
fp32_tensor
.
shape
[:
-
1
],
new_axis_shape
)
out_tensor
=
out_padded
[...,
:
axis_shape
]
out_tensor
=
out_tensor
.
to
(
target_dtype
).
contiguous
()
out_tensor
=
out_tensor
.
transpose
(
axis
,
tensor
.
ndim
-
1
)
return
out_tensor
quantize_mxfp8_fn
=
_quantize_mxfp8_fn
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/__init__.py
deleted
100644 → 0
View file @
2b7160c6
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
deleted
100644 → 0
View file @
2b7160c6
import
triton
import
triton.language
as
tl
# fmt: off
MXFP_BLOCK_SIZE
=
tl
.
constexpr
(
32
)
@
triton
.
jit
def
_get_max_quant_val
(
dtype
:
tl
.
constexpr
):
if
dtype
==
tl
.
uint8
:
return
6.0
elif
dtype
==
tl
.
float8e5
:
return
57344.0
elif
dtype
==
tl
.
float8e4nv
:
return
448.0
else
:
tl
.
static_assert
(
False
,
f
"Invalid
{
dtype
=
}
"
)
@
triton
.
jit
def
_compute_quant_and_scale
(
src_tensor
,
valid_src_mask
,
mx_tensor_dtype
:
tl
.
constexpr
,
DEQUANT_SCALE_ROUNDING_MODE
:
tl
.
constexpr
=
0
):
is_fp8
:
tl
.
constexpr
=
mx_tensor_dtype
==
tl
.
float8e4nv
or
mx_tensor_dtype
==
tl
.
float8e5
BLOCK_SIZE_OUT_DIM
:
tl
.
constexpr
=
src_tensor
.
shape
[
0
]
BLOCK_SIZE_QUANT_DIM
:
tl
.
constexpr
=
src_tensor
.
shape
[
1
]
BLOCK_SIZE_QUANT_MX_SCALE
:
tl
.
constexpr
=
src_tensor
.
shape
[
1
]
//
MXFP_BLOCK_SIZE
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
f32_tensor
=
src_tensor
.
to
(
tl
.
float32
)
abs_tensor
=
tl
.
abs
(
f32_tensor
)
abs_tensor
=
tl
.
where
(
valid_src_mask
,
abs_tensor
,
-
1.0
)
# Don't consider padding tensors in scale computation
abs_tensor
=
tl
.
reshape
(
abs_tensor
,
[
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
,
MXFP_BLOCK_SIZE
])
max_val
=
tl
.
max
(
abs_tensor
,
axis
=
2
,
keep_dims
=
True
)
dequant_scale
=
max_val
/
_get_max_quant_val
(
mx_tensor_dtype
)
if
DEQUANT_SCALE_ROUNDING_MODE
==
0
:
# DequantScaleRoundingMode.ROUND_UP
# compute 2 ** ceil(log2(dequant_scale))
# Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
# A corner case: exponent is 0xFF that will overflow but that's already
# NaN so assume we don't care.
dequant_scale_exponent
=
(
dequant_scale
.
to
(
tl
.
uint32
,
bitcast
=
True
)
+
0x007FFFFF
)
&
0x7F800000
else
:
# DequantScaleRoundingMode.ROUND_DOWN
# compute 2 ** floor(log2(dequant_scale))
assert
DEQUANT_SCALE_ROUNDING_MODE
==
1
dequant_scale_exponent
=
dequant_scale
.
to
(
tl
.
uint32
,
bitcast
=
True
)
&
0x7F800000
dequant_scale_rounded
=
dequant_scale_exponent
.
to
(
tl
.
float32
,
bitcast
=
True
)
quant_scale
=
tl
.
where
(
dequant_scale_rounded
==
0
,
0
,
1.0
/
dequant_scale_rounded
)
f32_tensor
=
tl
.
reshape
(
f32_tensor
,
[
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
,
MXFP_BLOCK_SIZE
])
quant_tensor
=
f32_tensor
*
quant_scale
# Reshape the tensors after scaling
quant_tensor
=
quant_tensor
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_DIM
])
# Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format.
quant_tensor
=
tl
.
where
(
valid_src_mask
,
quant_tensor
,
0
)
dequant_scale_exponent
=
dequant_scale_exponent
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
])
# First, we simply extract the exponent part of the scales and store the result
dequant_scale_exponent
=
(
dequant_scale_exponent
>>
23
).
to
(
tl
.
uint8
)
# Now we must convert the tensors to the mx format.
if
is_fp8
:
out_tensor
=
quant_tensor
.
to
(
mx_tensor_dtype
)
else
:
quant_tensor
=
quant_tensor
.
to
(
tl
.
uint32
,
bitcast
=
True
)
signs
=
quant_tensor
&
0x80000000
exponents
=
(
quant_tensor
>>
23
)
&
0xFF
mantissas
=
(
quant_tensor
&
0x7FFFFF
)
# 0.25 <= x < 0.75 maps to 0.5, a denormal number
E8_BIAS
=
127
E2_BIAS
=
1
# Move implicit bit 1 at the beginning to mantissa for denormals
adjusted_exponents
=
tl
.
core
.
sub
(
E8_BIAS
,
exponents
+
1
,
sanitize_overflow
=
False
)
mantissas
=
tl
.
where
(
exponents
<
E8_BIAS
,
(
0x400000
|
(
mantissas
>>
1
))
>>
adjusted_exponents
,
mantissas
)
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
exponents
=
tl
.
maximum
(
exponents
,
E8_BIAS
-
E2_BIAS
)
-
(
E8_BIAS
-
E2_BIAS
)
# Combine sign, exponent, and mantissa, while saturating
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
e2m1_tmp
=
tl
.
minimum
((((
exponents
<<
2
)
|
(
mantissas
>>
21
))
+
1
)
>>
1
,
0x7
)
e2m1_value
=
((
signs
>>
28
)
|
e2m1_tmp
).
to
(
tl
.
uint8
)
e2m1_value
=
tl
.
reshape
(
e2m1_value
,
[
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_DIM
//
2
,
2
])
evens
,
odds
=
tl
.
split
(
e2m1_value
)
out_tensor
=
evens
|
(
odds
<<
4
)
return
out_tensor
,
dequant_scale_exponent
@
triton
.
jit
def
_downcast_to_mxfp
(
mx_tensor_ptr
,
stride_mxt_outer
,
stride_mxt_quant
:
tl
.
constexpr
,
mx_scale_ptr
,
stride_mx_scale_outer
,
stride_mx_scale_quant
,
src_ptr
,
stride_src_outer
,
stride_src_quant
,
outer_dim
,
quant_dim
,
BLOCK_SIZE_OUT_DIM
:
tl
.
constexpr
,
BLOCK_SIZE_QUANT_DIM
:
tl
.
constexpr
,
DEQUANT_SCALE_ROUNDING_MODE
:
tl
.
constexpr
):
tl
.
static_assert
(
stride_mxt_quant
==
1
,
f
"Output stride,
{
stride_mxt_quant
=
}
must be 1."
)
tl
.
static_assert
(
BLOCK_SIZE_QUANT_DIM
%
MXFP_BLOCK_SIZE
==
0
,
f
"
{
BLOCK_SIZE_QUANT_DIM
=
}
must be a multiple of 32"
)
# uint8 signifies two fp4 e2m1 values packed into a single byte
mx_tensor_dtype
:
tl
.
constexpr
=
mx_tensor_ptr
.
dtype
.
element_ty
tl
.
static_assert
(
mx_tensor_dtype
==
tl
.
uint8
or
(
mx_tensor_dtype
==
tl
.
float8e4nv
or
mx_tensor_dtype
==
tl
.
float8e5
),
f
"Invalid
{
mx_tensor_dtype
=
}
. Must be uint8 or float8."
)
src_dtype
:
tl
.
constexpr
=
src_ptr
.
dtype
.
element_ty
tl
.
static_assert
(
mx_scale_ptr
.
dtype
.
element_ty
==
tl
.
uint8
,
f
"
{
mx_scale_ptr
.
dtype
.
element_ty
=
}
must be uint8"
)
tl
.
static_assert
((
src_dtype
==
tl
.
bfloat16
)
or
(
src_dtype
==
tl
.
float16
)
or
(
src_dtype
==
tl
.
float32
),
f
"
{
src_dtype
=
}
must be bfloat16 or float16 or float32"
)
is_fp4
:
tl
.
constexpr
=
mx_tensor_dtype
==
tl
.
uint8
outer_block
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
quant_block
=
tl
.
program_id
(
1
).
to
(
tl
.
int64
)
K_DIVISOR
:
tl
.
constexpr
=
2
if
is_fp4
else
1
BLOCK_SIZE_QUANT_MX_SCALE
:
tl
.
constexpr
=
BLOCK_SIZE_QUANT_DIM
//
MXFP_BLOCK_SIZE
BLOCK_SIZE_QUANT_MX_TENSOR
:
tl
.
constexpr
=
BLOCK_SIZE_QUANT_DIM
//
K_DIVISOR
start_src_quant
=
quant_block
*
BLOCK_SIZE_QUANT_DIM
start_mx_scale_quant
=
quant_block
*
BLOCK_SIZE_QUANT_MX_SCALE
start_mx_quant
=
quant_block
*
BLOCK_SIZE_QUANT_MX_TENSOR
start_out
=
outer_block
*
BLOCK_SIZE_OUT_DIM
src_ptr
+=
start_src_quant
*
stride_src_quant
+
start_out
*
stride_src_outer
mx_scale_ptr
+=
start_mx_scale_quant
*
stride_mx_scale_quant
+
start_out
*
stride_mx_scale_outer
mx_tensor_ptr
+=
start_mx_quant
*
stride_mxt_quant
+
start_out
*
stride_mxt_outer
offs_src_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_DIM
)[
None
,
:].
to
(
tl
.
int64
)
offs_mxt_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_MX_TENSOR
)[
None
,
:].
to
(
tl
.
int64
)
offs_scale_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_MX_SCALE
)[
None
,
:].
to
(
tl
.
int64
)
offs_outer
=
tl
.
arange
(
0
,
BLOCK_SIZE_OUT_DIM
)[:,
None
].
to
(
tl
.
int64
)
mask_src_quant
=
start_src_quant
+
offs_src_quant
<
quant_dim
mask_n
=
start_out
+
offs_outer
<
outer_dim
full_mask_src
=
mask_src_quant
&
mask_n
mask_mxt_quant
=
start_mx_quant
+
offs_mxt_quant
<
tl
.
cdiv
(
quant_dim
,
K_DIVISOR
)
full_mask_mxt
=
mask_mxt_quant
&
mask_n
scale_mask_k
=
start_mx_scale_quant
+
offs_scale_quant
<
tl
.
cdiv
(
quant_dim
,
MXFP_BLOCK_SIZE
)
full_scale_mask
=
scale_mask_k
&
mask_n
src_tensor_offsets
=
offs_src_quant
*
stride_src_quant
+
offs_outer
*
stride_src_outer
mx_scale_offsets
=
offs_scale_quant
*
stride_mx_scale_quant
+
offs_outer
*
stride_mx_scale_outer
mx_tensor_offsets
=
offs_mxt_quant
*
stride_mxt_quant
+
offs_outer
*
stride_mxt_outer
src_tensor
=
tl
.
load
(
src_ptr
+
src_tensor_offsets
,
mask
=
full_mask_src
)
out_tensor
,
scale_tensor
=
_compute_quant_and_scale
(
src_tensor
,
full_mask_src
,
mx_tensor_dtype
,
DEQUANT_SCALE_ROUNDING_MODE
)
tl
.
store
(
mx_scale_ptr
+
mx_scale_offsets
,
scale_tensor
,
mask
=
full_scale_mask
)
tl
.
store
(
mx_tensor_ptr
+
mx_tensor_offsets
,
out_tensor
,
mask
=
full_mask_mxt
)
@
triton
.
jit
(
repr
=
lambda
_
:
"_dequantize_mxfp8"
)
def
_quantize_mxfp8_fn
(
input
,
mask
,
pid
=
None
):
return
_compute_quant_and_scale
(
input
,
mask
,
tl
.
float8e4nv
)
vllm/compactor-vllm/src/compactor_vllm/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py
deleted
100644 → 0
View file @
2b7160c6
import
triton
import
triton.language
as
tl
from
._downcast_to_mxfp
import
MXFP_BLOCK_SIZE
# fmt: off
@
triton
.
jit
def
_upcast_from_mxfp
(
out_ptr
,
stride_o_outer
,
stride_o_quant
:
tl
.
constexpr
,
mx_scale_ptr
,
stride_scale_outer
,
stride_scale_quant
,
mx_tensor_ptr
,
stride_tensor_outer
,
stride_tensor_quant
:
tl
.
constexpr
,
outer_dim
,
quant_dim
,
BLOCK_SIZE_OUT_DIM
:
tl
.
constexpr
,
BLOCK_SIZE_QUANT_DIM
:
tl
.
constexpr
):
tl
.
static_assert
(
stride_o_quant
==
1
,
"the weight must be contiguous in the k dimension for mx"
)
tl
.
static_assert
(
BLOCK_SIZE_QUANT_DIM
%
MXFP_BLOCK_SIZE
==
0
,
"BLOCK_SIZE_K must be a multiple of 32"
)
# uint8 signifies two fp4 e2m1 values packed into a single byte
mx_tensor_dtype
:
tl
.
constexpr
=
mx_tensor_ptr
.
dtype
.
element_ty
dst_dtype
:
tl
.
constexpr
=
out_ptr
.
dtype
.
element_ty
tl
.
static_assert
(
dst_dtype
==
tl
.
float16
or
dst_dtype
==
tl
.
bfloat16
or
dst_dtype
==
tl
.
float32
)
tl
.
static_assert
(
mx_tensor_dtype
==
tl
.
uint8
or
((
mx_tensor_dtype
==
tl
.
float8e4nv
or
mx_tensor_dtype
==
tl
.
float8e5
)
or
mx_tensor_dtype
==
dst_dtype
),
"mx_tensor_ptr must be uint8 or float8 or dst_dtype"
)
tl
.
static_assert
(
mx_scale_ptr
.
dtype
.
element_ty
==
tl
.
uint8
,
"mx_scale_ptr must be uint8"
)
# Determine if we are dealing with fp8 types.
is_fp4
:
tl
.
constexpr
=
mx_tensor_dtype
==
tl
.
uint8
is_fp8
:
tl
.
constexpr
=
mx_tensor_dtype
==
tl
.
float8e4nv
or
mx_tensor_dtype
==
tl
.
float8e5
K_DIVISOR
:
tl
.
constexpr
=
2
if
is_fp4
else
1
BLOCK_SIZE_QUANT_MX_SCALE
:
tl
.
constexpr
=
BLOCK_SIZE_QUANT_DIM
//
MXFP_BLOCK_SIZE
BLOCK_SIZE_QUANT_MX_TENSOR
:
tl
.
constexpr
=
BLOCK_SIZE_QUANT_DIM
//
K_DIVISOR
# Compute starting indices for the quantized (packed) dimension and the outer dimension.
outer_block
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
quant_block
=
tl
.
program_id
(
1
).
to
(
tl
.
int64
)
start_mxt_quant
=
quant_block
*
BLOCK_SIZE_QUANT_MX_TENSOR
start_out_quant
=
quant_block
*
BLOCK_SIZE_QUANT_DIM
start_mx_scale_quant
=
quant_block
*
BLOCK_SIZE_QUANT_MX_SCALE
start_out
=
outer_block
*
BLOCK_SIZE_OUT_DIM
mx_tensor_ptr
+=
start_mxt_quant
*
stride_tensor_quant
+
start_out
*
stride_tensor_outer
mx_scale_ptr
+=
start_mx_scale_quant
*
stride_scale_quant
+
start_out
*
stride_scale_outer
out_ptr
+=
start_out
*
stride_o_outer
+
start_out_quant
*
stride_o_quant
# Compute offsets and masks.
offs_src_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_MX_TENSOR
)[
None
,
:].
to
(
tl
.
int64
)
offs_out_quant
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_DIM
)[
None
,
:].
to
(
tl
.
int64
)
offs_outer
=
tl
.
arange
(
0
,
BLOCK_SIZE_OUT_DIM
)[:,
None
].
to
(
tl
.
int64
)
offs_scale
=
tl
.
arange
(
0
,
BLOCK_SIZE_QUANT_MX_SCALE
)[
None
,
:].
to
(
tl
.
int64
)
mask_outer
=
start_out
+
offs_outer
<
outer_dim
mask_out_quant
=
start_out_quant
+
offs_out_quant
<
quant_dim
full_mask_out
=
mask_out_quant
&
mask_outer
mask_src_quant
=
start_mxt_quant
+
offs_src_quant
<
tl
.
cdiv
(
quant_dim
,
K_DIVISOR
)
full_mask_src
=
mask_src_quant
&
mask_outer
mask_scale
=
start_mx_scale_quant
+
offs_scale
<
tl
.
cdiv
(
quant_dim
,
MXFP_BLOCK_SIZE
)
full_scale_mask
=
mask_scale
&
mask_outer
tensor_offsets
=
offs_src_quant
*
stride_tensor_quant
+
offs_outer
*
stride_tensor_outer
scale_offsets
=
offs_scale
*
stride_scale_quant
+
offs_outer
*
stride_scale_outer
out_offsets
=
offs_out_quant
*
stride_o_quant
+
offs_outer
*
stride_o_outer
# Load the packed tensor and scale.
tensor
=
tl
.
load
(
mx_tensor_ptr
+
tensor_offsets
,
mask
=
full_mask_src
)
scale
=
tl
.
load
(
mx_scale_ptr
+
scale_offsets
,
mask
=
full_scale_mask
)
# Upcast the scale to the destination type.
if
dst_dtype
==
tl
.
bfloat16
:
dst_scale
=
(
scale
.
to
(
tl
.
uint16
)
<<
7
).
to
(
dst_dtype
,
bitcast
=
True
)
else
:
dst_scale
=
(
scale
.
to
(
tl
.
uint32
)
<<
23
).
to
(
tl
.
float32
,
bitcast
=
True
)
if
dst_dtype
==
tl
.
float16
:
dst_scale
=
dst_scale
.
to
(
tl
.
float16
)
# Now upcast the tensor.
intermediate_dtype
:
tl
.
constexpr
=
tl
.
bfloat16
if
dst_dtype
==
tl
.
float32
else
dst_dtype
if
is_fp8
:
dst_tensor
=
tensor
.
to
(
intermediate_dtype
)
if
tensor
.
dtype
==
tl
.
float8e5
:
from_e_bits
:
tl
.
constexpr
=
5
from_m_bits
:
tl
.
constexpr
=
2
to_e_bits
:
tl
.
constexpr
=
8
if
intermediate_dtype
==
tl
.
bfloat16
else
5
to_m_bits
:
tl
.
constexpr
=
7
if
intermediate_dtype
==
tl
.
bfloat16
else
10
# Preserve infs and nans. FIXME Fp8E5M2_to_Bf16 doesn't preserve them!
non_finite_mask_src
:
tl
.
constexpr
=
((
1
<<
from_e_bits
)
-
1
)
<<
from_m_bits
non_finite_mask_dst
:
tl
.
constexpr
=
((
1
<<
to_e_bits
)
-
1
)
<<
to_m_bits
dst_tensor
=
tl
.
where
(
(
tensor
.
to
(
tl
.
uint8
,
bitcast
=
True
)
&
non_finite_mask_src
)
==
non_finite_mask_src
,
(
dst_tensor
.
to
(
tl
.
uint16
,
bitcast
=
True
)
|
non_finite_mask_dst
).
to
(
intermediate_dtype
,
bitcast
=
True
),
dst_tensor
,
)
else
:
assert
is_fp4
dst_bias
:
tl
.
constexpr
=
127
if
intermediate_dtype
==
tl
.
bfloat16
else
15
dst_0p5
:
tl
.
constexpr
=
16128
if
intermediate_dtype
==
tl
.
bfloat16
else
0x3800
dst_m_bits
:
tl
.
constexpr
=
7
if
intermediate_dtype
==
tl
.
bfloat16
else
10
# e2m1
em0
=
tensor
&
0x07
em1
=
tensor
&
0x70
x0
=
(
em0
.
to
(
tl
.
uint16
)
<<
(
dst_m_bits
-
1
))
|
((
tensor
&
0x08
).
to
(
tl
.
uint16
)
<<
12
)
x1
=
(
em1
.
to
(
tl
.
uint16
)
<<
(
dst_m_bits
-
5
))
|
((
tensor
&
0x80
).
to
(
tl
.
uint16
)
<<
8
)
# Three cases:
# 1) x is normal and non-zero: Correct bias
x0
=
tl
.
where
((
em0
&
0x06
)
!=
0
,
x0
+
((
dst_bias
-
1
)
<<
dst_m_bits
),
x0
)
x1
=
tl
.
where
((
em1
&
0x60
)
!=
0
,
x1
+
((
dst_bias
-
1
)
<<
dst_m_bits
),
x1
)
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type
x0
=
tl
.
where
(
em0
==
0x01
,
dst_0p5
|
(
x0
&
0x8000
),
x0
)
x1
=
tl
.
where
(
em1
==
0x10
,
dst_0p5
|
(
x1
&
0x8000
),
x1
)
# 3) x is zero, do nothing
dst_tensor
=
tl
.
interleave
(
x0
,
x1
).
to
(
intermediate_dtype
,
bitcast
=
True
)
dst_tensor
=
dst_tensor
.
to
(
dst_dtype
)
# Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping.
dst_tensor
=
dst_tensor
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
,
MXFP_BLOCK_SIZE
])
dst_scale
=
dst_scale
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_MX_SCALE
,
1
])
scale
=
scale
.
reshape
(
dst_scale
.
shape
)
out_tensor
=
dst_tensor
*
dst_scale
# Correct any NaNs encoded via the scale.
out_tensor
=
tl
.
where
(
scale
==
0xFF
,
float
(
"nan"
),
out_tensor
)
out_tensor
=
out_tensor
.
reshape
([
BLOCK_SIZE_OUT_DIM
,
BLOCK_SIZE_QUANT_DIM
])
tl
.
store
(
out_ptr
+
out_offsets
,
out_tensor
,
mask
=
full_mask_out
)
Prev
1
2
3
4
5
6
7
8
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment