Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7a985548
Commit
7a985548
authored
May 22, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.0' into v0.9.0-ori
parents
45d3785c
dc1440cf
Changes
486
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1383 additions
and
464 deletions
+1383
-464
tests/kernels/moe/test_nvfp4_moe.py
tests/kernels/moe/test_nvfp4_moe.py
+144
-0
tests/kernels/moe/test_pplx_moe.py
tests/kernels/moe/test_pplx_moe.py
+691
-0
tests/kernels/moe/test_rocm_aiter_topk.py
tests/kernels/moe/test_rocm_aiter_topk.py
+122
-0
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
+20
-14
tests/kernels/quantization/nvfp4_utils.py
tests/kernels/quantization/nvfp4_utils.py
+66
-0
tests/kernels/quantization/test_awq_marlin.py
tests/kernels/quantization/test_awq_marlin.py
+0
-163
tests/kernels/quantization/test_block_fp8.py
tests/kernels/quantization/test_block_fp8.py
+14
-12
tests/kernels/quantization/test_block_int8.py
tests/kernels/quantization/test_block_int8.py
+4
-1
tests/kernels/quantization/test_cutlass_scaled_mm.py
tests/kernels/quantization/test_cutlass_scaled_mm.py
+3
-1
tests/kernels/quantization/test_ggml.py
tests/kernels/quantization/test_ggml.py
+6
-0
tests/kernels/quantization/test_gguf.py
tests/kernels/quantization/test_gguf.py
+5
-15
tests/kernels/quantization/test_marlin_gemm.py
tests/kernels/quantization/test_marlin_gemm.py
+67
-164
tests/kernels/quantization/test_nvfp4_quant.py
tests/kernels/quantization/test_nvfp4_quant.py
+1
-1
tests/kernels/quantization/test_nvfp4_scaled_mm.py
tests/kernels/quantization/test_nvfp4_scaled_mm.py
+14
-84
tests/kernels/quantization/test_rocm_skinny_gemms.py
tests/kernels/quantization/test_rocm_skinny_gemms.py
+4
-3
tests/kernels/test_fused_quant_activation.py
tests/kernels/test_fused_quant_activation.py
+70
-0
tests/kv_transfer/test_disagg.py
tests/kv_transfer/test_disagg.py
+2
-2
tests/lora/conftest.py
tests/lora/conftest.py
+15
-3
tests/lora/test_lora_allowed_token_ids.py
tests/lora/test_lora_allowed_token_ids.py
+134
-0
tests/lora/test_lora_huggingface.py
tests/lora/test_lora_huggingface.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
486 of 486+
files are displayed.
Plain diff
Email patch
tests/kernels/moe/test_nvfp4_moe.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
from
tests.kernels.quantization.nvfp4_utils
import
(
FLOAT4_E2M1_MAX
,
FLOAT8_E4M3_MAX
,
dequantize_nvfp4_to_dtype
)
from
tests.kernels.utils
import
torch_moe
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.fused_moe.cutlass_moe
import
cutlass_moe_fp4
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.platforms
import
current_platform
if
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
reason
=
"Nvfp4 Requires compute capability of 10 or above."
,
allow_module_level
=
True
)
MNK_FACTORS
=
[
(
2
,
1024
,
1024
),
(
2
,
1024
,
1536
),
(
2
,
3072
,
1024
),
(
2
,
3072
,
1536
),
(
64
,
1024
,
1024
),
(
64
,
1024
,
1536
),
(
64
,
3072
,
1024
),
(
64
,
2048
,
1536
),
(
224
,
1024
,
1024
),
(
224
,
1024
,
1536
),
]
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
[
40
,
64
,
256
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
6
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
@
torch
.
inference_mode
()
def
test_cutlass_fp4_moe_no_graph
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
):
current_platform
.
seed_everything
(
7
)
with
set_current_vllm_config
(
VllmConfig
(
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
=
1
))):
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
quant_blocksize
=
16
round_up
=
lambda
x
,
y
:
(
x
+
y
-
1
)
//
y
*
y
sf_w1_2n
=
round_up
(
2
*
n
,
128
)
sf_w1_k
=
round_up
(
k
//
quant_blocksize
,
4
)
w1_blockscale
=
torch
.
empty
((
e
,
sf_w1_2n
,
sf_w1_k
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
sf_w2_k
=
round_up
(
k
,
128
)
sf_w2_n
=
round_up
(
n
//
quant_blocksize
,
4
)
w2_blockscale
=
torch
.
empty
((
e
,
sf_w2_k
,
sf_w2_n
),
device
=
"cuda"
,
dtype
=
torch
.
float8_e4m3fn
)
w1_q
=
torch
.
empty
((
e
,
2
*
n
,
k
//
2
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w2_q
=
torch
.
empty
((
e
,
k
,
n
//
2
),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
w1_gs
=
torch
.
empty
((
e
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
w2_gs
=
torch
.
empty
((
e
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
for
expert
in
range
(
e
):
w1_amax
=
torch
.
abs
(
w1
).
max
().
to
(
torch
.
float32
)
w2_amax
=
torch
.
abs
(
w2
).
max
().
to
(
torch
.
float32
)
w1_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w1_amax
w2_gs
[
expert
]
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w2_amax
w1_q
[
expert
],
w1_blockscale
[
expert
]
=
ops
.
scaled_fp4_quant
(
w1
[
expert
],
w1_gs
[
expert
])
w2_q
[
expert
],
w2_blockscale
[
expert
]
=
ops
.
scaled_fp4_quant
(
w2
[
expert
],
w2_gs
[
expert
])
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
renormalize
=
False
)
a1_gs
=
torch
.
ones
((
e
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a2_gs
=
torch
.
ones
((
e
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
cutlass_output
=
cutlass_moe_fp4
(
a
=
a
,
a1_gscale
=
a1_gs
,
w1_fp4
=
w1_q
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
(
1
/
w1_gs
),
a2_gscale
=
a2_gs
,
w2_fp4
=
w2_q
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
(
1
/
w2_gs
),
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
m
=
m
,
n
=
n
,
k
=
k
,
e
=
e
,
device
=
a
.
device
,
)
# Reference check:
a_global_scale
=
((
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
)
/
torch
.
amax
(
a
.
flatten
(),
dim
=-
1
)).
to
(
torch
.
float32
)
a_fp4
,
a_scale_interleaved
=
ops
.
scaled_fp4_quant
(
a
,
a_global_scale
)
_
,
m_k
=
a_fp4
.
shape
a_in_dtype
=
dequantize_nvfp4_to_dtype
(
a_fp4
,
a_scale_interleaved
,
a_global_scale
,
dtype
=
a
.
dtype
,
device
=
a
.
device
,
block_size
=
quant_blocksize
)
w1_d
=
torch
.
empty
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
w2_d
=
torch
.
empty
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
for
idx
in
range
(
0
,
e
):
w1_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w1_q
[
idx
],
w1_blockscale
[
idx
],
w1_gs
[
idx
],
dtype
=
w1
.
dtype
,
device
=
w1
.
device
,
block_size
=
quant_blocksize
)
w2_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w2_q
[
idx
],
w2_blockscale
[
idx
],
w2_gs
[
idx
],
dtype
=
w2
.
dtype
,
device
=
w2
.
device
,
block_size
=
quant_blocksize
)
torch_output
=
torch_moe
(
a_in_dtype
,
w1_d
,
w2_d
,
score
,
topk
,
None
)
torch
.
testing
.
assert_close
(
torch_output
,
cutlass_output
,
atol
=
1e-1
,
rtol
=
1e-1
)
if
__name__
==
"__main__"
:
test_cutlass_fp4_moe_no_graph
((
2
,
1024
,
1024
),
40
,
1
,
torch
.
half
)
tests/kernels/moe/test_pplx_moe.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
"""Tests for the MOE layers.
Run `pytest tests/kernels/test_pplx_moe.py`.
"""
import
dataclasses
import
os
import
traceback
from
typing
import
Callable
,
Optional
import
pytest
import
torch
try
:
from
pplx_kernels
import
AllToAll
from
pplx_kernels.nvshmem
import
(
nvshmem_alloc_empty_unique_id
,
nvshmem_finalize
,
nvshmem_get_unique_id
,
nvshmem_init
)
has_pplx
=
True
except
ImportError
:
has_pplx
=
False
from
torch.multiprocessing
import
(
spawn
)
# pyright: ignore[reportPrivateImportUsage]
from
typing_extensions
import
Concatenate
,
ParamSpec
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
override_config
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
BatchedExperts
,
BatchedPrepareAndFinalize
,
BatchedTritonExperts
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
get_default_config
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
)
from
vllm.platforms
import
current_platform
PPLX_PREPARE_COMBOS
=
[(
4
,
128
,
128
),
(
32
,
1024
,
512
),
(
64
,
1024
,
512
),
(
222
,
2048
,
1024
)]
PPLX_MOE_COMBOS
=
[
(
1
,
128
,
128
),
(
2
,
128
,
512
),
(
3
,
1024
,
2048
),
(
32
,
128
,
1024
),
(
45
,
512
,
2048
),
(
64
,
1024
,
1024
),
(
222
,
1024
,
2048
),
]
NUM_EXPERTS
=
[
8
,
64
]
EP_SIZE
=
[
1
,
4
]
TOP_KS
=
[
1
,
2
,
6
]
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
P
=
ParamSpec
(
"P"
)
requires_pplx
=
pytest
.
mark
.
skipif
(
not
has_pplx
,
reason
=
"Requires PPLX kernels"
,
)
@
dataclasses
.
dataclass
class
ProcessGroupInfo
:
world_size
:
int
world_local_size
:
int
rank
:
int
node_rank
:
int
local_rank
:
int
device
:
torch
.
device
def
_worker_parallel_launch
(
local_rank
:
int
,
world_size
:
int
,
world_local_size
:
int
,
node_rank
:
int
,
init_method
:
str
,
worker
:
Callable
[
Concatenate
[
ProcessGroupInfo
,
P
],
None
],
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
,
)
->
None
:
rank
=
node_rank
*
world_local_size
+
local_rank
torch
.
cuda
.
set_device
(
local_rank
)
device
=
torch
.
device
(
"cuda"
,
local_rank
)
torch
.
distributed
.
init_process_group
(
backend
=
"cpu:gloo,cuda:nccl"
,
init_method
=
init_method
,
rank
=
rank
,
world_size
=
world_size
,
device_id
=
device
,
)
barrier
=
torch
.
tensor
([
rank
],
device
=
device
)
torch
.
distributed
.
all_reduce
(
barrier
)
try
:
worker
(
ProcessGroupInfo
(
world_size
=
world_size
,
world_local_size
=
world_local_size
,
rank
=
rank
,
node_rank
=
node_rank
,
local_rank
=
local_rank
,
device
=
device
,
),
*
args
,
**
kwargs
,
)
except
Exception
as
ex
:
print
(
ex
)
traceback
.
print_exc
()
raise
finally
:
torch
.
distributed
.
destroy_process_group
()
def
parallel_launch
(
world_size
:
int
,
worker
:
Callable
[
Concatenate
[
ProcessGroupInfo
,
P
],
None
],
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
,
)
->
None
:
assert
not
kwargs
spawn
(
_worker_parallel_launch
,
args
=
(
world_size
,
world_size
,
0
,
"tcp://localhost:29500"
,
worker
,
)
+
args
,
nprocs
=
world_size
,
join
=
True
,
)
def
parallel_launch_from_env
(
worker
:
Callable
[
Concatenate
[
ProcessGroupInfo
,
P
],
None
],
*
args
:
P
.
args
,
**
kwargs
:
P
.
kwargs
,
)
->
None
:
"""
Launches a worker function in parallel across all processes in the current
environment. The environment must have the following variables set:
- WORLD_SIZE: The total number of processes.
- WORLD_LOCAL_SIZE: The number of processes on the current node.
- NODE_RANK: The rank of the current
- MASTER_ADDR: The address of the master process.
- MASTER_PORT: The port of the master process.
"""
assert
not
kwargs
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
world_local_size
=
int
(
os
.
environ
[
"WORLD_LOCAL_SIZE"
])
node_rank
=
int
(
os
.
environ
[
"NODE_RANK"
])
assert
"MASTER_ADDR"
in
os
.
environ
assert
"MASTER_PORT"
in
os
.
environ
spawn
(
_worker_parallel_launch
,
args
=
(
world_size
,
world_local_size
,
node_rank
,
"env://"
,
worker
,
)
+
args
,
nprocs
=
world_local_size
,
join
=
True
,
)
def
torch_prepare
(
a
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
max_num_tokens
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
topk_ids
.
dim
()
==
2
assert
topk_ids
.
shape
[
0
]
==
a
.
shape
[
0
]
num_tokens
,
hidden_dim
=
a
.
shape
topk
=
topk_ids
.
shape
[
1
]
tokens_per_expert
=
torch
.
bincount
(
topk_ids
.
view
(
-
1
),
minlength
=
num_experts
)
assert
tokens_per_expert
.
numel
()
==
num_experts
if
max_num_tokens
is
None
:
max_num_tokens
=
int
(
tokens_per_expert
.
max
().
item
())
b_a
=
torch
.
zeros
((
num_experts
,
max_num_tokens
,
hidden_dim
),
dtype
=
a
.
dtype
,
device
=
a
.
device
)
token_counts
=
torch
.
zeros
(
num_experts
,
dtype
=
torch
.
int
,
device
=
a
.
device
)
for
token
in
range
(
num_tokens
):
for
j
in
range
(
topk
):
expert_id
=
topk_ids
[
token
,
j
]
idx
=
token_counts
[
expert_id
]
b_a
[
expert_id
,
idx
:
idx
+
1
,
:]
=
a
[
token
,
:]
token_counts
[
expert_id
]
=
token_counts
[
expert_id
]
+
1
return
b_a
,
tokens_per_expert
def
torch_finalize
(
b_out
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
=
topk_ids
.
shape
[
0
]
num_experts
=
b_out
.
shape
[
0
]
K
=
b_out
.
shape
[
-
1
]
out
=
torch
.
zeros
((
num_tokens
,
K
),
dtype
=
b_out
.
dtype
,
device
=
b_out
.
device
)
expert_counts
=
torch
.
zeros
(
num_experts
,
dtype
=
torch
.
int
,
device
=
b_out
.
device
)
for
token
in
range
(
num_tokens
):
expert_ids
=
topk_ids
[
token
]
for
i
in
range
(
expert_ids
.
numel
()):
expert_id
=
expert_ids
[
i
]
idx
=
expert_counts
[
expert_id
]
out
[
token
,
:]
=
out
[
token
,
:]
+
b_out
[
expert_id
,
idx
:
idx
+
1
,
:]
*
topk_weight
[
token
,
i
]
expert_counts
[
expert_id
]
=
expert_counts
[
expert_id
]
+
1
return
out
def
torch_batched_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
num_experts
=
w1
.
shape
[
0
]
b_a
,
tokens_per_expert
=
torch_prepare
(
a
,
topk_ids
,
num_experts
)
assert
b_a
.
dim
()
==
3
num_tokens
,
topk
=
topk_ids
.
shape
_
,
max_num_tokens
,
K
=
b_a
.
shape
assert
num_experts
==
b_a
.
shape
[
0
]
and
w2
.
shape
[
1
]
==
K
out
=
torch
.
zeros
((
num_experts
,
max_num_tokens
,
K
),
dtype
=
b_a
.
dtype
,
device
=
b_a
.
device
)
tmp
=
torch
.
empty
((
max_num_tokens
,
w1
.
shape
[
1
]
//
2
),
dtype
=
b_a
.
dtype
,
device
=
b_a
.
device
)
for
expert
in
range
(
num_experts
):
num
=
tokens_per_expert
[
expert
]
if
num
>
0
:
torch
.
ops
.
_C
.
silu_and_mul
(
tmp
[:
num
],
b_a
[
expert
,
:
num
,
:]
@
w1
[
expert
].
transpose
(
0
,
1
))
out
[
expert
,
:
num
,
:]
=
tmp
[:
num
]
@
w2
[
expert
].
transpose
(
0
,
1
)
return
torch_finalize
(
out
,
topk_weight
,
topk_ids
)
def
batched_moe
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
num_experts
=
w1
.
shape
[
0
]
fused_experts
=
FusedMoEModularKernel
(
BatchedPrepareAndFinalize
(
a
.
shape
[
0
],
world_size
=
1
,
dp_size
=
1
,
rank
=
0
),
BatchedExperts
(
max_num_tokens
=
a
.
shape
[
0
],
dp_size
=
1
,
world_size
=
1
))
return
fused_experts
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
num_experts
)
# Note: same as torch_moe but with fused_topk factored out.
def
torch_moe2
(
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
M
,
K
=
a
.
shape
topk
=
topk_ids
.
shape
[
1
]
a
=
a
.
view
(
M
,
-
1
,
K
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
K
)
out
=
torch
.
zeros
(
M
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
num_experts
=
w1
.
shape
[
0
]
for
i
in
range
(
num_experts
):
mask
=
(
topk_ids
==
i
).
view
(
-
1
)
if
mask
.
sum
():
out
[
mask
]
=
SiluAndMul
()(
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
))
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
M
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
M
,
-
1
,
1
).
to
(
out
.
dtype
)).
sum
(
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
1024
,
2048
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
512
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
def
test_fused_moe_batched_experts
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
):
current_platform
.
seed_everything
(
7
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
with
set_current_vllm_config
(
vllm_config
):
topk_weight
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
baseline_output
=
torch_moe2
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
torch_output
=
torch_batched_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
batched_output
=
batched_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
torch
.
testing
.
assert_close
(
baseline_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
baseline_output
,
batched_output
,
atol
=
2e-2
,
rtol
=
0
)
def
rank_chunk
(
num
:
int
,
r
:
int
,
w
:
int
)
->
int
:
rem
=
num
%
w
return
(
num
//
w
)
+
(
1
if
r
<
rem
else
0
)
def
chunk_by_rank
(
t
:
torch
.
Tensor
,
r
:
int
,
w
:
int
)
->
torch
.
Tensor
:
chunk
=
rank_chunk
(
t
.
shape
[
0
],
r
,
w
)
return
t
[(
r
*
chunk
):(
r
+
1
)
*
chunk
]
def
pplx_prepare_finalize
(
pgi
:
ProcessGroupInfo
,
dp_size
:
int
,
a
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.pplx_prepare_finalize
import
(
PplxPrepareAndFinalize
)
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
topk
=
topk_ids
.
shape
[
1
]
num_tokens
,
hidden_dim
=
a
.
shape
block_size
=
128
device
=
pgi
.
device
rank
=
pgi
.
rank
world_size
=
pgi
.
world_size
max_num_tokens
=
rank_chunk
(
num_tokens
,
0
,
world_size
)
ata
=
AllToAll
.
internode
(
max_num_tokens
=
max_num_tokens
,
num_experts
=
num_experts
,
experts_per_token
=
topk
,
rank
=
rank
,
world_size
=
world_size
,
dp_size
=
dp_size
,
hidden_dim
=
hidden_dim
,
hidden_dim_bytes
=
hidden_dim
*
a
.
dtype
.
itemsize
,
hidden_dim_scale_bytes
=
(
0
if
a
.
dtype
.
itemsize
!=
1
else
((
hidden_dim
+
block_size
-
1
)
//
block_size
*
torch
.
float32
.
itemsize
)),
)
topk_ids
=
topk_ids
.
to
(
dtype
=
torch
.
uint32
)
prepare_finalize
=
PplxPrepareAndFinalize
(
ata
,
max_num_tokens
,
world_size
,
rank
,
dp_size
,
a
.
dtype
,
)
a_chunk
=
chunk_by_rank
(
a
,
rank
,
world_size
).
to
(
device
)
chunk_topk_weight
=
chunk_by_rank
(
topk_weight
,
rank
,
world_size
).
to
(
device
)
chunk_topk_ids
=
chunk_by_rank
(
topk_ids
,
rank
,
world_size
).
to
(
device
)
b_a
,
b_a_scale
,
expert_num_tokens
=
prepare_finalize
.
prepare
(
a_chunk
,
None
,
None
,
chunk_topk_weight
,
chunk_topk_ids
,
num_experts
,
None
,
False
,
)
b_a
=
b_a
*
1.5
out
=
torch
.
full
(
(
max_num_tokens
,
hidden_dim
),
torch
.
nan
,
dtype
=
a
.
dtype
,
device
=
device
,
)
prepare_finalize
.
finalize
(
out
,
b_a
,
chunk_topk_weight
,
chunk_topk_ids
,
False
,
)
torch
.
cuda
.
synchronize
()
ata
.
destroy
()
num_tokens
=
a_chunk
.
shape
[
0
]
return
out
[:
num_tokens
]
def
_pplx_prepare_finalize
(
pgi
:
ProcessGroupInfo
,
dp_size
:
int
,
a
:
torch
.
Tensor
,
score
:
torch
.
Tensor
,
topk
:
torch
.
Tensor
,
num_experts
:
int
,
):
uid
=
nvshmem_get_unique_id
(
)
if
pgi
.
rank
==
0
else
nvshmem_alloc_empty_unique_id
()
torch
.
distributed
.
broadcast
(
uid
,
src
=
0
)
nvshmem_init
(
uid
,
pgi
.
rank
,
pgi
.
world_size
)
device
=
pgi
.
device
topk_weight
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
k
=
a
.
shape
[
1
]
a_rep
=
torch
.
repeat_interleave
(
a
,
topk
,
dim
=
0
).
to
(
device
)
torch_output
=
(
a_rep
.
view
(
-
1
,
topk
,
k
)
*
1.5
*
topk_weight
.
view
(
-
1
,
topk
,
1
).
to
(
device
)).
sum
(
dim
=
1
).
to
(
a
.
dtype
)
pplx_output
=
pplx_prepare_finalize
(
pgi
,
dp_size
,
a
,
topk_weight
,
topk_ids
,
num_experts
)
torch_output
=
chunk_by_rank
(
torch_output
,
pgi
.
rank
,
pgi
.
world_size
).
to
(
pplx_output
.
device
)
torch
.
testing
.
assert_close
(
pplx_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
nvshmem_finalize
()
# TODO (bnell): this test point does not work for odd M due to how the test is
# written, not due to limitations of the pplx kernels. The pplx_moe
# test below is able to deal with odd M.
@
pytest
.
mark
.
parametrize
(
"mnk"
,
PPLX_PREPARE_COMBOS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
requires_pplx
def
test_pplx_prepare_finalize
(
mnk
:
tuple
[
int
,
int
,
int
],
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
world_dp_size
:
tuple
[
int
,
int
],
):
current_platform
.
seed_everything
(
7
)
m
,
n
,
k
=
mnk
world_size
,
dp_size
=
world_dp_size
device
=
"cuda"
a
=
torch
.
randn
((
m
,
k
),
device
=
device
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
device
,
dtype
=
dtype
)
parallel_launch
(
world_size
,
_pplx_prepare_finalize
,
dp_size
,
a
,
score
,
topk
,
e
)
def
pplx_moe
(
rank
:
int
,
world_size
:
int
,
dp_size
:
int
,
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weight
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_compile
:
bool
=
True
,
use_cudagraphs
:
bool
=
True
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.pplx_prepare_finalize
import
(
PplxPrepareAndFinalize
)
device
=
torch
.
device
(
"cuda"
,
rank
)
hidden_dim
=
a
.
shape
[
1
]
num_experts
=
w1
.
shape
[
0
]
block_size
=
128
topk
=
topk_ids
.
shape
[
1
]
max_num_tokens
=
rank_chunk
(
a
.
shape
[
0
],
0
,
world_size
)
ata
=
AllToAll
.
internode
(
max_num_tokens
=
max_num_tokens
,
num_experts
=
num_experts
,
experts_per_token
=
topk
,
rank
=
rank
,
world_size
=
world_size
,
dp_size
=
dp_size
,
hidden_dim
=
hidden_dim
,
hidden_dim_bytes
=
hidden_dim
*
a
.
dtype
.
itemsize
,
hidden_dim_scale_bytes
=
(
0
if
a
.
dtype
.
itemsize
!=
1
else
((
hidden_dim
+
block_size
-
1
)
//
block_size
*
torch
.
float32
.
itemsize
)),
)
topk_ids
=
topk_ids
.
to
(
dtype
=
torch
.
uint32
)
prepare_finalize
=
PplxPrepareAndFinalize
(
ata
,
max_num_tokens
,
world_size
,
rank
,
dp_size
,
)
experts
=
BatchedTritonExperts
(
max_num_tokens
=
a
.
shape
[
0
],
world_size
=
world_size
,
dp_size
=
dp_size
)
fused_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk
=
chunk_by_rank
(
a
,
rank
,
world_size
).
to
(
device
)
chunk_topk_weight
=
chunk_by_rank
(
topk_weight
,
rank
,
world_size
).
to
(
device
)
chunk_topk_ids
=
chunk_by_rank
(
topk_ids
,
rank
,
world_size
).
to
(
device
)
# Chunking weights like this only works for batched format
w1_chunk
=
chunk_by_rank
(
w1
,
rank
,
world_size
).
to
(
device
)
w2_chunk
=
chunk_by_rank
(
w2
,
rank
,
world_size
).
to
(
device
)
if
use_compile
:
_fused_experts
=
torch
.
compile
(
fused_experts
,
backend
=
'inductor'
,
fullgraph
=
True
)
else
:
_fused_experts
=
fused_experts
out
=
_fused_experts
(
a_chunk
,
w1_chunk
,
w2_chunk
,
chunk_topk_weight
,
chunk_topk_ids
,
global_num_experts
=
num_experts
)
if
use_cudagraphs
:
out
.
fill_
(
0
)
stream
=
torch
.
cuda
.
Stream
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
,
stream
=
stream
):
out
=
_fused_experts
(
a_chunk
,
w1_chunk
,
w2_chunk
,
chunk_topk_weight
,
chunk_topk_ids
,
global_num_experts
=
num_experts
)
torch
.
cuda
.
synchronize
()
graph
.
replay
()
torch
.
cuda
.
synchronize
()
ata
.
destroy
()
return
out
def
_batched_moe
(
pgi
,
dp_size
,
a
,
w1
,
w2
,
topk_weight
,
topk_ids
):
assert
torch
.
cuda
.
current_device
()
==
pgi
.
local_rank
num_experts
=
w1
.
shape
[
0
]
device
=
pgi
.
device
rank
=
pgi
.
rank
world_size
=
pgi
.
world_size
max_num_tokens
=
rank_chunk
(
a
.
shape
[
0
],
0
,
world_size
)
prepare_finalize
=
BatchedPrepareAndFinalize
(
max_num_tokens
=
max_num_tokens
,
world_size
=
world_size
,
dp_size
=
dp_size
,
rank
=
rank
,
)
experts
=
BatchedExperts
(
max_num_tokens
=
a
.
shape
[
0
],
world_size
=
1
,
dp_size
=
1
)
fused_experts
=
FusedMoEModularKernel
(
prepare_finalize
,
experts
,
)
# Note: workers with the same dp_rank must use the exact same inputs.
a_chunk
=
chunk_by_rank
(
a
,
rank
,
world_size
).
to
(
device
)
chunk_topk_weight
=
chunk_by_rank
(
topk_weight
,
rank
,
world_size
).
to
(
device
)
chunk_topk_ids
=
chunk_by_rank
(
topk_ids
,
rank
,
world_size
).
to
(
device
)
out
=
fused_experts
(
a_chunk
,
# Chunking weights like this only works for batched format
chunk_by_rank
(
w1
,
rank
,
world_size
).
to
(
device
),
chunk_by_rank
(
w2
,
rank
,
world_size
).
to
(
device
),
chunk_topk_weight
,
chunk_topk_ids
,
global_num_experts
=
num_experts
)
return
out
def
_pplx_moe
(
pgi
:
ProcessGroupInfo
,
dp_size
:
int
,
a
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
score
:
torch
.
Tensor
,
topk
:
int
,
):
uid
=
nvshmem_get_unique_id
(
)
if
pgi
.
rank
==
0
else
nvshmem_alloc_empty_unique_id
()
torch
.
distributed
.
broadcast
(
uid
,
src
=
0
)
nvshmem_init
(
uid
,
pgi
.
rank
,
pgi
.
world_size
)
m
,
k
=
a
.
shape
e
,
_
,
n
=
w2
.
shape
moe_config
=
get_default_config
(
m
,
e
,
n
,
k
,
topk
,
a
.
dtype
,
False
)
with
set_current_vllm_config
(
vllm_config
),
override_config
(
moe_config
):
topk_weight
,
topk_ids
,
_
=
fused_topk
(
a
,
score
,
topk
,
False
)
torch_output
=
torch_moe2
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
pplx_output
=
pplx_moe
(
pgi
.
rank
,
pgi
.
world_size
,
dp_size
,
a
,
w1
,
w2
,
topk_weight
,
topk_ids
)
# TODO (bnell): fix + re-enable
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
# topk_ids)
torch_output
=
chunk_by_rank
(
torch_output
,
pgi
.
rank
,
pgi
.
world_size
).
to
(
pplx_output
.
device
)
torch
.
testing
.
assert_close
(
pplx_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
#torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0)
nvshmem_finalize
()
@
pytest
.
mark
.
parametrize
(
"mnk"
,
PPLX_MOE_COMBOS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"world_dp_size"
,
[[
2
,
1
]])
@
requires_pplx
def
test_pplx_moe
(
mnk
:
tuple
[
int
,
int
,
int
],
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
world_dp_size
:
tuple
[
int
,
int
],
):
current_platform
.
seed_everything
(
7
)
m
,
n
,
k
=
mnk
world_size
,
dp_size
=
world_dp_size
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
parallel_launch
(
world_size
,
_pplx_moe
,
dp_size
,
a
,
w1
,
w2
,
score
,
topk
)
tests/kernels/moe/test_rocm_aiter_topk.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
# This is a test for the AITER ops.
# It tests if the AITER ops are
# 1. correctly registered as custom ops
# 2. correctly defined the relationship between
# implementation and fake function
# 3. can be used with torch.compile
# This file will be skipped if AITER is not installed
# and the platform is not ROCm.
import
importlib.util
import
pytest
import
torch
# this import statement is needed to ensure the ops are registered
import
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
# noqa: F401
from
vllm.platforms
import
current_platform
# need to import once to ensure the ops are registered
# Check if aiter package is installed
aiter_available
=
importlib
.
util
.
find_spec
(
"aiter"
)
is
not
None
pytestmark
=
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_rocm
()
and
aiter_available
),
reason
=
"AITER ops are only available on ROCm with aiter package installed"
)
def
test_rocm_aiter_biased_grouped_topk_custom_op_registration
():
"""Test that the custom op is correctly registered."""
# Check if the op exists in torch.ops.vllm
assert
hasattr
(
torch
.
ops
.
vllm
,
'rocm_aiter_biased_grouped_topk'
)
# Check if the op is callable
assert
callable
(
torch
.
ops
.
vllm
.
rocm_aiter_biased_grouped_topk
)
def
test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility
():
"""Test that the op can be used with torch.compile."""
# Create test tensors
token
=
64
expert
=
256
num_expert_group
=
8
topk
=
8
topk_group
=
4
renormalize
=
True
scale_factor
=
1.0
gating_output
=
torch
.
randn
((
token
,
expert
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
e_score_correction_bias
=
torch
.
randn
((
expert
,
),
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
device
=
gating_output
.
device
topk_ids
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
int32
,
device
=
device
)
topk_weights
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
# Define a function that uses the op
def
biased_grouped_topk_fn
(
gating_output
,
e_score_correction_bias
,
topk_weights
,
topk_ids
):
return
torch
.
ops
.
vllm
.
rocm_aiter_biased_grouped_topk
(
gating_output
,
e_score_correction_bias
,
topk_weights
,
topk_ids
,
num_expert_group
,
topk_group
,
renormalize
,
scale_factor
)
# Verify the op's fake implementation
torch
.
library
.
opcheck
(
torch
.
ops
.
vllm
.
rocm_aiter_biased_grouped_topk
,
(
gating_output
,
e_score_correction_bias
,
topk_weights
,
topk_ids
),
kwargs
=
{
"num_expert_group"
:
num_expert_group
,
"topk_group"
:
topk_group
,
"need_renorm"
:
renormalize
,
"routed_scaling_factor"
:
scale_factor
},
test_utils
=
(
"test_faketensor"
))
# Compile the function with appropriate settings
compiled_fn
=
torch
.
compile
(
biased_grouped_topk_fn
,
fullgraph
=
True
,
backend
=
"inductor"
,
mode
=
"reduce-overhead"
,
dynamic
=
False
)
topk_weights_original
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
topk_ids_original
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
int32
,
device
=
device
)
topk_weights_compiled
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
topk_ids_compiled
=
torch
.
empty
((
token
,
topk
),
dtype
=
torch
.
int32
,
device
=
device
)
# Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode)
biased_grouped_topk_fn
(
gating_output
,
e_score_correction_bias
,
topk_weights_original
,
topk_ids_original
)
compiled_fn
(
gating_output
,
e_score_correction_bias
,
topk_weights_compiled
,
topk_ids_compiled
)
# Sort the results for comparison since the order might not be deterministic
topk_ids_original
,
indices_original
=
torch
.
sort
(
topk_ids_original
)
topk_weights_original
=
torch
.
gather
(
topk_weights_original
,
1
,
indices_original
)
topk_ids_compiled
,
indices_compiled
=
torch
.
sort
(
topk_ids_compiled
)
topk_weights_compiled
=
torch
.
gather
(
topk_weights_compiled
,
1
,
indices_compiled
)
# Verify results match
assert
torch
.
allclose
(
topk_weights_original
,
topk_weights_compiled
,
rtol
=
1e-2
,
atol
=
1e-2
)
assert
torch
.
allclose
(
topk_ids_original
,
topk_ids_compiled
)
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
View file @
7a985548
...
@@ -7,6 +7,7 @@ import pytest
...
@@ -7,6 +7,7 @@ import pytest
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0):
...
@@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
allow_module_level
=
True
)
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
def
native_w8a8_per_token_matmul
(
A
,
B
,
As
,
Bs
,
output_dtype
=
torch
.
float16
):
def
native_w8a8_per_token_matmul
(
A
,
B
,
As
,
Bs
,
output_dtype
=
torch
.
float16
):
"""Matrix multiplication function that supports per-token input
"""Matrix multiplication function that supports per-token input
...
@@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
...
@@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
w2_s
=
torch
.
rand
(
E
,
K
,
device
=
w2_fp32
.
device
)
*
factor_for_scale
w2_s
=
torch
.
rand
(
E
,
K
,
device
=
w2_fp32
.
device
)
*
factor_for_scale
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
ref_out
=
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
)
with
set_current_vllm_config
(
vllm_config
):
out
=
fused_moe
(
ref_out
=
torch_w8a8_per_column_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
)
a
,
out
=
fused_moe
(
w1
,
a
,
w2
,
w1
,
score
,
w2
,
topk
,
score
,
renormalize
=
False
,
topk
,
use_fp8_w8a8
=
True
,
# using fp8
renormalize
=
False
,
per_channel_quant
=
True
,
use_fp8_w8a8
=
True
,
# using fp8
w1_scale
=
w1_s
,
per_channel_quant
=
True
,
w2_scale
=
w2_s
,
w1_scale
=
w1_s
,
block_shape
=
None
,
# Not using block quantization
w2_scale
=
w2_s
,
)
block_shape
=
None
,
# Not using block quantization
)
# Check results
# Check results
rel_diff
=
(
torch
.
mean
(
rel_diff
=
(
torch
.
mean
(
...
...
tests/kernels/quantization/nvfp4_utils.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
import
torch
from
vllm.scalar_type
import
scalar_types
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
kE2M1ToFloat
=
torch
.
tensor
([
0.
,
0.5
,
1.
,
1.5
,
2.
,
3.
,
4.
,
6.
],
dtype
=
torch
.
float32
)
def
convert_swizzled_to_linear
(
a_sf_swizzled
:
torch
.
Tensor
,
m
,
k
,
block_size
):
m_tiles
=
(
m
+
128
-
1
)
//
128
f
=
block_size
*
4
k_tiles
=
(
k
+
f
-
1
)
//
f
tmp
=
torch
.
reshape
(
a_sf_swizzled
,
(
1
,
m_tiles
,
k_tiles
,
32
,
4
,
4
))
tmp
=
torch
.
permute
(
tmp
,
(
0
,
1
,
4
,
3
,
2
,
5
))
out
=
tmp
.
reshape
(
m_tiles
*
128
,
k_tiles
*
f
//
block_size
)
return
out
[
0
:
m
,
0
:
k
]
def
dequantize_nvfp4_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert
tensor_fp4
.
dtype
==
torch
.
uint8
m
,
packed_k
=
tensor_fp4
.
shape
k
=
packed_k
*
2
tensor_f32
=
break_fp4_bytes
(
tensor_fp4
,
dtype
)
tensor_f32
=
tensor_f32
.
reshape
(
m
,
k
//
block_size
,
block_size
)
tensor_sf
=
tensor_sf
.
view
(
torch
.
float8_e4m3fn
)
tensor_sf
=
convert_swizzled_to_linear
(
tensor_sf
,
m
,
k
,
block_size
)
tensor_sf_dtype
=
tensor_sf
.
to
(
torch
.
float32
)
/
global_scale
# scale the tensor
out
=
(
tensor_f32
*
tensor_sf_dtype
.
unsqueeze
(
-
1
)).
reshape
(
m
,
k
)
return
out
.
to
(
dtype
=
dtype
)
def
break_fp4_bytes
(
a
,
dtype
):
assert
a
.
dtype
==
torch
.
uint8
m
,
n
=
a
.
shape
# Vectorized nibble processing
a_flat
=
a
.
flatten
()
high
=
(
a_flat
&
0xF0
)
>>
4
# Upper nibbles
low
=
a_flat
&
0x0F
# Lower nibbles
# Combine nibbles for batch processing
combined
=
torch
.
stack
((
low
,
high
),
dim
=
1
).
flatten
()
# Vectorized sign and magnitude extraction
signs
=
(
combined
&
0x08
).
to
(
torch
.
bool
)
# Sign bits
abs_vals
=
(
combined
&
0x07
).
to
(
torch
.
long
)
# Magnitude indices
# Device-aware lookup and sign application
kE2M1
=
kE2M1ToFloat
.
to
(
device
=
a
.
device
)
values
=
kE2M1
[
abs_vals
]
*
torch
.
where
(
signs
,
-
1.0
,
1.0
)
# Reshape to final form
return
values
.
reshape
(
m
,
n
*
2
).
to
(
dtype
=
dtype
)
tests/kernels/quantization/test_awq_marlin.py
deleted
100644 → 0
View file @
45d3785c
# SPDX-License-Identifier: Apache-2.0
"""Test AWQ with fused MoE Marlin kernels.
Run `pytest tests/kernels/test_awq_marlin.py`.
"""
import
pytest
import
torch
import
vllm.model_executor.layers.fused_moe
# noqa
from
tests.kernels.utils
import
(
compute_max_diff
,
stack_and_dev
,
torch_moe
,
torch_moe_single
)
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
awq_marlin_quantize
)
from
vllm.scalar_type
import
scalar_types
NUM_EXPERTS
=
[
8
,
64
]
TOP_KS
=
[
2
,
6
]
GROUP_SIZES
=
[
-
1
,
32
,
128
]
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
33
,
64
,
222
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
GROUP_SIZES
)
@
pytest
.
mark
.
skipif
(
not
(
ops
.
supports_moe_ops
and
hasattr
(
torch
.
ops
.
_moe_C
,
"marlin_gemm_moe"
)),
reason
=
"Marlin is not supported on this GPU type."
)
def
test_fused_marlin_moe_awq
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
group_size
:
int
,
):
torch
.
manual_seed
(
7
)
num_bits
=
4
quant_type
=
scalar_types
.
uint4
dtype
=
torch
.
float16
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w_ref1_l
=
[]
qweights1_l
=
[]
scales1_l
=
[]
zp1_l
=
[]
for
i
in
range
(
w1
.
shape
[
0
]):
w_ref1
,
qweight1
,
scales1
,
zp1
=
awq_marlin_quantize
(
w1
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w_ref1_l
.
append
(
w_ref1
)
qweights1_l
.
append
(
qweight1
)
scales1_l
.
append
(
scales1
)
zp1_l
.
append
(
zp1
)
w_ref1
=
stack_and_dev
(
w_ref1_l
)
qweight1
=
stack_and_dev
(
qweights1_l
).
contiguous
()
scales1
=
stack_and_dev
(
scales1_l
)
zp1
=
stack_and_dev
(
zp1_l
)
w_ref2_l
=
[]
qweights2_l
=
[]
scales2_l
=
[]
zp2_l
=
[]
for
i
in
range
(
w2
.
shape
[
0
]):
w_ref2
,
qweight2
,
scales2
,
zp2
=
awq_marlin_quantize
(
w2
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w_ref2_l
.
append
(
w_ref2
)
qweights2_l
.
append
(
qweight2
)
scales2_l
.
append
(
scales2
)
zp2_l
.
append
(
zp2
)
w_ref2
=
stack_and_dev
(
w_ref2_l
)
qweight2
=
stack_and_dev
(
qweights2_l
).
contiguous
()
scales2
=
stack_and_dev
(
scales2_l
)
zp2
=
stack_and_dev
(
zp2_l
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
,
topk
,
False
)
marlin_output
=
torch
.
ops
.
vllm
.
fused_marlin_moe
(
a
,
qweight1
,
qweight2
,
scales1
,
scales2
,
score
,
topk_weights
,
topk_ids
,
w1_zeros
=
zp1
,
w2_zeros
=
zp2
,
num_bits
=
num_bits
,
)
torch_output
=
torch_moe
(
a
,
w_ref1
.
transpose
(
1
,
2
),
w_ref2
.
transpose
(
1
,
2
),
score
,
topk
,
None
)
assert
compute_max_diff
(
marlin_output
,
torch_output
)
<
4e-2
@
pytest
.
mark
.
skip
(
"This test is here for the sake of debugging, "
"don't run it in automated tests."
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
64
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
128
,
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
1024
,
512
])
@
pytest
.
mark
.
parametrize
(
"e"
,
[
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
6
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
,
32
,
64
,
128
])
def
test_single_marlin_moe_multiply_awq
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
group_size
:
int
,
):
torch
.
manual_seed
(
7
)
num_bits
=
4
quant_type
=
scalar_types
.
uint4
dtype
=
torch
.
float16
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w
=
torch
.
randn
((
e
,
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w_ref_l
=
[]
qweights_l
=
[]
scales_l
=
[]
zp_l
=
[]
for
i
in
range
(
w
.
shape
[
0
]):
w_ref
,
qweight
,
scales
,
zp
=
awq_marlin_quantize
(
w
[
i
].
transpose
(
1
,
0
),
quant_type
,
group_size
)
w_ref_l
.
append
(
w_ref
)
qweights_l
.
append
(
qweight
)
scales_l
.
append
(
scales
)
zp_l
.
append
(
zp
)
w_ref
=
stack_and_dev
(
w_ref_l
)
qweight
=
stack_and_dev
(
qweights_l
).
contiguous
()
scales
=
stack_and_dev
(
scales_l
).
contiguous
()
zp
=
stack_and_dev
(
zp_l
).
contiguous
()
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
marlin_output
=
torch
.
ops
.
vllm
.
single_marlin_moe
(
a
,
qweight
,
scales
,
score
,
topk
,
renormalize
=
False
,
w_zeros
=
zp
,
num_bits
=
num_bits
)
torch_output
=
torch_moe_single
(
a
,
w_ref
.
transpose
(
1
,
2
),
score
,
topk
)
assert
compute_max_diff
(
marlin_output
,
torch_output
)
<
1e-2
tests/kernels/quantization/test_block_fp8.py
View file @
7a985548
...
@@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
...
@@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
deep_gemm_moe_fp8
)
_valid_deep_gemm_shape
,
deep_gemm_moe_fp8
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_topk
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
)
moe_align_block_size
)
...
@@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0):
...
@@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0):
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
pytest
.
skip
(
"FP8 Triton requires CUDA 9.0 or higher"
,
allow_module_level
=
True
)
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
# Test configurations
# Test configurations
DTYPES
=
[
torch
.
bfloat16
]
# [torch.half, torch.bfloat16, torch.float32]
DTYPES
=
[
torch
.
bfloat16
]
# [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS
=
[
7
,
83
,
2048
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
...
@@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
...
@@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
# Set the context to avoid lots of warning spam.
# Set the context to avoid lots of warning spam.
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
out
=
fused_moe
(
out
=
fused_moe
(
a
,
a
,
...
@@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
...
@@ -258,6 +261,7 @@ def per_block_cast_to_fp8(
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"M,N,K,block_size,out_dtype,seed"
,
"M,N,K,block_size,out_dtype,seed"
,
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
))
itertools
.
product
(
M
,
N
,
K
,
BLOCK_SIZE
,
OUT_DTYPES
,
SEEDS
))
@
pytest
.
mark
.
skipif
(
not
dg_available
,
reason
=
"DeepGemm kernels not available."
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_w8a8_block_fp8_deep_gemm_matmul
(
M
,
N
,
K
,
block_size
,
out_dtype
,
seed
):
def
test_w8a8_block_fp8_deep_gemm_matmul
(
M
,
N
,
K
,
block_size
,
out_dtype
,
seed
):
# only aligned sizes
# only aligned sizes
...
@@ -338,7 +342,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
...
@@ -338,7 +342,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
M
,
K
=
a
.
shape
M
,
K
=
a
.
shape
N
=
w2
.
shape
[
-
1
]
N
=
w2
.
shape
[
-
1
]
topk_weight
,
topk_ids
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
topk_weight
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
block_m
=
deep_gemm
.
get_m_alignment_for_contiguous_layout
()
block_m
=
deep_gemm
.
get_m_alignment_for_contiguous_layout
()
...
@@ -380,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
...
@@ -380,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
block_size
=
[
block_m
,
block_m
]
block_size
=
[
block_m
,
block_m
]
dtype
=
torch
.
bfloat16
dtype
=
torch
.
bfloat16
# only aligned sizes
if
topk
>
E
:
if
(
N
%
block_m
!=
0
or
K
%
block_m
!=
0
or
topk
>
E
):
pytest
.
skip
(
f
"Skipping test: topk=
{
topk
}
> E=
{
E
}
"
)
pytest
.
skip
(
f
"Skipping test; bad size m=
{
M
}
, n=
{
N
}
, k=
{
K
}
, topk=
{
topk
}
, E=
{
E
}
"
)
if
N
<=
512
:
pytest
.
skip
(
"Skipping N <= 512 until performance issues solved."
)
vllm_config
=
VllmConfig
()
if
not
_valid_deep_gemm_shape
(
M
,
N
,
K
):
pytest
.
skip
(
f
"Skipping test: invalid size m=
{
M
}
, n=
{
N
}
, k=
{
K
}
"
)
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
...
@@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
...
@@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
ref_out
=
torch_w8a8_block_fp8_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
score
,
topk
,
block_size
)
topk
,
block_size
)
topk_weights
,
topk_ids
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
a
,
score
.
float
(),
topk
,
False
)
out
=
deep_gemm_moe_fp8
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
)
out
=
deep_gemm_moe_fp8
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weights
,
topk_ids
)
...
...
tests/kernels/quantization/test_block_int8.py
View file @
7a985548
...
@@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0):
...
@@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0):
pytest
.
skip
(
"INT8 Triton requires CUDA 7.0 or higher"
,
pytest
.
skip
(
"INT8 Triton requires CUDA 7.0 or higher"
,
allow_module_level
=
True
)
allow_module_level
=
True
)
vllm_config
=
VllmConfig
()
vllm_config
.
scheduler_config
.
max_num_seqs
=
128
vllm_config
.
scheduler_config
.
max_model_len
=
8192
# For test
# For test
def
native_per_token_group_quant_int8
(
x
,
def
native_per_token_group_quant_int8
(
x
,
...
@@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
...
@@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed):
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
# Set the context to avoid lots of warning spam.
# Set the context to avoid lots of warning spam.
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
out
=
fused_moe
(
out
=
fused_moe
(
a
,
a
,
...
...
tests/kernels/quantization/test_cutlass_scaled_mm.py
View file @
7a985548
...
@@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int,
...
@@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int,
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5e-
2
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
1.
5e-
1
)
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
...
@@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
...
@@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
return
return
if
m
%
a_scale_group_shape
[
0
]
!=
0
or
k
%
a_scale_group_shape
[
1
]
!=
0
:
if
m
%
a_scale_group_shape
[
0
]
!=
0
or
k
%
a_scale_group_shape
[
1
]
!=
0
:
return
return
if
m
%
4
!=
0
and
current_platform
.
has_device_capability
(
100
):
return
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
a_scale_group_shape
,
b_scale_group_shape
,
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
)
use_bias
)
...
...
tests/kernels/quantization/test_ggml.py
View file @
7a985548
...
@@ -36,3 +36,9 @@ def test_ggml_opcheck(quant_type):
...
@@ -36,3 +36,9 @@ def test_ggml_opcheck(quant_type):
opcheck
(
torch
.
ops
.
_C
.
ggml_moe_a8
,
opcheck
(
torch
.
ops
.
_C
.
ggml_moe_a8
,
(
x
,
qweight
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
(
x
,
qweight
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
quant_type
,
qweight
.
shape
[
0
],
1
,
x
.
shape
[
0
]))
quant_type
,
qweight
.
shape
[
0
],
1
,
x
.
shape
[
0
]))
topk_ids
=
torch
.
zeros
((
1
,
1
),
device
=
'cuda'
,
dtype
=
torch
.
int32
)
opcheck
(
torch
.
ops
.
_C
.
ggml_moe_a8_vec
,
(
x
,
qweight
,
topk_ids
,
1
,
quant_type
,
qweight
.
shape
[
0
],
x
.
shape
[
0
]))
tests/kernels/quantization/test_gguf.py
View file @
7a985548
...
@@ -151,20 +151,7 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
...
@@ -151,20 +151,7 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype,
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"top_k"
,
[
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
QUANT_TYPES
)
"quant_type"
,
[
# k-quants
GGMLQuantizationType
.
Q2_K
,
GGMLQuantizationType
.
Q3_K
,
GGMLQuantizationType
.
Q4_K
,
GGMLQuantizationType
.
Q5_K
,
GGMLQuantizationType
.
Q6_K
,
# standard quants
GGMLQuantizationType
.
Q4_0
,
GGMLQuantizationType
.
Q5_0
,
GGMLQuantizationType
.
Q8_0
,
])
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_moe
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
def
test_moe
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
quant_type
:
GGMLQuantizationType
,
top_k
:
int
):
quant_type
:
GGMLQuantizationType
,
top_k
:
int
):
...
@@ -174,7 +161,10 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
...
@@ -174,7 +161,10 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype,
x
=
torch
.
rand
((
num_tokens
,
H
),
dtype
=
dtype
,
device
=
"cuda"
)
x
=
torch
.
rand
((
num_tokens
,
H
),
dtype
=
dtype
,
device
=
"cuda"
)
topk_weights
=
torch
.
rand
(
num_tokens
,
top_k
,
device
=
"cuda"
,
dtype
=
dtype
)
topk_weights
=
torch
.
rand
(
num_tokens
,
top_k
,
device
=
"cuda"
,
dtype
=
dtype
)
topk_ids
=
torch
.
randint
(
0
,
E
,
(
num_tokens
,
top_k
),
device
=
"cuda"
)
topk_ids
=
torch
.
randint
(
0
,
E
,
(
num_tokens
,
top_k
),
device
=
"cuda"
,
dtype
=
torch
.
int32
)
tensors
=
get_gguf_MoE_tensors
(
hidden_size
,
quant_type
)
tensors
=
get_gguf_MoE_tensors
(
hidden_size
,
quant_type
)
...
...
tests/kernels/quantization/test_marlin_gemm.py
View file @
7a985548
...
@@ -18,9 +18,12 @@ from vllm.model_executor.layers.quantization.qqq import (
...
@@ -18,9 +18,12 @@ from vllm.model_executor.layers.quantization.qqq import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
,
GPTQ_MARLIN_MIN_THREAD_N
,
MARLIN_SUPPORTED_GROUP_SIZES
,
marlin_make_empty_g_idx
,
MARLIN_SUPPORTED_GROUP_SIZES
,
marlin_make_empty_g_idx
,
marlin_permute_scales
,
query_marlin_supported_quant_types
)
marlin_make_workspace_new
,
marlin_permute_scales
,
query_marlin_supported_quant_types
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp4
import
(
FP4_MARLIN_SUPPORTED_GROUP_SIZES
,
rand_marlin_weight_fp4_like
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
pack
_fp8_to
_int32
)
marlin_quant
_fp8_to
rch
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_test
import
(
MarlinWorkspace
,
awq_marlin_quantize
,
get_weight_perm
,
marlin_quantize
,
MarlinWorkspace
,
awq_marlin_quantize
,
get_weight_perm
,
marlin_quantize
,
marlin_weights
)
marlin_weights
)
...
@@ -73,7 +76,7 @@ def rand_data(shape, dtype=torch.float16):
...
@@ -73,7 +76,7 @@ def rand_data(shape, dtype=torch.float16):
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
False
))
query_marlin_supported_quant_types
(
False
,
False
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
...
@@ -138,7 +141,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
...
@@ -138,7 +141,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
Fals
e
))
query_marlin_supported_quant_types
(
Tru
e
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
def
test_awq_marlin_repack
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
def
test_awq_marlin_repack
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
...
@@ -189,9 +192,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
...
@@ -189,9 +192,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
reason
=
"Marlin is not supported on this GPU type."
)
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
())
query_marlin_supported_quant_types
(
False
))
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
"group_size"
,
set
(
MARLIN_SUPPORTED_GROUP_SIZES
+
FP4_MARLIN_SUPPORTED_GROUP_SIZES
))
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"act_order"
,
ACT_ORDER_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
@
pytest
.
mark
.
parametrize
(
"is_k_full"
,
K_FULL_OPTS
)
...
@@ -209,6 +213,7 @@ def test_gptq_marlin_gemm(
...
@@ -209,6 +213,7 @@ def test_gptq_marlin_gemm(
use_fp32_reduce
,
use_fp32_reduce
,
):
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
m_factor
,
n_factor
,
k_factor
=
mnk_factors
has_zp
=
quant_type
in
[
scalar_types
.
uint4
,
scalar_types
.
uint8
]
size_m
=
m_factor
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_k
=
k_chunk
*
k_factor
...
@@ -219,39 +224,74 @@ def test_gptq_marlin_gemm(
...
@@ -219,39 +224,74 @@ def test_gptq_marlin_gemm(
return
return
if
group_size
==
size_k
:
if
group_size
==
size_k
:
return
return
if
has_zp
:
return
if
size_k
%
group_size
!=
0
:
return
a_input
=
rand_data
((
size_m
,
size_k
))
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
if
quant_type
==
scalar_types
.
float4_e2m1f
:
b_weight
,
quant_type
,
group_size
,
act_order
)
if
group_size
!=
16
or
act_order
:
return
marlin_zp
=
marlin_make_empty_g_idx
(
marlin_s
.
device
)
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_s2
=
rand_marlin_weight_fp4_like
(
b_weight
.
T
,
group_size
)
g_idx
=
None
sort_indices
=
None
marlin_zp
=
None
elif
quant_type
==
scalar_types
.
float8_e4m3fn
:
if
group_size
not
in
[
-
1
,
128
]:
return
if
act_order
:
return
w_ref
,
marlin_q_w
,
marlin_s
=
marlin_quant_fp8_torch
(
b_weight
.
T
,
group_size
)
g_idx
=
None
sort_indices
=
None
marlin_zp
=
None
marlin_s2
=
None
elif
has_zp
:
if
group_size
==
16
:
return
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
=
awq_marlin_quantize
(
b_weight
,
quant_type
,
group_size
)
g_idx
=
None
sort_indices
=
None
marlin_s2
=
None
else
:
if
group_size
==
16
:
return
w_ref
,
marlin_q_w
,
marlin_s
,
g_idx
,
sort_indices
,
_
=
marlin_quantize
(
b_weight
,
quant_type
,
group_size
,
act_order
)
marlin_zp
=
None
marlin_s2
=
None
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
workspace
=
marlin_make_workspace_new
(
w_ref
.
device
)
GPTQ_MARLIN_MAX_PARALLEL
)
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
opcheck
(
torch
.
ops
.
_C
.
gptq_marlin_gemm
,
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_
zp
,
g_idx
,
sort_indices
,
(
a_input
,
None
,
marlin_q_w
,
marlin_s
,
marlin_
s2
,
marlin_zp
,
g_idx
,
workspace
.
scratch
,
quant_type
.
id
,
a_input
.
shape
[
0
],
sort_indices
,
workspace
,
quant_type
.
id
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
False
,
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
,
use_atomic_add
,
use_atomic_add
,
use_fp32_reduce
,
False
),
use_fp32_reduce
,
False
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
)
output
=
ops
.
gptq_marlin_gemm
(
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
a_input
,
None
,
marlin_q_w
,
marlin_q_w
,
marlin_s
,
marlin_s
,
marlin_s2
,
marlin_zp
,
marlin_zp
,
g_idx
,
g_idx
,
sort_indices
,
sort_indices
,
workspace
.
scratch
,
workspace
,
quant_type
,
quant_type
,
a_input
.
shape
[
0
],
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
is_k_full
,
is_k_full
=
is_k_full
,
has_zp
=
False
,
use_atomic_add
=
use_atomic_add
,
use_atomic_add
=
use_atomic_add
,
use_fp32_reduce
=
use_fp32_reduce
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
,
is_zp_float
=
False
,
...
@@ -326,143 +366,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
...
@@ -326,143 +366,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
assert
max_diff
<
0.04
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"num_bits"
,
[
8
])
@
pytest
.
mark
.
parametrize
(
"group_size"
,
[
-
1
])
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
def
test_fp8_marlin_gemm
(
k_chunk
,
n_chunk
,
num_bits
,
group_size
,
mnk_factors
,
dtype
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
a_input
=
rand_data
((
size_m
,
size_k
),
dtype
=
dtype
)
b_weight
=
rand_data
((
size_k
,
size_n
),
dtype
=
dtype
)
# WEIGHTS
fp8_weight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
b_weight
,
scale
=
None
)
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight
=
pack_fp8_to_int32
(
fp8_weight
)
# Repack weights to marlin format
marlin_qweight
=
ops
.
gptq_marlin_repack
(
b_q_weight
=
packed_gptq_qweight
,
perm
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
"cuda"
),
size_k
=
size_k
,
size_n
=
size_n
,
num_bits
=
8
,
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
weight_scale
.
repeat
(
1
,
size_n
).
to
(
a_input
.
dtype
).
to
(
"cuda"
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
size_k
,
size_n
=
size_n
,
group_size
=-
1
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
opcheck
(
torch
.
ops
.
_C
.
fp8_marlin_gemm
,
(
a_input
,
marlin_qweight
,
marlin_scales
,
workspace
.
scratch
,
num_bits
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
]))
output
=
ops
.
fp8_marlin_gemm
(
a
=
a_input
,
b_q_weight
=
marlin_qweight
,
b_scales
=
marlin_scales
,
workspace
=
workspace
.
scratch
,
num_bits
=
num_bits
,
size_m
=
a_input
.
shape
[
0
],
size_n
=
b_weight
.
shape
[
1
],
size_k
=
a_input
.
shape
[
1
],
)
output_ref
=
torch
.
matmul
(
a_input
,
b_weight
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"n_chunk"
,
MARLIN_N_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
query_marlin_supported_quant_types
(
True
))
@
pytest
.
mark
.
parametrize
(
"group_size"
,
MARLIN_SUPPORTED_GROUP_SIZES
)
@
pytest
.
mark
.
parametrize
(
"mnk_factors"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"use_fp32_reduce"
,
USE_FP32_REDUCE_OPTS
)
def
test_awq_marlin_gemm
(
k_chunk
,
n_chunk
,
quant_type
,
group_size
,
mnk_factors
,
use_fp32_reduce
,
):
m_factor
,
n_factor
,
k_factor
=
mnk_factors
size_m
=
m_factor
size_k
=
k_chunk
*
k_factor
size_n
=
n_chunk
*
n_factor
a_input
=
rand_data
((
size_m
,
size_k
))
b_weight
=
rand_data
((
size_k
,
size_n
))
w_ref
,
marlin_q_w
,
marlin_s
,
marlin_zp
=
awq_marlin_quantize
(
b_weight
,
quant_type
,
group_size
)
g_idx
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
marlin_q_w
.
device
)
sort_indices
=
torch
.
empty
(
0
,
dtype
=
torch
.
int
,
device
=
marlin_q_w
.
device
)
is_k_full
=
True
has_zp
=
True
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
marlin_q_w
,
marlin_s
,
marlin_zp
,
g_idx
,
sort_indices
,
workspace
.
scratch
,
quant_type
,
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
is_k_full
,
has_zp
=
has_zp
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
False
,
)
output_ref
=
torch
.
matmul
(
a_input
,
w_ref
)
torch
.
cuda
.
synchronize
()
max_diff
=
compute_max_diff
(
output
,
output_ref
)
assert
max_diff
<
0.04
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
reason
=
"Marlin is not supported on this GPU type."
)
reason
=
"Marlin is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
@
pytest
.
mark
.
parametrize
(
"k_chunk"
,
MARLIN_K_CHUNKS
)
...
@@ -508,23 +411,23 @@ def test_hqq_marlin_gemm(
...
@@ -508,23 +411,23 @@ def test_hqq_marlin_gemm(
g_idx
=
marlin_make_empty_g_idx
(
dev
)
g_idx
=
marlin_make_empty_g_idx
(
dev
)
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
dev
)
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
dev
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
workspace
=
marlin_make_workspace_new
(
b_weight
.
device
)
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
gptq_marlin_gemm
(
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
a_input
,
None
,
marlin_w_q
,
marlin_w_q
,
marlin_s
,
marlin_s
,
None
,
marlin_zp
,
marlin_zp
,
g_idx
,
g_idx
,
g_idx_sort_indices
,
g_idx_sort_indices
,
workspace
.
scratch
,
workspace
,
quant_type
,
quant_type
,
a_input
.
shape
[
0
],
a_input
.
shape
[
0
],
b_weight
.
shape
[
0
],
b_weight
.
shape
[
0
],
a_input
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
True
,
is_k_full
=
True
,
has_zp
=
True
,
use_fp32_reduce
=
use_fp32_reduce
,
use_fp32_reduce
=
use_fp32_reduce
,
is_zp_float
=
True
,
is_zp_float
=
True
,
)
)
...
@@ -621,23 +524,23 @@ def test_marlin_gemm_subset_input():
...
@@ -621,23 +524,23 @@ def test_marlin_gemm_subset_input():
b_weight
,
quant_type
,
group_size
,
False
)
b_weight
,
quant_type
,
group_size
,
False
)
marlin_zp
=
marlin_make_empty_g_idx
(
marlin_s
.
device
)
marlin_zp
=
marlin_make_empty_g_idx
(
marlin_s
.
device
)
workspace
=
MarlinWorkspace
(
size_n
,
GPTQ_MARLIN_MIN_THREAD_N
,
workspace
=
marlin_make_workspace_new
(
a_input
.
device
)
GPTQ_MARLIN_MAX_PARALLEL
)
output
=
ops
.
gptq_marlin_gemm
(
output
=
ops
.
gptq_marlin_gemm
(
a_input
,
a_input
,
None
,
marlin_q_w
,
marlin_q_w
,
marlin_s
,
marlin_s
,
None
,
marlin_zp
,
marlin_zp
,
g_idx
,
g_idx
,
sort_indices
,
sort_indices
,
workspace
.
scratch
,
workspace
,
quant_type
,
quant_type
,
a_input
.
shape
[
0
],
a_input
.
shape
[
0
],
b_weight
.
shape
[
1
],
b_weight
.
shape
[
1
],
a_input
.
shape
[
1
],
a_input
.
shape
[
1
],
is_k_full
=
True
,
is_k_full
=
True
,
has_zp
=
False
,
use_atomic_add
=
False
,
use_atomic_add
=
False
,
use_fp32_reduce
=
True
,
use_fp32_reduce
=
True
,
is_zp_float
=
False
,
is_zp_float
=
False
,
...
...
tests/kernels/quantization/test_nvfp4_quant.py
View file @
7a985548
...
@@ -17,7 +17,7 @@ PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48),
...
@@ -17,7 +17,7 @@ PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48),
SEEDS
=
[
42
]
SEEDS
=
[
42
]
CUDA_DEVICES
=
[
'cuda:0'
]
CUDA_DEVICES
=
[
'cuda:0'
]
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
n
.
max
()
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1f
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
# E2M1 to float
# E2M1 to float
...
...
tests/kernels/quantization/test_nvfp4_scaled_mm.py
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
pytest
import
torch
import
torch
from
nvfp4_utils
import
(
FLOAT4_E2M1_MAX
,
FLOAT8_E4M3_MAX
,
dequantize_nvfp4_to_dtype
)
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
if
not
current_platform
.
has_device_capability
(
100
):
if
not
current_platform
.
has_device_capability
(
100
):
pytest
.
skip
(
reason
=
"Nvfp4 Requires compute capability of 10 or above."
,
pytest
.
skip
(
reason
=
"Nvfp4 Requires compute capability of 10 or above."
,
...
@@ -19,95 +20,24 @@ SHAPES.extend(PAD_SHAPES)
...
@@ -19,95 +20,24 @@ SHAPES.extend(PAD_SHAPES)
SEEDS
=
[
42
]
SEEDS
=
[
42
]
CUDA_DEVICES
=
[
'cuda:0'
]
CUDA_DEVICES
=
[
'cuda:0'
]
FLOAT4_E2M1_MAX
=
scalar_types
.
float4_e2m1fn
.
max
()
FLOAT8_E4M3_MAX
=
torch
.
finfo
(
torch
.
float8_e4m3fn
).
max
kE2M1ToFloatArray
=
[
0.
,
0.5
,
1.
,
1.5
,
2.
,
3.
,
4.
,
6.
,
]
def
e2m1_to_fp32
(
int4_value
):
signBit
=
(
int4_value
&
0x8
)
int4_absValue
=
int4_value
&
0x7
float_result
=
kE2M1ToFloatArray
[
int4_absValue
]
if
(
signBit
):
float_result
=
-
float_result
return
float_result
def
break_fp4_bytes
(
a
,
dtype
):
assert
(
a
.
dtype
==
torch
.
uint8
)
m
,
n
=
a
.
shape
a
=
a
.
flatten
()
# Get upper 4 bits
highHalfByte
=
(
a
&
0xF0
)
>>
4
# Get lower 4 bits
lowHalfByte
=
a
&
0x0F
fH
=
torch
.
tensor
([
e2m1_to_fp32
(
x
)
for
x
in
highHalfByte
]).
to
(
a
.
device
)
fL
=
torch
.
tensor
([
e2m1_to_fp32
(
x
)
for
x
in
lowHalfByte
]).
to
(
a
.
device
)
# [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC]
out
=
torch
.
stack
((
fL
,
fH
),
dim
=-
1
).
reshape
(
m
,
n
*
2
)
return
out
def
convert_swizzled_to_linear
(
a_sf_swizzled
:
torch
.
Tensor
,
m
,
k
,
block_size
):
sf_m
,
sf_k
=
a_sf_swizzled
.
shape
m_tiles
=
(
m
+
128
-
1
)
//
128
f
=
block_size
*
4
k_tiles
=
(
k
+
f
-
1
)
//
f
tmp
=
torch
.
reshape
(
a_sf_swizzled
,
(
1
,
m_tiles
,
k_tiles
,
32
,
4
,
4
))
tmp
=
torch
.
permute
(
tmp
,
(
0
,
1
,
4
,
3
,
2
,
5
))
out
=
tmp
.
reshape
(
m_tiles
*
128
,
k_tiles
*
f
//
block_size
)
return
out
[
0
:
m
,
0
:
k
]
def
dequantize_to_dtype
(
tensor_fp4
,
tensor_sf
,
global_scale
,
dtype
,
device
,
block_size
=
16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert
tensor_fp4
.
dtype
==
torch
.
uint8
m
,
packed_k
=
tensor_fp4
.
shape
k
=
packed_k
*
2
tensor_f32
=
break_fp4_bytes
(
tensor_fp4
,
dtype
)
tensor_f32
=
tensor_f32
.
reshape
(
m
,
k
//
block_size
,
block_size
)
tensor_sf
=
tensor_sf
.
view
(
torch
.
float8_e4m3fn
)
tensor_sf
=
convert_swizzled_to_linear
(
tensor_sf
,
m
,
k
,
block_size
)
tensor_sf_dtype
=
tensor_sf
.
to
(
torch
.
float32
)
/
global_scale
# scale the tensor
out
=
(
tensor_f32
*
tensor_sf_dtype
.
unsqueeze
(
-
1
)).
reshape
(
m
,
k
)
return
out
def
get_ref_results
(
a_fp4
,
b_fp4
,
a_sf
,
b_sf
,
a_global_scale
,
b_global_scale
,
def
get_ref_results
(
a_fp4
,
b_fp4
,
a_sf
,
b_sf
,
a_global_scale
,
b_global_scale
,
m
,
n
,
dtype
,
block_size
,
device
):
m
,
n
,
dtype
,
block_size
,
device
):
_
,
m_k
=
a_fp4
.
shape
_
,
m_k
=
a_fp4
.
shape
_
,
n_k
=
b_fp4
.
shape
_
,
n_k
=
b_fp4
.
shape
assert
(
m_k
==
n_k
)
assert
(
m_k
==
n_k
)
a_in_dtype
=
dequantize_to_dtype
(
a_fp4
,
a_in_dtype
=
dequantize_
nvfp4_
to_dtype
(
a_fp4
,
a_sf
,
a_sf
,
a_global_scale
,
a_global_scale
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
block_size
=
block_size
)
block_size
=
block_size
)
b_in_dtype
=
dequantize_to_dtype
(
b_fp4
,
b_in_dtype
=
dequantize_
nvfp4_
to_dtype
(
b_fp4
,
b_sf
,
b_sf
,
b_global_scale
,
b_global_scale
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
,
device
=
device
,
block_size
=
block_size
)
block_size
=
block_size
)
return
torch
.
matmul
(
a_in_dtype
,
b_in_dtype
.
t
())
return
torch
.
matmul
(
a_in_dtype
,
b_in_dtype
.
t
())
...
...
tests/kernels/quantization/test_rocm_skinny_gemms.py
View file @
7a985548
...
@@ -8,7 +8,7 @@ from vllm.platforms import current_platform
...
@@ -8,7 +8,7 @@ from vllm.platforms import current_platform
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float16
]
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float16
]
M
=
[
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
4096
,
8192
]
M
=
[
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
4096
,
8192
]
K
=
[
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
# k % 8 == 0
K
=
[
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
6144
,
8192
]
# k % 8 == 0
N
=
[
1
,
2
,
3
,
4
]
N
=
[
1
,
2
,
3
,
4
]
SEEDS
=
[
0
]
SEEDS
=
[
0
]
...
@@ -58,8 +58,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
...
@@ -58,8 +58,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed):
@
pytest
.
mark
.
parametrize
(
"m"
,
M
+
[
28672
])
# m >= 16
@
pytest
.
mark
.
parametrize
(
"m"
,
M
+
[
28672
])
# m >= 16
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_rocm
(),
@
pytest
.
mark
.
skipif
(
reason
=
"only test for rocm"
)
not
(
current_platform
.
is_rocm
()
and
current_platform
.
supports_fp8
()),
reason
=
"only test for rocm fp8"
)
def
test_rocm_wvsplitk_fp8_kernel
(
n
,
k
,
m
,
dtype
,
seed
):
def
test_rocm_wvsplitk_fp8_kernel
(
n
,
k
,
m
,
dtype
,
seed
):
torch
.
manual_seed
(
seed
)
torch
.
manual_seed
(
seed
)
...
...
tests/kernels/test_fused_quant_activation.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
import
pytest
import
torch
import
vllm._custom_ops
as
ops
from
tests.kernels.utils
import
opcheck
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.platforms
import
current_platform
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float16
]
QUANT_DTYPES
=
[
current_platform
.
fp8_dtype
()]
NUM_TOKENS
=
[
1
,
17
,
86
,
1234
,
3045
]
# Arbitrary values for testing
HIDDEN_SIZES
=
[
16
,
48
,
128
,
1562
,
4096
]
# Arbitrary values for testing
SEEDS
=
[
0
]
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
def
ref_impl
(
silu_and_mul
:
SiluAndMul
,
x
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
silu_and_mul_out
=
silu_and_mul
.
forward_native
(
x
)
out
,
scales
=
ops
.
scaled_fp8_quant
(
silu_and_mul_out
,
scale
)
return
out
def
ops_impl
(
x
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out_shape
=
(
x
.
shape
[
0
],
x
.
shape
[
1
]
//
2
)
out
=
torch
.
empty
(
out_shape
,
dtype
=
current_platform
.
fp8_dtype
(),
device
=
x
.
device
)
torch
.
ops
.
_C
.
silu_and_mul_quant
(
out
,
x
,
scale
)
return
out
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"quant_dtype"
,
QUANT_DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_silu_and_mul
(
num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
quant_dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
set_default_device
(
device
)
layer
=
SiluAndMul
()
# Make inputs
scale
=
(
torch
.
randn
((
1
),
device
=
device
,
dtype
=
torch
.
float32
))
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
ref_out
=
ref_impl
(
layer
,
x
,
scale
)
ops_out
=
ops_impl
(
x
,
scale
)
assert
ref_out
.
dtype
==
quant_dtype
assert
ops_out
.
dtype
==
quant_dtype
assert
ref_out
.
shape
==
ops_out
.
shape
assert
torch
.
allclose
(
ref_out
.
to
(
dtype
=
torch
.
float32
),
ops_out
.
to
(
dtype
=
torch
.
float32
))
opcheck
(
torch
.
ops
.
_C
.
silu_and_mul_quant
,
(
ops_out
,
x
,
scale
))
tests/kv_transfer/test_disagg.py
View file @
7a985548
...
@@ -14,8 +14,8 @@ import torch
...
@@ -14,8 +14,8 @@ import torch
# Fixture to set up environment variables and teardown servers after tests
# Fixture to set up environment variables and teardown servers after tests
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
@
pytest
.
fixture
(
scope
=
"module"
,
autouse
=
True
)
def
setup_servers
():
def
setup_servers
():
if
torch
.
cuda
.
device_count
()
<
4
:
if
torch
.
cuda
.
device_count
()
<
2
:
pytest
.
skip
(
"Skipping test: fewer than
4
GPUs available"
)
pytest
.
skip
(
"Skipping test: fewer than
2
GPUs available"
)
# Set up environment variables
# Set up environment variables
VLLM_HOST_IP
=
subprocess
.
check_output
(
"hostname -I | awk '{print $1}'"
,
VLLM_HOST_IP
=
subprocess
.
check_output
(
"hostname -I | awk '{print $1}'"
,
...
...
tests/lora/conftest.py
View file @
7a985548
...
@@ -47,7 +47,7 @@ def dist_init():
...
@@ -47,7 +47,7 @@ def dist_init():
temp_file
=
tempfile
.
mkstemp
()[
1
]
temp_file
=
tempfile
.
mkstemp
()[
1
]
backend
=
"nccl"
backend
=
"nccl"
if
current_platform
.
is_cpu
():
if
current_platform
.
is_cpu
()
or
current_platform
.
is_tpu
()
:
backend
=
"gloo"
backend
=
"gloo"
init_distributed_environment
(
world_size
=
1
,
init_distributed_environment
(
world_size
=
1
,
...
@@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module:
...
@@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module:
return
model
return
model
@
pytest
.
fixture
(
scope
=
"session"
)
def
llama_2_7b_base_huggingface_id
():
# used as a base model for testing with sql lora adapter
return
"meta-llama/Llama-2-7b-hf"
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
sql_lora_huggingface_id
():
def
sql_lora_huggingface_id
():
# huggingface repo id is used to test lora runtime downloading.
# huggingface repo id is used to test lora runtime downloading.
...
@@ -198,6 +204,12 @@ def qwen2vl_lora_files():
...
@@ -198,6 +204,12 @@ def qwen2vl_lora_files():
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen2-vl-lora-pokemon"
)
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen2-vl-lora-pokemon"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
qwen25vl_base_huggingface_id
():
# used as a base model for testing with qwen25vl lora adapter
return
"Qwen/Qwen2.5-VL-3B-Instruct"
@
pytest
.
fixture
(
scope
=
"session"
)
@
pytest
.
fixture
(
scope
=
"session"
)
def
qwen25vl_lora_files
():
def
qwen25vl_lora_files
():
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen25-vl-lora-pokemon"
)
return
snapshot_download
(
repo_id
=
"jeeejeee/qwen25-vl-lora-pokemon"
)
...
@@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch):
...
@@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch):
@
pytest
.
fixture
@
pytest
.
fixture
def
reset_default_device
():
def
reset_default_device
():
"""
"""
Some tests, such as `test_punica_ops.py`, explicitly set the
Some tests, such as `test_punica_ops.py`, explicitly set the
default device, which can affect subsequent tests. Adding this fixture
default device, which can affect subsequent tests. Adding this fixture
helps avoid this problem.
helps avoid this problem.
"""
"""
original_device
=
torch
.
get_default_device
()
original_device
=
torch
.
get_default_device
()
...
...
tests/lora/test_lora_allowed_token_ids.py
0 → 100644
View file @
7a985548
# SPDX-License-Identifier: Apache-2.0
import
pytest
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoRAConfig
,
ModelConfig
,
VllmConfig
)
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.v1.engine.processor
import
Processor
def
test_allowed_token_ids_with_lora_vocab
(
llama_2_7b_base_huggingface_id
,
sql_lora_files
):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that define additional tokens.
"""
# Setup a base model compatible with the sql_lora_files adapter and
# a known number of tokens in the base model.
model_config
=
ModelConfig
(
model
=
llama_2_7b_base_huggingface_id
,
tokenizer
=
llama_2_7b_base_huggingface_id
,
tokenizer_mode
=
"auto"
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
device_config
=
DeviceConfig
(),
lora_config
=
LoRAConfig
(),
)
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
lora_config
=
vllm_config
.
lora_config
)
processor
=
Processor
(
vllm_config
,
tokenizer
)
lora_request
=
LoRARequest
(
"1"
,
1
,
str
(
sql_lora_files
))
request_id
=
"1"
prompt
=
"a prompt"
# tokens added in the lora adapter should not raise an error
lora_token_ids
=
[
32000
,
32001
,
32002
,
32003
]
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
lora_token_ids
),
lora_request
=
lora_request
)
# tokens in the base model should not raise an error
base_token_ids
=
[
1000
,
1001
,
1002
,
1003
]
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
base_token_ids
),
lora_request
=
lora_request
)
# tokens not in the lora adapter should raise an error
invalid_token_ids
=
[
35000
,
35001
,
35002
,
35003
]
with
pytest
.
raises
(
ValueError
):
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
invalid_token_ids
),
lora_request
=
lora_request
)
# tokens in the lora adapter with no lora request should raise an error
with
pytest
.
raises
(
ValueError
):
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
lora_token_ids
),
)
def
test_allowed_token_ids_with_lora_adapter_no_vocab
(
qwen25vl_base_huggingface_id
,
qwen25vl_lora_files
):
"""
Test that we properly resolve the range of allowed token ids for lora
adapters that do not define additional tokens.
"""
# Setup a base model compatible with the qwen25vl_lora_files adapter and
# a known number of tokens in the base model.
model_config
=
ModelConfig
(
model
=
qwen25vl_base_huggingface_id
,
tokenizer
=
qwen25vl_base_huggingface_id
,
tokenizer_mode
=
"auto"
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
CacheConfig
(),
device_config
=
DeviceConfig
(),
lora_config
=
LoRAConfig
(),
)
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
scheduler_config
=
vllm_config
.
scheduler_config
,
lora_config
=
vllm_config
.
lora_config
)
processor
=
Processor
(
vllm_config
,
tokenizer
)
lora_request
=
LoRARequest
(
"1"
,
1
,
str
(
qwen25vl_lora_files
))
request_id
=
"1"
prompt
=
"a prompt"
# tokens in the base model should not raise an error
base_token_ids
=
[
1000
,
1001
,
1002
,
1003
]
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
base_token_ids
),
lora_request
=
lora_request
)
# tokens in the base model with no lora request should not raise an error
base_token_ids
=
[
1000
,
1001
,
1002
,
1003
]
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
base_token_ids
),
)
# tokens not in the base model should raise an error
invalid_token_ids
=
[
200000
,
200001
,
200002
,
200003
]
with
pytest
.
raises
(
ValueError
):
processor
.
process_inputs
(
request_id
,
prompt
,
params
=
SamplingParams
(
allowed_token_ids
=
invalid_token_ids
),
lora_request
=
lora_request
)
tests/lora/test_lora_huggingface.py
View file @
7a985548
...
@@ -30,7 +30,7 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
...
@@ -30,7 +30,7 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
lora_path
=
get_adapter_absolute_path
(
lora_name
)
lora_path
=
get_adapter_absolute_path
(
lora_name
)
# lora loading should work for either absolute path and hugg
g
ingface id.
# lora loading should work for either absolute path and huggingface id.
peft_helper
=
PEFTHelper
.
from_local_dir
(
lora_path
,
4096
)
peft_helper
=
PEFTHelper
.
from_local_dir
(
lora_path
,
4096
)
lora_model
=
LoRAModel
.
from_local_checkpoint
(
lora_model
=
LoRAModel
.
from_local_checkpoint
(
lora_path
,
lora_path
,
...
...
Prev
1
…
14
15
16
17
18
19
20
21
22
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment