Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
06d49028
Unverified
Commit
06d49028
authored
Dec 21, 2025
by
Michael Goin
Committed by
GitHub
Dec 21, 2025
Browse files
[NVFP4][Perf] Tune NVFP4 input quant kernel for small batch size (#30897)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
b471092d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
243 additions
and
97 deletions
+243
-97
benchmarks/kernels/bench_nvfp4_quant.py
benchmarks/kernels/bench_nvfp4_quant.py
+177
-0
csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
...quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
+4
-1
csrc/quantization/fp4/nvfp4_experts_quant.cu
csrc/quantization/fp4/nvfp4_experts_quant.cu
+14
-17
csrc/quantization/fp4/nvfp4_quant_kernels.cu
csrc/quantization/fp4/nvfp4_quant_kernels.cu
+20
-42
csrc/quantization/fp4/nvfp4_utils.cuh
csrc/quantization/fp4/nvfp4_utils.cuh
+28
-37
No files found.
benchmarks/kernels/bench_nvfp4_quant.py
0 → 100644
View file @
06d49028
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
copy
import
itertools
import
torch
from
weight_shapes
import
WEIGHT_SHAPES
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.triton_utils
import
triton
from
vllm.utils.flashinfer
import
flashinfer_fp4_quantize
if
not
current_platform
.
has_device_capability
(
100
):
raise
RuntimeError
(
"NVFP4 requires compute capability of 10.0 (Blackwell)"
)
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
PROVIDER_CFGS
=
{
"vllm"
:
dict
(
backend
=
"vllm"
,
enabled
=
True
),
"flashinfer"
:
dict
(
backend
=
"flashinfer"
,
enabled
=
True
),
}
_enabled
=
[
k
for
k
,
v
in
PROVIDER_CFGS
.
items
()
if
v
[
"enabled"
]]
def
compute_global_scale
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Compute global scale for FP4 quantization."""
amax
=
torch
.
abs
(
tensor
).
max
().
to
(
torch
.
float32
)
return
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
amax
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
],
x_log
=
False
,
line_arg
=
"provider"
,
line_vals
=
_enabled
,
line_names
=
_enabled
,
ylabel
=
"us (lower is better)"
,
plot_name
=
"NVFP4 Input Quantization Latency (us)"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
provider
,
N
,
K
):
M
=
batch_size
device
=
"cuda"
dtype
=
torch
.
bfloat16
# Create input tensor
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
# Compute global scale for activation
a_global_scale
=
compute_global_scale
(
a
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
cfg
=
PROVIDER_CFGS
[
provider
]
if
cfg
[
"backend"
]
==
"vllm"
:
# vLLM's FP4 quantization
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
),
quantiles
=
quantiles
,
)
elif
cfg
[
"backend"
]
==
"flashinfer"
:
# FlashInfer's FP4 quantization
# Use is_sf_swizzled_layout=True to match vLLM's output format
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench_cudagraph
(
lambda
:
flashinfer_fp4_quantize
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
True
),
quantiles
=
quantiles
,
)
# Convert ms to us for better readability at small batch sizes
to_us
=
lambda
t_ms
:
t_ms
*
1000
return
to_us
(
ms
),
to_us
(
max_ms
),
to_us
(
min_ms
)
def
prepare_shapes
(
args
):
out
=
[]
for
model
,
tp_size
in
itertools
.
product
(
args
.
models
,
args
.
tp_sizes
):
for
KN
,
tp_dim
in
copy
.
deepcopy
(
WEIGHT_SHAPES
[
model
]):
KN
[
tp_dim
]
//=
tp_size
KN
.
append
(
model
)
out
.
append
(
KN
)
return
out
def
_test_accuracy_once
(
M
:
int
,
K
:
int
,
dtype
:
torch
.
dtype
,
device
:
str
):
"""Test accuracy between vLLM and FlashInfer FP4 quantization."""
# Create input tensor
a
=
torch
.
randn
((
M
,
K
),
device
=
device
,
dtype
=
dtype
)
# Compute global scale
a_global_scale
=
compute_global_scale
(
a
)
# vLLM quantization
vllm_fp4
,
vllm_scale
=
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
)
# FlashInfer quantization (with swizzled layout to match vLLM's output)
flashinfer_fp4
,
flashinfer_scale
=
flashinfer_fp4_quantize
(
a
,
a_global_scale
,
is_sf_swizzled_layout
=
True
)
flashinfer_scale
=
flashinfer_scale
.
view
(
torch
.
float8_e4m3fn
)
# Compare outputs
torch
.
testing
.
assert_close
(
vllm_fp4
,
flashinfer_fp4
,
)
print
(
f
"M=
{
M
}
, K=
{
K
}
, dtype=
{
dtype
}
: PASSED"
)
def
test_accuracy
():
"""Run accuracy tests across various shapes."""
print
(
"
\n
"
+
"="
*
60
)
print
(
"Running accuracy tests: vLLM vs FlashInfer"
)
print
(
"="
*
60
)
device
=
"cuda"
dtype
=
torch
.
bfloat16
# Test various batch sizes and hidden dimensions
Ms
=
[
1
,
1024
]
Ks
=
[
4096
]
for
M
in
Ms
:
for
K
in
Ks
:
_test_accuracy_once
(
M
,
K
,
dtype
,
device
)
print
(
"
\n
All accuracy tests passed!"
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark NVFP4 quantization: vLLM vs FlashInfer"
)
parser
.
add_argument
(
"--models"
,
nargs
=
"+"
,
type
=
str
,
default
=
[
"meta-llama/Llama-3.1-8B-Instruct"
],
choices
=
list
(
WEIGHT_SHAPES
.
keys
()),
)
parser
.
add_argument
(
"--tp-sizes"
,
nargs
=
"+"
,
type
=
int
,
default
=
[
1
])
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
default
=
None
,
help
=
"Path to save benchmark results"
,
)
parser
.
add_argument
(
"--accuracy"
,
action
=
"store_true"
,
help
=
"Run accuracy tests"
,
)
args
=
parser
.
parse_args
()
if
args
.
accuracy
:
test_accuracy
()
for
K
,
N
,
model
in
prepare_shapes
(
args
):
print
(
f
"
\n
{
model
}
, N=
{
N
}
K=
{
K
}
"
)
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
,
N
=
N
,
K
=
K
,
)
print
(
"
\n
Benchmark finished!"
)
csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu
View file @
06d49028
...
...
@@ -74,6 +74,9 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Precompute SF layout parameter (constant for entire kernel).
int32_t
const
numKTiles
=
(
numCols
+
63
)
/
64
;
// Get the global scaling factor, which will be applied to the SF.
// Note SFScale is the same as next GEMM's alpha, which is
// (448.f / (Alpha_A / 6.f)).
...
...
@@ -101,7 +104,7 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx
,
colIdx
,
num
Col
s
,
SFout
);
rowIdx
,
colIdx
,
num
KTile
s
,
SFout
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
out_silu_mul
,
SFScaleVal
,
sf_out
);
...
...
csrc/quantization/fp4/nvfp4_experts_quant.cu
View file @
06d49028
...
...
@@ -25,6 +25,7 @@
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"
#include "nvfp4_utils.cuh"
#include "launch_bounds_utils.h"
...
...
@@ -44,6 +45,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Precompute SF layout parameter (constant for entire kernel).
int32_t
const
numKTiles
=
(
numCols
+
63
)
/
64
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
colsPerRow
=
numCols
/
CVT_FP4_ELTS_PER_THREAD
;
...
...
@@ -112,17 +116,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
// (448.f / (Alpha_A / 6.f)).
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
// The actual output_scales dim is computed from the padded numCols.
int32_t
numCols_padded
=
(
numCols
+
factor
-
1
)
/
factor
*
factor
;
int
numCols_SFout
=
numCols_padded
/
CVT_FP4_SF_VEC_SIZE
/
4
;
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
num
Cols_SFout
;
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
num
KTiles
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
num
Col
s
,
SFout_in_expert
);
rowIdx_in_expert
,
colIdx
,
num
KTile
s
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
...
...
@@ -140,6 +140,10 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
(
CVT_FP4_SF_VEC_SIZE
/
CVT_FP4_ELTS_PER_THREAD
);
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Precompute SF layout parameter (constant for entire kernel).
int32_t
const
numKTiles
=
(
numCols
+
63
)
/
64
;
extern
__shared__
uint32_t
shared_input_offsets
[];
// Load input offsets into shared memory.
...
...
@@ -202,16 +206,13 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
float
const
SFScaleVal
=
SFScale
==
nullptr
?
1.0
f
:
SFScale
[
expert_idx
];
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numCols_padded
=
(
numCols
+
factor
-
1
)
/
factor
*
factor
;
int
numCols_SFout
=
numCols_padded
/
CVT_FP4_SF_VEC_SIZE
/
4
;
uint32_t
*
SFout_in_expert
=
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
num
Cols_SFout
;
SFout
+
output_scale_offset_by_experts
[
expert_idx
]
*
num
KTiles
;
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx_in_expert
,
colIdx
,
num
Col
s
,
SFout_in_expert
);
rowIdx_in_expert
,
colIdx
,
num
KTile
s
,
SFout_in_expert
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
SFScaleVal
,
sf_out
);
}
...
...
@@ -222,12 +223,8 @@ void quant_impl(void* output, void* output_scale, void* input,
void
*
input_global_scale
,
void
*
input_offset_by_experts
,
void
*
output_scale_offset_by_experts
,
int
m_topk
,
int
k
,
int
n_experts
,
cudaStream_t
stream
)
{
// TODO: this multiProcessorCount should be cached.
int
device
;
cudaGetDevice
(
&
device
);
int
multiProcessorCount
;
cudaDeviceGetAttribute
(
&
multiProcessorCount
,
cudaDevAttrMultiProcessorCount
,
device
);
int
multiProcessorCount
=
get_device_attribute
(
cudaDevAttrMultiProcessorCount
,
-
1
);
// Grid, Block size.
// Each thread converts 8 values.
...
...
csrc/quantization/fp4/nvfp4_quant_kernels.cu
View file @
06d49028
...
...
@@ -38,6 +38,12 @@ __host__ __device__ inline Int round_up(Int x, Int y) {
return
(
x
+
y
-
1
)
/
y
*
y
;
}
// Compute effective rows for grid configuration with swizzled SF layouts.
inline
int
computeEffectiveRows
(
int
m
)
{
constexpr
int
ROW_TILE
=
128
;
return
round_up
(
m
,
ROW_TILE
);
}
// Use UE4M3 by default.
template
<
class
Type
,
bool
UE8M0_SF
=
false
>
__global__
void
__launch_bounds__
(
512
,
VLLM_BLOCKS_PER_SM
(
512
))
...
...
@@ -49,6 +55,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
static_assert
(
sizeof
(
PackedVec
)
==
sizeof
(
Type
)
*
CVT_FP4_ELTS_PER_THREAD
,
"Vec size is not matched."
);
// Precompute SF layout parameter (constant for entire kernel).
int32_t
const
numKTiles
=
(
numCols
+
63
)
/
64
;
int
sf_m
=
round_up
<
int
>
(
numRows
,
128
);
int
sf_n_unpadded
=
numCols
/
CVT_FP4_SF_VEC_SIZE
;
int
sf_n_int
=
round_up
<
int
>
(
sf_n_unpadded
,
4
)
/
4
;
...
...
@@ -79,7 +88,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
auto
sf_out
=
cvt_quant_to_fp4_get_sf_out_offset
<
uint32_t
,
CVT_FP4_NUM_THREADS_PER_SF
>
(
rowIdx
,
colIdx
,
num
Col
s
,
SFout
);
rowIdx
,
colIdx
,
num
KTile
s
,
SFout
);
out_pos
=
cvt_warp_fp16_to_fp4
<
Type
,
UE8M0_SF
>
(
in_vec
,
global_scale
,
sf_out
);
...
...
@@ -87,43 +96,6 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
}
}
template
<
typename
T
>
void
invokeFP4Quantization
(
int
m
,
int
n
,
T
const
*
input
,
float
const
*
SFScale
,
int64_t
*
output
,
int32_t
*
SFOuput
,
bool
useUE8M0
,
int
multiProcessorCount
,
cudaStream_t
stream
)
{
// Grid, Block size.
// Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
n
/
ELTS_PER_THREAD
),
512
));
// Get number of blocks per SM
int
const
numBlocksPerSM
=
vllm_runtime_blocks_per_sm
(
static_cast
<
int
>
(
block
.
x
));
dim3
grid
(
std
::
min
(
int
(
m
),
multiProcessorCount
*
numBlocksPerSM
));
// Launch the cvt kernel.
if
(
useUE8M0
)
{
cvt_fp16_to_fp4
<
T
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input
,
SFScale
,
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
SFOuput
));
}
else
{
cvt_fp16_to_fp4
<
T
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input
,
SFScale
,
reinterpret_cast
<
uint32_t
*>
(
output
),
reinterpret_cast
<
uint32_t
*>
(
SFOuput
));
}
}
// Instantiate the function.
template
void
invokeFP4Quantization
(
int
m
,
int
n
,
half
const
*
input
,
float
const
*
SFScale
,
int64_t
*
output
,
int32_t
*
SFOuput
,
bool
useUE8M0
,
int
multiProcessorCount
,
cudaStream_t
stream
);
template
void
invokeFP4Quantization
(
int
m
,
int
n
,
__nv_bfloat16
const
*
input
,
float
const
*
SFScale
,
int64_t
*
output
,
int32_t
*
SFOuput
,
bool
useUE8M0
,
int
multiProcessorCount
,
cudaStream_t
stream
);
}
// namespace vllm
void
scaled_fp4_quant_sm1xxa
(
torch
::
Tensor
const
&
output
,
...
...
@@ -147,13 +119,19 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
input
.
get_device
());
// We don't support e8m0 scales at this moment.
bool
useUE8M0
=
false
;
// Grid, Block size. Each thread converts 8 values.
dim3
block
(
std
::
min
(
int
(
n
/
ELTS_PER_THREAD
),
512
));
int
const
numBlocksPerSM
=
vllm_runtime_blocks_per_sm
(
static_cast
<
int
>
(
block
.
x
));
int
effectiveRows
=
vllm
::
computeEffectiveRows
(
m
);
dim3
grid
(
std
::
min
(
effectiveRows
,
multiProcessorCount
*
numBlocksPerSM
));
VLLM_DISPATCH_HALF_TYPES
(
input
.
scalar_type
(),
"nvfp4_quant_kernel"
,
[
&
]
{
using
cuda_type
=
vllm
::
CUDATypeConverter
<
scalar_t
>::
Type
;
auto
input_ptr
=
static_cast
<
cuda_type
const
*>
(
input
.
data_ptr
());
vllm
::
invokeFP4Quantization
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
output_ptr
,
sf_out
,
useUE8M0
,
multiProcessorCount
,
stream
);
// NOTE: We don't support e8m0 scales at this moment.
vllm
::
cvt_fp16_to_fp4
<
cuda_type
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
m
,
n
,
input_ptr
,
input_sf_ptr
,
reinterpret_cast
<
uint32_t
*>
(
output_ptr
),
reinterpret_cast
<
uint32_t
*>
(
sf_out
));
});
}
csrc/quantization/fp4/nvfp4_utils.cuh
View file @
06d49028
...
...
@@ -128,51 +128,42 @@ inline __device__ float reciprocal_approximate_ftz(float a) {
return
b
;
}
// Compute SF output offset for swizzled tensor core layout.
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
// Caller must precompute: numKTiles = (numCols + 63) / 64
template
<
class
SFType
,
int
CVT_FP4_NUM_THREADS_PER_SF
>
__device__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int
numCols
,
SFType
*
SFout
)
{
__device__
__forceinline__
uint8_t
*
cvt_quant_to_fp4_get_sf_out_offset
(
int
rowIdx
,
int
colIdx
,
int32_t
numKTiles
,
SFType
*
SFout
)
{
static_assert
(
CVT_FP4_NUM_THREADS_PER_SF
==
1
||
CVT_FP4_NUM_THREADS_PER_SF
==
2
);
// One pair of threads write one SF to global memory.
// TODO: stage through smem for packed STG.32
// is it better than STG.8 from 4 threads ?
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
==
0
)
{
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
int32_t
mTileIdx
=
mIdx
/
(
32
*
4
);
// SF vector size 16.
int
factor
=
CVT_FP4_SF_VEC_SIZE
*
4
;
int32_t
numKTiles
=
(
numCols
+
factor
-
1
)
/
factor
;
int64_t
mTileStride
=
numKTiles
*
32
*
4
*
4
;
int32_t
kTileIdx
=
(
kIdx
/
4
);
int64_t
kTileStride
=
32
*
4
*
4
;
// M tile layout [32, 4] is column-major.
int32_t
outerMIdx
=
(
mIdx
%
32
);
int64_t
outerMStride
=
4
*
4
;
int32_t
innerMIdx
=
(
mIdx
%
(
32
*
4
))
/
32
;
int64_t
innerMStride
=
4
;
int32_t
innerKIdx
=
(
kIdx
%
4
);
int64_t
innerKStride
=
1
;
// Compute the global offset.
int64_t
SFOffset
=
mTileIdx
*
mTileStride
+
kTileIdx
*
kTileStride
+
outerMIdx
*
outerMStride
+
innerMIdx
*
innerMStride
+
innerKIdx
*
innerKStride
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
if
(
threadIdx
.
x
%
CVT_FP4_NUM_THREADS_PER_SF
!=
0
)
{
return
nullptr
;
}
return
nullptr
;
// SF vector index (16 elements share one SF in the K dimension).
int32_t
kIdx
=
colIdx
/
CVT_FP4_NUM_THREADS_PER_SF
;
int32_t
mIdx
=
rowIdx
;
// Decompose indices using bitwise ops (all divisors are powers of 2).
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
int32_t
mTileIdx
=
mIdx
>>
7
;
// mIdx / 128
int32_t
outerMIdx
=
mIdx
&
31
;
// mIdx % 32
int32_t
innerMIdx
=
(
mIdx
>>
5
)
&
3
;
// (mIdx / 32) % 4
int32_t
kTileIdx
=
kIdx
>>
2
;
// kIdx / 4
int32_t
innerKIdx
=
kIdx
&
3
;
// kIdx % 4
// Compute global SF offset: mTileIdx * (numKTiles * 512) + kTileIdx * 512 +
// outerMIdx * 16 + innerMIdx * 4 + innerKIdx
// Use bitwise OR for non-overlapping lower bits.
int64_t
SFOffset
=
(
static_cast
<
int64_t
>
(
mTileIdx
)
*
numKTiles
+
kTileIdx
)
<<
9
|
(
outerMIdx
<<
4
)
|
(
innerMIdx
<<
2
)
|
innerKIdx
;
return
reinterpret_cast
<
uint8_t
*>
(
SFout
)
+
SFOffset
;
}
// Quantizes the provided PackedVec into the uint32_t output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment