Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cf069aa8
Unverified
Commit
cf069aa8
authored
Mar 03, 2025
by
Harry Mellor
Committed by
GitHub
Mar 02, 2025
Browse files
Update deprecated Python 3.8 typing (#13971)
parent
bf33700e
Changes
300
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
114 additions
and
121 deletions
+114
-121
tests/entrypoints/openai/tool_parsers/utils.py
tests/entrypoints/openai/tool_parsers/utils.py
+5
-4
tests/kernels/quant_utils.py
tests/kernels/quant_utils.py
+3
-3
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+1
-2
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+8
-8
tests/kernels/test_blocksparse_attention.py
tests/kernels/test_blocksparse_attention.py
+6
-6
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+2
-3
tests/kernels/test_cascade_flash_attn.py
tests/kernels/test_cascade_flash_attn.py
+4
-4
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+5
-6
tests/kernels/test_cutlass_2of4_sparse.py
tests/kernels/test_cutlass_2of4_sparse.py
+2
-3
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+2
-2
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+8
-8
tests/kernels/test_flashinfer.py
tests/kernels/test_flashinfer.py
+11
-11
tests/kernels/test_fused_quant_layernorm.py
tests/kernels/test_fused_quant_layernorm.py
+6
-6
tests/kernels/test_gguf.py
tests/kernels/test_gguf.py
+1
-2
tests/kernels/test_machete_mm.py
tests/kernels/test_machete_mm.py
+7
-7
tests/kernels/test_mamba_mixer2.py
tests/kernels/test_mamba_mixer2.py
+1
-2
tests/kernels/test_mamba_ssm_ssd.py
tests/kernels/test_mamba_ssm_ssd.py
+3
-5
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+3
-3
tests/kernels/test_triton_scaled_mm.py
tests/kernels/test_triton_scaled_mm.py
+2
-2
tests/kernels/utils.py
tests/kernels/utils.py
+34
-34
No files found.
tests/entrypoints/openai/tool_parsers/utils.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Iterable
,
List
,
Tuple
,
Union
from
collections.abc
import
Iterable
from
typing
import
Union
from
vllm.entrypoints.openai.protocol
import
(
ChatCompletionRequest
,
DeltaMessage
,
...
...
@@ -12,7 +13,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser
class
StreamingToolReconstructor
:
def
__init__
(
self
,
assert_one_tool_per_delta
:
bool
=
True
):
self
.
tool_calls
:
L
ist
[
ToolCall
]
=
[]
self
.
tool_calls
:
l
ist
[
ToolCall
]
=
[]
self
.
other_content
:
str
=
""
self
.
_assert_one_tool_per_delta
=
assert_one_tool_per_delta
...
...
@@ -72,7 +73,7 @@ def run_tool_extraction(
request
:
Union
[
ChatCompletionRequest
,
None
]
=
None
,
streaming
:
bool
=
False
,
assert_one_tool_per_delta
:
bool
=
True
,
)
->
T
uple
[
Union
[
str
,
None
],
L
ist
[
ToolCall
]]:
)
->
t
uple
[
Union
[
str
,
None
],
l
ist
[
ToolCall
]]:
if
streaming
:
reconstructor
=
run_tool_extraction_streaming
(
tool_parser
,
...
...
@@ -106,7 +107,7 @@ def run_tool_extraction_streaming(
reconstructor
=
StreamingToolReconstructor
(
assert_one_tool_per_delta
=
assert_one_tool_per_delta
)
previous_text
=
""
previous_tokens
:
L
ist
[
int
]
=
[]
previous_tokens
:
l
ist
[
int
]
=
[]
for
delta
in
model_deltas
:
token_delta
=
[
tool_parser
.
vocab
.
get
(
token
)
...
...
tests/kernels/quant_utils.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Union
import
torch
...
...
@@ -19,7 +19,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
def
ref_dynamic_per_token_quant
(
x
:
torch
.
tensor
,
quant_dtype
:
torch
.
dtype
,
scale_ub
:
Optional
[
torch
.
tensor
]
=
None
)
\
->
T
uple
[
torch
.
tensor
,
torch
.
tensor
]:
->
t
uple
[
torch
.
tensor
,
torch
.
tensor
]:
assert
quant_dtype
in
[
torch
.
int8
,
FP8_DTYPE
]
if
scale_ub
is
not
None
:
...
...
@@ -68,7 +68,7 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
# ref_dynamic_per_token_quant, when we have a dynamic_per_tensor int8 quant
# kernel
def
ref_dynamic_per_tensor_fp8_quant
(
x
:
torch
.
tensor
)
\
->
T
uple
[
torch
.
tensor
,
torch
.
tensor
]:
->
t
uple
[
torch
.
tensor
,
torch
.
tensor
]:
fp8_traits
=
torch
.
finfo
(
FP8_DTYPE
)
fp8_traits_max
=
ROCM_FP8_MAX
if
current_platform
.
is_rocm
()
\
...
...
tests/kernels/test_activation.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
random
from
typing
import
Type
import
pytest
import
torch
...
...
@@ -86,7 +85,7 @@ def test_act_and_mul(
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
torch
.
inference_mode
()
def
test_activation
(
activation
:
T
ype
[
torch
.
nn
.
Module
],
activation
:
t
ype
[
torch
.
nn
.
Module
],
num_tokens
:
int
,
d
:
int
,
dtype
:
torch
.
dtype
,
...
...
tests/kernels/test_attention.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
random
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -85,8 +85,8 @@ def ref_single_query_cached_kv_attention(
block_table
=
block_tables_lst
[
i
]
seq_len
=
int
(
seq_lens_lst
[
i
])
keys_lst
:
L
ist
[
torch
.
Tensor
]
=
[]
values_lst
:
L
ist
[
torch
.
Tensor
]
=
[]
keys_lst
:
l
ist
[
torch
.
Tensor
]
=
[]
values_lst
:
l
ist
[
torch
.
Tensor
]
=
[]
for
j
in
range
(
seq_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
...
...
@@ -133,7 +133,7 @@ def test_paged_attention(
kv_cache_factory
,
version
:
str
,
num_seqs
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
...
...
@@ -166,7 +166,7 @@ def test_paged_attention(
# Create the block tables.
max_num_blocks_per_seq
=
(
max_seq_len
+
block_size
-
1
)
//
block_size
block_tables_lst
:
L
ist
[
L
ist
[
int
]]
=
[]
block_tables_lst
:
l
ist
[
l
ist
[
int
]]
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
...
...
@@ -334,7 +334,7 @@ def test_paged_attention(
def
ref_multi_query_kv_attention
(
cu_seq_lens
:
L
ist
[
int
],
cu_seq_lens
:
l
ist
[
int
],
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
...
@@ -342,7 +342,7 @@ def ref_multi_query_kv_attention(
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
num_seqs
=
len
(
cu_seq_lens
)
-
1
ref_outputs
:
L
ist
[
torch
.
Tensor
]
=
[]
ref_outputs
:
l
ist
[
torch
.
Tensor
]
=
[]
for
i
in
range
(
num_seqs
):
start_idx
=
cu_seq_lens
[
i
]
end_idx
=
cu_seq_lens
[
i
+
1
]
...
...
@@ -378,7 +378,7 @@ def ref_multi_query_kv_attention(
@
torch
.
inference_mode
()
def
test_multi_query_kv_attention
(
num_seqs
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
...
...
tests/kernels/test_blocksparse_attention.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
random
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -87,8 +87,8 @@ def ref_single_query_cached_kv_attention(
block_table
=
block_tables_lst
[
i
]
seq_len
=
int
(
seq_lens_lst
[
i
])
keys_lst
:
L
ist
[
torch
.
Tensor
]
=
[]
values_lst
:
L
ist
[
torch
.
Tensor
]
=
[]
keys_lst
:
l
ist
[
torch
.
Tensor
]
=
[]
values_lst
:
l
ist
[
torch
.
Tensor
]
=
[]
for
j
in
range
(
seq_len
):
block_number
=
int
(
block_table
[
j
//
block_size
])
block_offset
=
j
%
block_size
...
...
@@ -162,7 +162,7 @@ def test_paged_attention(
kv_cache_factory
,
version
:
str
,
num_seqs
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
use_alibi
:
bool
,
block_size
:
int
,
...
...
@@ -331,7 +331,7 @@ def test_paged_attention(
def
ref_multi_query_kv_attention
(
cu_seq_lens
:
L
ist
[
int
],
cu_seq_lens
:
l
ist
[
int
],
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
...
@@ -376,7 +376,7 @@ def ref_multi_query_kv_attention(
@
torch
.
inference_mode
()
def
test_varlen_blocksparse_attention_prefill
(
num_seqs
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
blocksparse_local_blocks
:
int
,
blocksparse_vert_stride
:
int
,
...
...
tests/kernels/test_cache.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
random
from
typing
import
List
,
Tuple
import
pytest
import
torch
...
...
@@ -74,7 +73,7 @@ def test_copy_blocks(
src_blocks
=
random
.
sample
(
range
(
num_blocks
),
num_mappings
)
remainig_blocks
=
list
(
set
(
range
(
num_blocks
))
-
set
(
src_blocks
))
dst_blocks
=
random
.
sample
(
remainig_blocks
,
2
*
num_mappings
)
block_mapping
:
L
ist
[
T
uple
[
int
,
int
]]
=
[]
block_mapping
:
l
ist
[
t
uple
[
int
,
int
]]
=
[]
for
i
in
range
(
num_mappings
):
src
=
src_blocks
[
i
]
dst1
=
dst_blocks
[
2
*
i
]
...
...
@@ -342,7 +341,7 @@ def test_reshape_and_cache_flash(
@
torch
.
inference_mode
()
def
test_swap_blocks
(
kv_cache_factory
,
direction
:
T
uple
[
str
,
str
],
direction
:
t
uple
[
str
,
str
],
num_mappings
:
int
,
num_heads
:
int
,
head_size
:
int
,
...
...
tests/kernels/test_cascade_flash_attn.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -25,7 +25,7 @@ DTYPES = [torch.float16, torch.bfloat16]
@
torch
.
inference_mode
()
def
test_merge_kernel
(
num_tokens
:
int
,
num_heads
:
T
uple
[
int
,
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
):
...
...
@@ -85,8 +85,8 @@ CASES = [
@
pytest
.
mark
.
parametrize
(
"fa_version"
,
[
2
,
3
])
@
torch
.
inference_mode
()
def
test_cascade
(
seq_lens_and_common_prefix
:
T
uple
[
L
ist
[
T
uple
[
int
,
int
]],
int
],
num_heads
:
T
uple
[
int
,
int
],
seq_lens_and_common_prefix
:
t
uple
[
l
ist
[
t
uple
[
int
,
int
]],
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
...
...
tests/kernels/test_cutlass.py
View file @
cf069aa8
...
...
@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_cutlass.py`.
"""
from
typing
import
Type
import
pytest
import
torch
...
...
@@ -71,7 +70,7 @@ def cutlass_fp8_gemm_helper(m: int,
a_scale_group_shape
:
tuple
,
b_scale_group_shape
:
tuple
,
use_bias
:
bool
,
out_dtype
:
T
ype
[
torch
.
dtype
]
=
torch
.
bfloat16
,
out_dtype
:
t
ype
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
...
...
@@ -109,7 +108,7 @@ def cutlass_int8_gemm_helper(m: int,
a_scale_group_shape
:
tuple
,
b_scale_group_shape
:
tuple
,
use_bias
:
bool
,
out_dtype
:
T
ype
[
torch
.
dtype
]
=
torch
.
bfloat16
,
out_dtype
:
t
ype
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
...
...
@@ -187,7 +186,7 @@ def test_cutlass_int8_gemm(m: int, n: int, k: int, a_scale_group_shape,
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_output_dtype
(
a_scale_group_shape
,
b_scale_group_shape
,
out_dtype
:
T
ype
[
torch
.
dtype
],
out_dtype
:
t
ype
[
torch
.
dtype
],
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
512
,
512
,
...
...
@@ -208,7 +207,7 @@ def test_cutlass_int8_gemm_output_dtype(a_scale_group_shape,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_output_dtype
(
a_scale_group_shape
,
b_scale_group_shape
,
out_dtype
:
T
ype
[
torch
.
dtype
],
out_dtype
:
t
ype
[
torch
.
dtype
],
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
512
,
512
,
...
...
@@ -227,7 +226,7 @@ def test_cutlass_fp8_gemm_output_dtype(a_scale_group_shape,
reason
=
"FP8 blockwise is not supported on this GPU type."
)
def
test_cutlass_fp8_blockwise_scale_gemm_dtype
(
a_scale_group_shape
,
b_scale_group_shape
,
out_dtype
:
T
ype
[
torch
.
dtype
],
out_dtype
:
t
ype
[
torch
.
dtype
],
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
512
,
512
,
...
...
tests/kernels/test_cutlass_2of4_sparse.py
View file @
cf069aa8
...
...
@@ -3,7 +3,6 @@
Run `pytest tests/kernels/test_semi_structured.py`.
"""
from
typing
import
Tuple
,
Type
import
pytest
import
torch
...
...
@@ -79,7 +78,7 @@ def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
def
make_rand_sparse_tensors
(
dtype
:
torch
.
dtype
,
m
:
int
,
n
:
int
,
k
:
int
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
a
=
torch
.
randn
((
m
,
k
),
device
=
'cuda'
)
b
=
torch
.
randn
((
n
,
k
),
device
=
'cuda'
).
t
()
...
...
@@ -167,7 +166,7 @@ MNK_FACTORS = [
@
pytest
.
mark
.
parametrize
(
"m, n, k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_sparse_gemm
(
m
:
int
,
k
:
int
,
n
:
int
,
dtype
:
T
ype
[
torch
.
dtype
],
def
test_cutlass_sparse_gemm
(
m
:
int
,
k
:
int
,
n
:
int
,
dtype
:
t
ype
[
torch
.
dtype
],
use_bias
:
bool
):
# Create tensors
...
...
tests/kernels/test_encoder_decoder_attn.py
View file @
cf069aa8
...
...
@@ -243,7 +243,7 @@ def _decoder_attn_setup(
test_pt
:
TestPoint
,
test_rsrcs
:
TestResources
,
block_base_addr
:
int
=
0
,
)
->
T
uple
[
QKVInputs
,
PhaseTestParameters
,
PhaseTestParameters
,
int
]:
)
->
t
uple
[
QKVInputs
,
PhaseTestParameters
,
PhaseTestParameters
,
int
]:
'''
Set up test vectors & data structures for self-attention test.
...
...
@@ -421,7 +421,7 @@ def _enc_dec_cross_attn_setup_reuses_query(
test_pt
:
TestPoint
,
test_rsrcs
:
TestResources
,
block_base_addr
:
int
=
0
,
)
->
T
uple
[
PhaseTestParameters
,
PhaseTestParameters
]:
)
->
t
uple
[
PhaseTestParameters
,
PhaseTestParameters
]:
'''
Set up test vectors & data structures for cross-attention test.
...
...
tests/kernels/test_flash_attn.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -24,8 +24,8 @@ def ref_paged_attn(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
query_lens
:
L
ist
[
int
],
kv_lens
:
L
ist
[
int
],
query_lens
:
l
ist
[
int
],
kv_lens
:
l
ist
[
int
],
block_tables
:
torch
.
Tensor
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
...
...
@@ -35,7 +35,7 @@ def ref_paged_attn(
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
outputs
:
L
ist
[
torch
.
Tensor
]
=
[]
outputs
:
l
ist
[
torch
.
Tensor
]
=
[]
start_idx
=
0
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
...
...
@@ -88,8 +88,8 @@ def ref_paged_attn(
@
torch
.
inference_mode
()
def
test_flash_attn_with_paged_kv
(
use_out
:
bool
,
kv_lens
:
L
ist
[
int
],
num_heads
:
T
uple
[
int
,
int
],
kv_lens
:
l
ist
[
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
...
...
@@ -174,8 +174,8 @@ def test_flash_attn_with_paged_kv(
@
torch
.
inference_mode
()
def
test_varlen_with_paged_kv
(
use_out
:
bool
,
seq_lens
:
L
ist
[
T
uple
[
int
,
int
]],
num_heads
:
T
uple
[
int
,
int
],
seq_lens
:
l
ist
[
t
uple
[
int
,
int
]],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
sliding_window
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
...
...
tests/kernels/test_flashinfer.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
flashinfer
import
pytest
...
...
@@ -19,8 +19,8 @@ def ref_paged_attn(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
query_lens
:
L
ist
[
int
],
kv_lens
:
L
ist
[
int
],
query_lens
:
l
ist
[
int
],
kv_lens
:
l
ist
[
int
],
block_tables
:
torch
.
Tensor
,
scale
:
float
,
sliding_window
:
Optional
[
int
]
=
None
,
...
...
@@ -30,7 +30,7 @@ def ref_paged_attn(
block_tables
=
block_tables
.
cpu
().
numpy
()
_
,
block_size
,
num_kv_heads
,
head_size
=
key_cache
.
shape
outputs
:
L
ist
[
torch
.
Tensor
]
=
[]
outputs
:
l
ist
[
torch
.
Tensor
]
=
[]
start_idx
=
0
for
i
in
range
(
num_seqs
):
query_len
=
query_lens
[
i
]
...
...
@@ -78,8 +78,8 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
torch
.
inference_mode
def
test_flashinfer_decode_with_paged_kv
(
kv_lens
:
L
ist
[
int
],
num_heads
:
T
uple
[
int
,
int
],
kv_lens
:
l
ist
[
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
...
...
@@ -168,8 +168,8 @@ def test_flashinfer_decode_with_paged_kv(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
torch
.
inference_mode
def
test_flashinfer_prefill_with_paged_kv
(
seq_lens
:
L
ist
[
T
uple
[
int
,
int
]],
num_heads
:
T
uple
[
int
,
int
],
def
test_flashinfer_prefill_with_paged_kv
(
seq_lens
:
l
ist
[
t
uple
[
int
,
int
]],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
])
->
None
:
...
...
@@ -270,7 +270,7 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
def
test_flashinfer_prefill_with_paged_fp8_kv
(
seq_lens
:
L
ist
[
T
uple
[
int
,
int
]],
num_heads
:
T
uple
[
int
,
int
],
seq_lens
:
l
ist
[
t
uple
[
int
,
int
]],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
soft_cap
:
Optional
[
float
])
->
None
:
pytest
.
skip
(
"TODO: fix the accuracy issue"
)
...
...
@@ -378,8 +378,8 @@ def test_flashinfer_prefill_with_paged_fp8_kv(
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
30.0
,
50.0
])
@
torch
.
inference_mode
def
test_flashinfer_decode_with_paged_fp8_kv
(
kv_lens
:
L
ist
[
int
],
num_heads
:
T
uple
[
int
,
int
],
kv_lens
:
l
ist
[
int
],
num_heads
:
t
uple
[
int
,
int
],
head_size
:
int
,
dtype
:
torch
.
dtype
,
block_size
:
int
,
...
...
tests/kernels/test_fused_quant_layernorm.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Union
import
pytest
import
torch
...
...
@@ -39,7 +39,7 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
def
ref_rms_norm
(
rms_norm_layer
:
RMSNorm
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
residual
is
not
None
:
residual
=
residual
.
clone
()
out
,
residual
=
rms_norm_layer
.
forward_native
(
x
,
residual
)
...
...
@@ -54,7 +54,7 @@ def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
],
scale_ub
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
scale_ub
is
not
None
:
assert
quant_dtype
==
torch
.
float8_e4m3fn
...
...
@@ -78,7 +78,7 @@ def ref_impl(rms_norm_layer: RMSNorm,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
],
scale_ub
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
return
ref_dynamic_per_token_quant
(
rms_norm_layer
,
x
,
quant_dtype
,
residual
,
scale_ub
)
...
...
@@ -88,7 +88,7 @@ def ops_dynamic_per_token_quant(weight: torch.Tensor,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
],
scale_ub
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
residual
is
not
None
:
residual
=
residual
.
clone
()
out
,
scales
=
ops
.
rms_norm_dynamic_per_token_quant
(
x
,
weight
,
EPS
,
...
...
@@ -102,7 +102,7 @@ def ops_impl(weight: torch.Tensor,
quant_dtype
:
torch
.
dtype
,
residual
:
Optional
[
torch
.
Tensor
],
scale_ub
:
Optional
[
torch
.
Tensor
])
\
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
return
ops_dynamic_per_token_quant
(
weight
,
x
,
quant_dtype
,
residual
,
scale_ub
)
...
...
tests/kernels/test_gguf.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
pathlib
import
Path
from
typing
import
List
import
pytest
import
torch
...
...
@@ -16,7 +15,7 @@ GGUF_SAMPLE = snapshot_download("Isotr0py/test-gguf-sample")
def
get_gguf_sample_tensors
(
hidden_size
:
int
,
quant_type
:
GGMLQuantizationType
)
->
L
ist
[
ReaderTensor
]:
quant_type
:
GGMLQuantizationType
)
->
l
ist
[
ReaderTensor
]:
sample_dir
=
GGUF_SAMPLE
filename
=
f
"Quant_
{
quant_type
.
name
}
_
{
hidden_size
}
.gguf"
sample_file
=
Path
(
sample_dir
)
/
filename
...
...
tests/kernels/test_machete_mm.py
View file @
cf069aa8
...
...
@@ -6,7 +6,7 @@ Run `pytest tests/kernels/test_machete_mm.py`.
import
math
from
dataclasses
import
dataclass
,
fields
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -45,7 +45,7 @@ MNK_SHAPES = [
(
1024
,
8192
,
4096
),
]
GROUP_SIZES_TO_TEST
:
L
ist
[
Optional
[
int
]]
=
[
128
,
-
1
]
GROUP_SIZES_TO_TEST
:
l
ist
[
Optional
[
int
]]
=
[
128
,
-
1
]
@
dataclass
...
...
@@ -75,7 +75,7 @@ class Tensors:
# Ch Scales Type, Tok Scales Type)
# NOTE: None "Scale Type" means the act type is floating point
# None "Output Type" means the output type is the same as the act type
TestTypeTuple
=
T
uple
[
L
ist
[
torch
.
dtype
],
ScalarType
,
Optional
[
torch
.
dtype
],
TestTypeTuple
=
t
uple
[
l
ist
[
torch
.
dtype
],
ScalarType
,
Optional
[
torch
.
dtype
],
Optional
[
torch
.
dtype
],
bool
]
TEST_TYPES
=
[
# GPTQ style
...
...
@@ -136,7 +136,7 @@ def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor):
return
zps
if
zps
is
None
else
-
1
*
s
*
(
zps
.
to
(
s
.
dtype
))
def
group_size_valid
(
shape
:
T
uple
[
int
,
int
,
int
],
def
group_size_valid
(
shape
:
t
uple
[
int
,
int
,
int
],
group_size
:
Optional
[
int
])
->
bool
:
return
group_size
is
None
or
group_size
==
-
1
or
group_size
%
shape
[
2
]
==
0
...
...
@@ -166,7 +166,7 @@ def machete_quantize_and_pack(atype: torch.dtype,
return
w_ref
,
w_q_machete
,
w_s
,
w_zp
def
create_test_tensors
(
shape
:
T
uple
[
int
,
int
,
int
],
def
create_test_tensors
(
shape
:
t
uple
[
int
,
int
,
int
],
types
:
TypeConfig
,
group_size
:
Optional
[
int
],
subset_stride_factor
:
Optional
[
int
]
=
None
)
->
Tensors
:
...
...
@@ -265,7 +265,7 @@ def machete_mm_test_helper(types: TypeConfig,
@
pytest
.
mark
.
parametrize
(
"types"
,
TEST_TYPES
)
def
test_machete_all_schedules
(
shape
,
types
:
TypeConfig
):
group_sizes
:
L
ist
[
Optional
[
int
]]
=
[]
group_sizes
:
l
ist
[
Optional
[
int
]]
=
[]
if
types
.
group_scale_type
is
None
:
group_sizes
=
[
None
]
else
:
...
...
@@ -294,7 +294,7 @@ def test_machete_all_schedules(shape, types: TypeConfig):
ids
=
lambda
x
:
"x"
.
join
(
str
(
v
)
for
v
in
x
))
@
pytest
.
mark
.
parametrize
(
"types"
,
TEST_TYPES
)
def
test_machete_heuristic
(
shape
,
types
:
TypeConfig
):
group_sizes
:
L
ist
[
Optional
[
int
]]
=
[]
group_sizes
:
l
ist
[
Optional
[
int
]]
=
[]
if
types
.
group_scale_type
is
None
:
group_sizes
=
[
None
]
else
:
...
...
tests/kernels/test_mamba_mixer2.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
import
unittest
from
typing
import
Tuple
import
pytest
import
torch
...
...
@@ -29,7 +28,7 @@ from vllm.utils import update_environment_variables
def
test_mixer2_gated_norm_multi_gpu
(
batch_size
:
int
,
seq_len
:
int
,
hidden_size_n_groups
:
T
uple
[
int
,
int
],
hidden_size_n_groups
:
t
uple
[
int
,
int
],
dtype
:
torch
.
dtype
,
device
:
str
=
'cuda'
,
):
...
...
tests/kernels/test_mamba_ssm_ssd.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Dict
,
Tuple
import
pytest
import
torch
import
torch.nn.functional
as
F
...
...
@@ -134,7 +132,7 @@ def generate_continous_batched_examples(example_lens_by_batch,
# given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc
def
get_continuous_batch
(
example_lens
:
T
uple
[
int
,
...]):
def
get_continuous_batch
(
example_lens
:
t
uple
[
int
,
...]):
indices
=
[]
for
i
,
x
in
enumerate
(
example_lens
):
...
...
@@ -264,8 +262,8 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken
:
D
ict
=
{}
# map: eg -> pointer to last taken sample
exhausted
:
D
ict
=
{}
# map: eg -> boolean indicating example is exhausted
last_taken
:
d
ict
=
{}
# map: eg -> pointer to last taken sample
exhausted
:
d
ict
=
{}
# map: eg -> boolean indicating example is exhausted
states
=
None
for
Y_min
,
cu_seqlens
,
sed_idx
,
(
A
,
dt
,
X
,
B
,
...
...
tests/kernels/test_pos_encoding.py
View file @
cf069aa8
# SPDX-License-Identifier: Apache-2.0
from
itertools
import
accumulate
,
product
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Callable
,
Optional
import
pytest
import
torch
...
...
@@ -179,7 +179,7 @@ def test_batched_rotary_embedding_multi_lora(
torch
.
set_default_device
(
device
)
if
rotary_dim
is
None
:
rotary_dim
=
head_size
scaling_factors
:
L
ist
[
int
]
=
[
1
,
2
,
4
]
scaling_factors
:
l
ist
[
int
]
=
[
1
,
2
,
4
]
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
tuple
(
scaling_factors
)
...
...
@@ -234,7 +234,7 @@ def test_rope_module_cache():
})
settings
=
(
HEAD_SIZES
,
ROTARY_DIMS
,
MAX_POSITIONS
,
BASES
,
IS_NEOX_STYLE
,
ROPE_SCALINGS
,
DTYPES
)
rope_setting_id_map
:
D
ict
[
str
,
int
]
=
{}
rope_setting_id_map
:
d
ict
[
str
,
int
]
=
{}
for
setting
in
product
(
*
settings
):
head_size
,
rotary_dim
,
max_position
,
base
,
\
is_neox_stype
,
rope_scaling
,
dtype
=
setting
...
...
tests/kernels/test_triton_scaled_mm.py
View file @
cf069aa8
...
...
@@ -4,7 +4,7 @@
Run `pytest tests/kernels/test_triton_scaled_mm.py`.
"""
import
importlib
from
typing
import
Optional
,
Type
from
typing
import
Optional
import
pytest
import
torch
...
...
@@ -18,7 +18,7 @@ def scaled_mm_torch(a: torch.Tensor,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
T
ype
[
torch
.
dtype
],
out_dtype
:
t
ype
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
out
=
torch
.
mm
(
a
.
to
(
torch
.
float32
),
b
.
to
(
torch
.
float32
))
out
=
scale_a
*
out
...
...
tests/kernels/utils.py
View file @
cf069aa8
...
...
@@ -4,9 +4,9 @@
import
itertools
import
random
import
unittest
from
collections.abc
import
Sequence
from
numbers
import
Number
from
typing
import
(
Any
,
Dict
,
List
,
NamedTuple
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
)
from
typing
import
Any
,
NamedTuple
,
Optional
,
Union
import
pytest
import
torch
...
...
@@ -20,13 +20,13 @@ from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
DEFAULT_OPCHECK_TEST_UTILS
:
T
uple
[
str
,
...]
=
(
DEFAULT_OPCHECK_TEST_UTILS
:
t
uple
[
str
,
...]
=
(
"test_schema"
,
"test_autograd_registration"
,
"test_faketensor"
,
)
ALL_OPCHECK_TEST_UTILS
:
T
uple
[
str
,
...]
=
(
ALL_OPCHECK_TEST_UTILS
:
t
uple
[
str
,
...]
=
(
"test_schema"
,
"test_autograd_registration"
,
"test_faketensor"
,
...
...
@@ -50,8 +50,8 @@ class QKVInputs(NamedTuple):
query
:
torch
.
Tensor
key
:
torch
.
Tensor
value
:
torch
.
Tensor
q_seq_lens
:
L
ist
[
int
]
kv_seq_lens
:
L
ist
[
int
]
q_seq_lens
:
l
ist
[
int
]
kv_seq_lens
:
l
ist
[
int
]
class
QKVO
(
NamedTuple
):
...
...
@@ -89,10 +89,10 @@ class PackedQKVInputs(NamedTuple):
query
:
torch
.
Tensor
key
:
torch
.
Tensor
value
:
torch
.
Tensor
q_start_loc_list
:
Optional
[
L
ist
[
int
]]
kv_start_loc_list
:
Optional
[
L
ist
[
int
]]
q_seq_lens
:
Optional
[
L
ist
[
int
]]
kv_seq_lens
:
Optional
[
L
ist
[
int
]]
q_start_loc_list
:
Optional
[
l
ist
[
int
]]
kv_start_loc_list
:
Optional
[
l
ist
[
int
]]
q_seq_lens
:
Optional
[
l
ist
[
int
]]
kv_seq_lens
:
Optional
[
l
ist
[
int
]]
class
PackedQKVO
(
NamedTuple
):
...
...
@@ -146,7 +146,7 @@ class PhaseTestParameters(NamedTuple):
def
maybe_make_int_tensor
(
_list
:
Optional
[
L
ist
[
int
]],
_list
:
Optional
[
l
ist
[
int
]],
device
:
Union
[
torch
.
device
,
str
],
)
->
torch
.
Tensor
:
'''
...
...
@@ -162,7 +162,7 @@ def maybe_make_int_tensor(
def
maybe_make_long_tensor
(
_list
:
Optional
[
L
ist
[
int
]],
_list
:
Optional
[
l
ist
[
int
]],
device
:
Union
[
torch
.
device
,
str
],
)
->
torch
.
Tensor
:
'''
...
...
@@ -177,7 +177,7 @@ def maybe_make_long_tensor(
_list
,
dtype
=
torch
.
long
,
device
=
device
)
def
maybe_max
(
_list
:
Optional
[
L
ist
])
->
Optional
[
Number
]:
def
maybe_max
(
_list
:
Optional
[
l
ist
])
->
Optional
[
Number
]:
'''
Returns:
...
...
@@ -232,8 +232,8 @@ def ref_masked_attention(query: torch.Tensor,
value
:
torch
.
Tensor
,
scale
:
float
,
custom_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
q_seq_lens
:
Optional
[
L
ist
]
=
None
,
kv_seq_lens
:
Optional
[
L
ist
]
=
None
)
->
torch
.
Tensor
:
q_seq_lens
:
Optional
[
l
ist
]
=
None
,
kv_seq_lens
:
Optional
[
l
ist
]
=
None
)
->
torch
.
Tensor
:
'''
"Golden" masked attention reference. Supports two types of masking:
...
...
@@ -295,10 +295,10 @@ def make_qkv(
num_heads
:
int
,
head_size
:
int
,
device
:
Union
[
torch
.
device
,
str
],
force_kv_seq_lens
:
Optional
[
L
ist
[
int
]]
=
None
,
force_kv_seq_lens
:
Optional
[
l
ist
[
int
]]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
ENCODER_DECODER
,
force_max_len
:
bool
=
False
,
)
->
T
uple
[
QKVInputs
,
QKVInputs
,
QKVInputs
]:
)
->
t
uple
[
QKVInputs
,
QKVInputs
,
QKVInputs
]:
'''
Construct QKV test tensors for self- and cross-attention.
...
...
@@ -429,8 +429,8 @@ def make_qkv(
def
pack_tensor
(
unpacked_tensor
:
torch
.
Tensor
,
seq_lens
:
L
ist
[
int
],
device
:
Union
[
torch
.
device
,
str
])
->
T
uple
[
torch
.
Tensor
,
L
ist
[
int
]]:
unpacked_tensor
:
torch
.
Tensor
,
seq_lens
:
l
ist
[
int
],
device
:
Union
[
torch
.
device
,
str
])
->
t
uple
[
torch
.
Tensor
,
l
ist
[
int
]]:
'''
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
unpadded number_of_tokens x num_heads x head_size tensor, where
...
...
@@ -537,11 +537,11 @@ def make_backend(backend_name: str) -> AttentionBackend:
def
_make_metadata_tensors
(
seq_lens
:
Optional
[
L
ist
[
int
]],
context_lens
:
Optional
[
L
ist
[
int
]],
encoder_seq_lens
:
Optional
[
L
ist
[
int
]],
seq_lens
:
Optional
[
l
ist
[
int
]],
context_lens
:
Optional
[
l
ist
[
int
]],
encoder_seq_lens
:
Optional
[
l
ist
[
int
]],
device
:
Union
[
torch
.
device
,
str
],
)
->
T
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Any
,
Any
,
Optional
[
torch
.
Tensor
],
)
->
t
uple
[
torch
.
Tensor
,
torch
.
Tensor
,
Any
,
Any
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
int
]]:
'''
Build scalar & tensor values required to build attention metadata structure.
...
...
@@ -654,7 +654,7 @@ def make_empty_block_tables_tensor(device: Union[torch.device, str]):
return
torch
.
tensor
([],
device
=
device
)
def
split_slot_mapping
(
slot_mapping_list
:
torch
.
Tensor
,
seq_lens
:
L
ist
[
int
],
def
split_slot_mapping
(
slot_mapping_list
:
torch
.
Tensor
,
seq_lens
:
l
ist
[
int
],
device
:
Union
[
torch
.
device
,
str
]):
'''
Split a slot mapping into valid prefill- and decode-phase slot mappings.
...
...
@@ -682,9 +682,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
Arguments:
* slot_mapping_list: Length-P 1D slot mapping (as
L
ist) reflecting all N
* slot_mapping_list: Length-P 1D slot mapping (as
l
ist) reflecting all N
post-decode sequences
* seq_lens:
L
ist of N post-decode sequence lengths (K_i + 1 in the
* seq_lens:
l
ist of N post-decode sequence lengths (K_i + 1 in the
description above)
* device: cuda, cpu, etc.
...
...
@@ -712,9 +712,9 @@ def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
def
make_block_tables_slot_mapping
(
block_size
:
int
,
seq_lens
:
L
ist
[
int
],
seq_lens
:
l
ist
[
int
],
device
:
Union
[
torch
.
device
,
str
],
block_base_addr
:
int
=
0
)
->
T
uple
[
torch
.
Tensor
,
L
ist
[
int
],
int
]:
block_base_addr
:
int
=
0
)
->
t
uple
[
torch
.
Tensor
,
l
ist
[
int
],
int
]:
'''
Construct fake block tables & slot mappings.
...
...
@@ -794,7 +794,7 @@ def make_block_tables_slot_mapping(
def
make_test_metadata
(
attn_backend
:
_Backend
,
is_prompt
:
bool
,
seq_lens
:
Optional
[
L
ist
[
int
]],
seq_lens
:
Optional
[
l
ist
[
int
]],
decoder_test_params
:
Optional
[
PhaseTestParameters
],
device
:
Union
[
torch
.
device
,
str
],
encoder_test_params
:
Optional
[
PhaseTestParameters
]
=
None
,
...
...
@@ -1043,7 +1043,7 @@ def fp8_allclose(
# Marlin MoE test utils
def
stack_and_dev
(
tensors
:
L
ist
[
torch
.
Tensor
]):
def
stack_and_dev
(
tensors
:
l
ist
[
torch
.
Tensor
]):
dev
=
tensors
[
0
].
device
return
torch
.
stack
(
tensors
,
dim
=
0
).
to
(
dev
)
...
...
@@ -1090,12 +1090,12 @@ def torch_moe_single(a, w, score, topk):
# and a patched version of allclose that supports fp8 types.
def
opcheck
(
op
:
Union
[
torch
.
_ops
.
OpOverload
,
torch
.
_ops
.
OpOverloadPacket
,
torch
.
_library
.
custom_ops
.
CustomOpDef
],
args
:
T
uple
[
Any
,
...],
kwargs
:
Optional
[
D
ict
[
str
,
Any
]]
=
None
,
args
:
t
uple
[
Any
,
...],
kwargs
:
Optional
[
d
ict
[
str
,
Any
]]
=
None
,
*
,
test_utils
:
Union
[
str
,
Sequence
[
str
]]
=
ALL_OPCHECK_TEST_UTILS
,
raise_exception
:
bool
=
True
,
cond
:
bool
=
True
)
->
D
ict
[
str
,
str
]:
cond
:
bool
=
True
)
->
d
ict
[
str
,
str
]:
with
unittest
.
mock
.
patch
(
'torch.allclose'
,
new
=
fp8_allclose
):
return
torch
.
library
.
opcheck
(
op
,
...
...
@@ -1120,7 +1120,7 @@ def baseline_scaled_mm(a: torch.Tensor,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
T
ype
[
torch
.
dtype
],
out_dtype
:
t
ype
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# We treat N-dimensional group scaling as extended numpy-style broadcasting
...
...
Prev
1
2
3
4
5
6
7
8
9
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment