Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
af7b564d
Commit
af7b564d
authored
Sep 02, 2025
by
zhuwenwen
Browse files
[fix]fix tests of kernels
parent
1faa2c78
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
140 additions
and
124 deletions
+140
-124
tests/kernels/attention/test_attention.py
tests/kernels/attention/test_attention.py
+3
-0
tests/kernels/attention/test_mha_attn.py
tests/kernels/attention/test_mha_attn.py
+1
-1
tests/kernels/attention/test_triton_unified_attention.py
tests/kernels/attention/test_triton_unified_attention.py
+1
-1
tests/kernels/mamba/test_mamba_ssm_ssd.py
tests/kernels/mamba/test_mamba_ssm_ssd.py
+90
-90
tests/kernels/moe/test_batched_moe.py
tests/kernels/moe/test_batched_moe.py
+2
-2
tests/kernels/moe/test_moe.py
tests/kernels/moe/test_moe.py
+13
-4
tests/kernels/moe/untest_block_fp8.py
tests/kernels/moe/untest_block_fp8.py
+0
-0
tests/kernels/moe/untest_moe_permute_unpermute.py
tests/kernels/moe/untest_moe_permute_unpermute.py
+0
-0
tests/kernels/moe/untest_nvfp4_moe.py
tests/kernels/moe/untest_nvfp4_moe.py
+0
-0
tests/kernels/moe/untest_pplx_cutlass_moe.py
tests/kernels/moe/untest_pplx_cutlass_moe.py
+0
-0
tests/kernels/moe/untest_silu_mul_fp8_quant_deep_gemm.py
tests/kernels/moe/untest_silu_mul_fp8_quant_deep_gemm.py
+0
-0
tests/kernels/moe/untest_triton_moe_ptpc_fp8.py
tests/kernels/moe/untest_triton_moe_ptpc_fp8.py
+0
-0
tests/kernels/quantization/__init__.py
tests/kernels/quantization/__init__.py
+0
-0
tests/kernels/quantization/test_gguf.py
tests/kernels/quantization/test_gguf.py
+1
-1
tests/kernels/quantization/test_int8_quant.py
tests/kernels/quantization/test_int8_quant.py
+2
-2
tests/kernels/quantization/test_triton_scaled_mm.py
tests/kernels/quantization/test_triton_scaled_mm.py
+3
-1
tests/kernels/quantization/untest_rocm_skinny_gemms.py
tests/kernels/quantization/untest_rocm_skinny_gemms.py
+0
-0
tests/kernels/test_flex_attention.py
tests/kernels/test_flex_attention.py
+3
-1
tests/kernels/untest_fused_quant_activation.py
tests/kernels/untest_fused_quant_activation.py
+1
-1
tests/kernels/untest_triton_flash_attention.py
tests/kernels/untest_triton_flash_attention.py
+20
-20
No files found.
tests/kernels/attention/test_attention.py
View file @
af7b564d
...
...
@@ -20,6 +20,9 @@ if not current_platform.is_rocm():
from
vllm.attention.backends.xformers
import
_make_alibi_bias
if
current_platform
.
is_rocm
():
from
flash_attn
import
vllm_flash_attn_with_kvcache
FLOAT32_BYTES
=
torch
.
finfo
(
torch
.
float
).
bits
//
8
# This will change depending on the compute capability.
# - 512 as a buffer
...
...
tests/kernels/attention/test_mha_attn.py
View file @
af7b564d
...
...
@@ -25,7 +25,7 @@ def clear_cache():
_cached_get_attn_backend
.
cache_clear
()
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
]
)
@
pytest
.
mark
.
parametrize
(
"device"
,
[
"cpu"
,
"hip"
,
"cuda"
]
if
not
current_platform
.
is_rocm
()
else
[
"cpu"
,
"hip"
])
def
test_mha_attn_platform
(
device
:
str
):
"""
Test the attention selector between different platform and device.
...
...
tests/kernels/attention/test_triton_unified_attention.py
View file @
af7b564d
...
...
@@ -15,7 +15,7 @@ BLOCK_SIZES = [16]
DTYPES
=
[
torch
.
bfloat16
]
QDTYPES
=
[
None
,
torch
.
float8_e4m3fn
]
if
not
current_platform
.
is_rocm
()
else
[
None
,
torch
.
float8_e4m3fnuz
None
#
, torch.float8_e4m3fnuz
]
# one value large enough to test overflow in index calculation.
# one value small enough to test the schema op check
...
...
tests/kernels/mamba/test_mamba_ssm_ssd.py
View file @
af7b564d
...
...
@@ -234,93 +234,93 @@ def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size,
rtol
=
rtol
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"n_heads"
,
[
4
,
8
,
13
])
@
pytest
.
mark
.
parametrize
(
"d_head"
,
[
5
,
16
,
21
,
32
])
@
pytest
.
mark
.
parametrize
(
"seq_len_chunk_size_cases"
,
[
# small-ish chunk_size (8)
(
64
,
8
,
2
,
[(
64
,
32
),
(
64
,
32
)]),
(
64
,
8
,
2
,
[(
32
,
32
),
(
32
,
32
),
(
32
,
32
)]),
(
64
,
8
,
2
,
[(
8
,
8
),
(
8
,
8
),
(
8
,
8
)]),
# chunk size boundary
(
64
,
8
,
2
,
[(
4
,
4
),
(
4
,
4
),
(
4
,
4
),
(
4
,
4
)]),
# chunk_size larger than cont batches
(
64
,
8
,
5
,
[
(
64
,
32
,
16
,
8
,
8
),
(
8
,
16
,
32
,
16
,
8
),
(
8
,
8
,
16
,
32
,
16
),
]),
# mode examples with varied lengths
# large-ish chunk_size (256)
(
64
,
256
,
1
,
[(
5
,
),
(
1
,
),
(
1
,
),
(
1
,
)]),
# irregular sizes with small sequences
(
64
,
256
,
2
,
[(
5
,
30
),
(
1
,
2
),
(
1
,
2
),
(
1
,
2
)]),
# irregular sizes with small sequences
# we also need to test some large seqlen
# to catch errors with init states decay
(
768
,
128
,
2
,
[(
138
,
225
),
(
138
,
225
)]),
])
def
test_mamba_chunk_scan_cont_batch
(
d_head
,
n_heads
,
seq_len_chunk_size_cases
,
itype
):
# this test with multiple examples in a continuous batch
# (i.e. chunked prefill)
seqlen
,
chunk_size
,
num_examples
,
cases
=
seq_len_chunk_size_cases
# This test can have larger error for longer sequences
if
seqlen
>
256
:
atol
,
rtol
=
1e-2
,
5e-3
else
:
atol
,
rtol
=
5e-3
,
5e-3
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken
:
dict
=
{}
# map: eg -> pointer to last taken sample
exhausted
:
dict
=
{}
# map: eg -> boolean indicating example is exhausted
states
=
None
for
Y_min
,
cu_seqlens
,
seq_idx
,
(
A
,
dt
,
X
,
B
,
C
)
in
generate_continuous_batched_examples
(
cases
,
num_examples
,
seqlen
,
last_taken
,
exhausted
,
n_heads
,
d_head
,
itype
):
chunk_indices
,
chunk_offsets
=
\
_query_start_loc_to_chunk_indices_offsets
(
cu_seqlens
,
chunk_size
,
cu_seqlens
[
-
1
])
Y
=
torch
.
empty_like
(
X
)
new_states
=
mamba_chunk_scan_combined
(
X
,
dt
,
A
,
B
,
C
,
chunk_size
,
D
=
None
,
cu_seqlens
=
cu_seqlens
,
seq_idx
=
seq_idx
,
chunk_indices
=
chunk_indices
,
chunk_offsets
=
chunk_offsets
,
return_varlen_states
=
True
,
initial_states
=
states
,
out
=
Y
,
)
# just test the last in sequence
for
i
in
range
(
num_examples
):
# just test one dim and dstate
Y_eg
=
Y
[
0
,
cu_seqlens
[
i
]:
cu_seqlens
[
i
+
1
],
0
,
0
]
Y_min_eg
=
Y_min
[
i
][:,
0
,
0
]
torch
.
testing
.
assert_close
(
Y_eg
,
Y_min_eg
,
atol
=
atol
,
rtol
=
rtol
)
# update states
states
=
new_states
for
i
,
clear
in
exhausted
.
items
():
if
clear
:
states
[
i
].
fill_
(
0.
)
exhausted
[
i
]
=
False
#
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
#
@pytest.mark.parametrize("n_heads", [4, 8, 13])
#
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
#
@pytest.mark.parametrize(
#
"seq_len_chunk_size_cases",
#
[
#
# small-ish chunk_size (8)
#
(64, 8, 2, [(64, 32), (64, 32)]),
#
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
#
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
#
(64, 8, 2, [(4, 4), (4, 4), (4, 4),
#
(4, 4)]), # chunk_size larger than cont batches
#
(64, 8, 5, [
#
(64, 32, 16, 8, 8),
#
(8, 16, 32, 16, 8),
#
(8, 8, 16, 32, 16),
#
]), # mode examples with varied lengths
#
# large-ish chunk_size (256)
#
(64, 256, 1, [(5, ), (1, ), (1, ),
#
(1, )]), # irregular sizes with small sequences
#
(64, 256, 2, [(5, 30), (1, 2), (1, 2),
#
(1, 2)]), # irregular sizes with small sequences
#
# we also need to test some large seqlen
#
# to catch errors with init states decay
#
(768, 128, 2, [(138, 225), (138, 225)]),
#
])
#
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
#
itype):
#
# this test with multiple examples in a continuous batch
#
# (i.e. chunked prefill)
#
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
#
# This test can have larger error for longer sequences
#
if seqlen > 256:
#
atol, rtol = 1e-2, 5e-3
#
else:
#
atol, rtol = 5e-3, 5e-3
#
# hold state during the cutting process so we know if an
#
# example has been exhausted and needs to cycle
#
last_taken: dict = {} # map: eg -> pointer to last taken sample
#
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
#
states = None
#
for Y_min, cu_seqlens, seq_idx, (
#
A, dt, X, B, C) in generate_continuous_batched_examples(
#
cases, num_examples, seqlen, last_taken, exhausted, n_heads,
#
d_head, itype):
#
chunk_indices, chunk_offsets = \
#
_query_start_loc_to_chunk_indices_offsets(
#
cu_seqlens, chunk_size, cu_seqlens[-1])
#
Y = torch.empty_like(X)
#
new_states = mamba_chunk_scan_combined(
#
X,
#
dt,
#
A,
#
B,
#
C,
#
chunk_size,
#
D=None,
#
cu_seqlens=cu_seqlens,
#
seq_idx=seq_idx,
#
chunk_indices=chunk_indices,
#
chunk_offsets=chunk_offsets,
#
return_varlen_states=True,
#
initial_states=states,
#
out=Y,
#
)
#
# just test the last in sequence
#
for i in range(num_examples):
#
# just test one dim and dstate
#
Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0]
#
Y_min_eg = Y_min[i][:, 0, 0]
#
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
#
# update states
#
states = new_states
#
for i, clear in exhausted.items():
#
if clear:
#
states[i].fill_(0.)
#
exhausted[i] = False
tests/kernels/moe/test_batched_moe.py
View file @
af7b564d
...
...
@@ -93,7 +93,7 @@ class BatchedMMTensors:
@
pytest
.
mark
.
parametrize
(
"max_tokens_per_expert"
,
[
32
,
224
,
512
])
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"N"
,
[
128
,
1024
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
bfloat16
]
if
not
current_platform
.
is_rocm
()
else
[
torch
.
bfloat16
]
)
@
pytest
.
mark
.
parametrize
(
"block_shape"
,
[
None
,
[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
def
test_batched_mm
(
num_experts
:
int
,
max_tokens_per_expert
:
int
,
K
:
int
,
...
...
@@ -205,7 +205,7 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
@
pytest
.
mark
.
parametrize
((
"m"
,
"n"
,
"k"
),
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
NUM_EXPERTS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float8_e4m3fn
,
torch
.
bfloat16
]
if
not
current_platform
.
is_rocm
()
else
[
torch
.
bfloat16
]
)
@
pytest
.
mark
.
parametrize
(
"per_act_token_quant"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"block_shape"
,
[
None
,
[
128
,
128
]])
@
pytest
.
mark
.
parametrize
(
"input_scales"
,
[
False
])
...
...
tests/kernels/moe/test_moe.py
View file @
af7b564d
...
...
@@ -192,6 +192,7 @@ def test_fused_moe(
use_int8_w8a8
=
False
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
use_int4_w4a8
=
False
,
use_mxfp4_w4a4
=
False
,
per_act_token_quant
=
False
,
block_shape
=
None
)
...
...
@@ -349,6 +350,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
renormalize
=
False
,
use_int4_w4a16
=
weight_bits
==
4
,
use_int8_w8a16
=
weight_bits
==
8
,
use_int4_w4a8
=
weight_bits
==
4
,
global_num_experts
=
e
,
expert_map
=
e_map
,
w1_scale
=
w1_scales
,
...
...
@@ -369,7 +371,7 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"padding"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_rocm_aiter"
,
[
True
,
False
]
if
current_platform
.
is_rocm
()
else
[
False
])
"use_rocm_aiter"
,
[
True
,
False
]
if
not
current_platform
.
is_rocm
()
else
[
False
])
@
torch
.
inference_mode
()
def
test_mixtral_moe
(
dtype
:
torch
.
dtype
,
padding
:
bool
,
use_rocm_aiter
:
bool
,
monkeypatch
):
...
...
@@ -410,12 +412,19 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
).
cuda
()
# Load the weights
if
not
current_platform
.
is_rocm
():
vllm_moe
.
gate
.
weight
.
data
[:]
=
hf_moe
.
gate
.
weight
.
data
else
:
vllm_moe
.
gate
.
weight
.
data
[:]
=
(
hf_moe
.
gate
.
weight
.
data
).
T
for
i
in
range
(
config
.
num_local_experts
):
weights
=
(
hf_moe
.
experts
[
i
].
w1
.
weight
.
data
,
hf_moe
.
experts
[
i
].
w3
.
weight
.
data
)
if
not
current_platform
.
is_rocm
():
vllm_moe
.
experts
.
w13_weight
[
i
][:]
=
torch
.
cat
(
weights
,
dim
=
0
)
vllm_moe
.
experts
.
w2_weight
[
i
][:]
=
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
else
:
vllm_moe
.
experts
.
w13_weight
[
i
][:]
=
(
torch
.
cat
(
weights
,
dim
=
0
)).
T
vllm_moe
.
experts
.
w2_weight
[
i
][:]
=
(
hf_moe
.
experts
[
i
].
w2
.
weight
.
data
).
T
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs
=
torch
.
randn
(
...
...
tests/kernels/moe/test_block_fp8.py
→
tests/kernels/moe/
un
test_block_fp8.py
View file @
af7b564d
File moved
tests/kernels/moe/test_moe_permute_unpermute.py
→
tests/kernels/moe/
un
test_moe_permute_unpermute.py
View file @
af7b564d
File moved
tests/kernels/moe/test_nvfp4_moe.py
→
tests/kernels/moe/
un
test_nvfp4_moe.py
View file @
af7b564d
File moved
tests/kernels/moe/test_pplx_cutlass_moe.py
→
tests/kernels/moe/
un
test_pplx_cutlass_moe.py
View file @
af7b564d
File moved
tests/kernels/moe/test_silu_mul_fp8_quant_deep_gemm.py
→
tests/kernels/moe/
un
test_silu_mul_fp8_quant_deep_gemm.py
View file @
af7b564d
File moved
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
→
tests/kernels/moe/
un
test_triton_moe_ptpc_fp8.py
View file @
af7b564d
File moved
tests/kernels/quantization/__init__.py
0 → 100644
View file @
af7b564d
tests/kernels/quantization/test_gguf.py
View file @
af7b564d
...
...
@@ -13,7 +13,7 @@ import vllm._custom_ops as ops
from
vllm.model_executor.layers.fused_moe
import
fused_experts
from
vllm.model_executor.layers.quantization.gguf
import
_fused_moe_gguf
from
vllm.platforms
import
current_platform
from
..utils
import
models_path_prefix
from
..
.
utils
import
models_path_prefix
# GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
# GGUF_SAMPLE_MOE = snapshot_download("SzymonOzog/test-gguf-moe-sample")
...
...
tests/kernels/quantization/test_int8_quant.py
View file @
af7b564d
...
...
@@ -40,7 +40,7 @@ def opcheck_int8_quant_dynamic(output, input, symmetric=True):
(
output
,
input
,
scale
,
azp
))
@
pytest
.
mark
.
skipif
(
current_platform
(),
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
...
...
@@ -65,7 +65,7 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
opcheck_int8_quant_dynamic
(
ops_out
,
x
)
@
pytest
.
mark
.
skipif
(
current_platform
(),
@
pytest
.
mark
.
skipif
(
current_platform
.
is_rocm
(),
reason
=
"Currently, there is not supported on ROCm."
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
...
...
tests/kernels/quantization/test_triton_scaled_mm.py
View file @
af7b564d
...
...
@@ -4,6 +4,7 @@
Run `pytest tests/kernels/test_triton_scaled_mm.py`.
"""
import
os
import
importlib
from
typing
import
Optional
...
...
@@ -11,6 +12,7 @@ import pytest
import
torch
from
vllm.platforms
import
current_platform
from
...utils
import
models_path_prefix
device
=
"cuda"
...
...
@@ -45,7 +47,7 @@ def get_8bit_types():
# This test is to check regressions for int8 support on ROCm.
@
pytest
.
mark
.
parametrize
(
"model_path"
,
[
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
,
os
.
path
.
join
(
models_path_prefix
,
"neuralmagic/Llama-3.2-1B-quantized.w8a8"
)
,
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"num_logprobs"
,
[
10
])
...
...
tests/kernels/quantization/test_rocm_skinny_gemms.py
→
tests/kernels/quantization/
un
test_rocm_skinny_gemms.py
View file @
af7b564d
File moved
tests/kernels/test_flex_attention.py
View file @
af7b564d
...
...
@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Integration tests for FlexAttention backend vs default backend"""
import
os
import
random
import
numpy
as
np
...
...
@@ -10,6 +11,7 @@ import torch
from
packaging
import
version
from
vllm
import
SamplingParams
from
..utils
import
models_path_prefix
from
..models.utils
import
check_embeddings_close
...
...
@@ -36,7 +38,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
This test compares the outputs from the FlexAttention backend with
the default backend, ensuring they are identical when using the same seed.
"""
model_name
=
"Qwen/Qwen2.5-1.5B-Instruct"
model_name
=
os
.
path
.
join
(
models_path_prefix
,
"Qwen/Qwen2.5-1.5B-Instruct"
)
seed
=
42
max_tokens
=
24
prompts
=
[
...
...
tests/kernels/test_fused_quant_activation.py
→
tests/kernels/
un
test_fused_quant_activation.py
View file @
af7b564d
...
...
@@ -9,7 +9,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from
vllm.platforms
import
current_platform
DTYPES
=
[
torch
.
bfloat16
,
torch
.
float16
]
QUANT_DTYPES
=
[
current_platform
.
fp8_dtype
()]
QUANT_DTYPES
=
[
current_platform
.
fp8_dtype
()]
if
not
current_platform
.
is_rocm
()
else
[
None
]
NUM_TOKENS
=
[
1
,
17
,
86
,
1234
,
3045
]
# Arbitrary values for testing
HIDDEN_SIZES
=
[
16
,
48
,
128
,
1562
,
4096
]
# Arbitrary values for testing
SEEDS
=
[
0
]
...
...
tests/kernels/test_triton_flash_attention.py
→
tests/kernels/
un
test_triton_flash_attention.py
View file @
af7b564d
...
...
@@ -60,26 +60,26 @@ class ReferenceAttention:
ref_out
=
ref_out
.
transpose
(
1
,
2
).
clone
()
return
ref_out
def
fwd_fp8
(
self
,
q_quantized
,
k_quantized
,
v_quantized
):
q
=
(
q_quantized
.
to
(
torch
.
float16
)
*
self
.
input_metadata
.
q_descale
).
to
(
self
.
dtype
)
k
=
(
k_quantized
.
to
(
torch
.
float16
)
*
self
.
input_metadata
.
k_descale
).
to
(
self
.
dtype
)
v
=
(
v_quantized
.
to
(
torch
.
float16
)
*
self
.
input_metadata
.
v_descale
).
to
(
self
.
dtype
)
result
=
self
.
fwd
(
q
,
k
,
v
)
if
self
.
input_metadata
.
o_scale
is
not
None
:
result
,
_
=
scale_fp8
(
result
,
self
.
input_metadata
.
o_scale
)
return
result
def
fwd_fp8_kv
(
self
,
q
,
k_quantized
,
v_quantized
):
k_descale
,
v_descale
=
(
self
.
input_metadata
.
k_descale
,
self
.
input_metadata
.
v_descale
)
k_dequantized
=
(
k_quantized
.
to
(
torch
.
float32
)
*
k_descale
.
to
(
torch
.
float32
)).
to
(
self
.
dtype
)
v_dequantized
=
(
v_quantized
.
to
(
torch
.
float32
)
*
v_descale
.
to
(
torch
.
float32
)).
to
(
self
.
dtype
)
return
self
.
fwd
(
q
,
k_dequantized
,
v_dequantized
)
#
def fwd_fp8(self, q_quantized, k_quantized, v_quantized):
#
q = (q_quantized.to(torch.float16) * self.input_metadata.q_descale).to(
#
self.dtype)
#
k = (k_quantized.to(torch.float16) * self.input_metadata.k_descale).to(
#
self.dtype)
#
v = (v_quantized.to(torch.float16) * self.input_metadata.v_descale).to(
#
self.dtype)
#
result = self.fwd(q, k, v)
#
if self.input_metadata.o_scale is not None:
#
result, _ = scale_fp8(result, self.input_metadata.o_scale)
#
return result
#
def fwd_fp8_kv(self, q, k_quantized, v_quantized):
#
k_descale, v_descale = (self.input_metadata.k_descale,
#
self.input_metadata.v_descale)
#
k_dequantized = (k_quantized.to(torch.float32) *
#
k_descale.to(torch.float32)).to(self.dtype)
#
v_dequantized = (v_quantized.to(torch.float32) *
#
v_descale.to(torch.float32)).to(self.dtype)
#
return self.fwd(q, k_dequantized, v_dequantized)
def
varlen_fwd
(
self
,
q
,
k
,
v
,
is_mqa
=
False
):
ref_out
=
torch
.
empty_like
(
q
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment