Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ad385667
Commit
ad385667
authored
Oct 23, 2024
by
zhuwenwen
Browse files
Merge branch 'v0.6.3.post1-dev'
parents
be0967c1
903593d3
Changes
364
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
638 additions
and
65 deletions
+638
-65
tests/kernels/test_blocksparse_attention.py
tests/kernels/test_blocksparse_attention.py
+5
-11
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+56
-45
tests/kernels/test_causal_conv1d.py
tests/kernels/test_causal_conv1d.py
+427
-0
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+150
-9
No files found.
Too many changes to show.
To preserve performance only
364 of 364+
files are displayed.
Plain diff
Email patch
tests/kernels/test_blocksparse_attention.py
View file @
ad385667
...
...
@@ -7,7 +7,7 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.ops.blocksparse_attention.interface
import
(
LocalStridedBlockSparseAttn
)
from
vllm.utils
import
get_max_shared_memory_bytes
,
is_hip
from
vllm.utils
import
get_max_shared_memory_bytes
,
is_hip
,
seed_everything
from
.allclose_default
import
get_default_atol
,
get_default_rtol
...
...
@@ -172,10 +172,7 @@ def test_paged_attention(
blocksparse_block_size
:
int
,
blocksparse_head_sliding_step
:
int
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
num_query_heads
,
num_kv_heads
=
num_heads
...
...
@@ -327,7 +324,7 @@ def test_paged_attention(
atol
,
rtol
=
1e-3
,
1e-5
if
kv_cache_dtype
==
"fp8"
:
atol
,
rtol
=
1e-2
,
1e-5
assert
torch
.
all
close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_
close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
def
ref_multi_query_kv_attention
(
...
...
@@ -386,10 +383,7 @@ def test_varlen_blocksparse_attention_prefill(
seed
:
int
,
device
:
str
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
...
...
@@ -441,4 +435,4 @@ def test_varlen_blocksparse_attention_prefill(
scale
,
dtype
,
)
assert
torch
.
all
close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_
close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
)
tests/kernels/test_cache.py
View file @
ad385667
...
...
@@ -4,7 +4,9 @@ from typing import List, Tuple
import
pytest
import
torch
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
seed_everything
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -55,10 +57,7 @@ def test_copy_blocks(
)
->
None
:
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
pytest
.
skip
()
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
# Generate random block mappings where each source block is mapped to two
# destination blocks.
...
...
@@ -88,6 +87,11 @@ def test_copy_blocks(
block_mapping_tensor
=
torch
.
tensor
(
block_mapping
,
dtype
=
torch
.
int64
,
device
=
device
).
view
(
-
1
,
2
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
copy_blocks
,
(
key_caches
,
value_caches
,
block_mapping_tensor
),
test_utils
=
DEFAULT_OPCHECK_TEST_UTILS
,
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping_tensor
)
# Run the reference implementation.
...
...
@@ -99,10 +103,10 @@ def test_copy_blocks(
# Compare the results.
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
assert
torch
.
all
close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_
close
(
key_cache
,
cloned_key_cache
)
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
assert
torch
.
all
close
(
value_cache
,
cloned_value_cache
)
torch
.
testing
.
assert_
close
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -129,10 +133,7 @@ def test_reshape_and_cache(
)
->
None
:
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
pytest
.
skip
()
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
...
...
@@ -163,6 +164,10 @@ def test_reshape_and_cache(
k_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
...
...
@@ -185,17 +190,17 @@ def test_reshape_and_cache(
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
assert
torch
.
all
close
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
assert
torch
.
all
close
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
torch
.
testing
.
assert_
close
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
torch
.
testing
.
assert_
close
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
else
:
assert
torch
.
all
close
(
key_cache
,
cloned_key_cache
)
assert
torch
.
all
close
(
value_cache
,
cloned_value_cache
)
torch
.
testing
.
assert_
close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_
close
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -220,9 +225,7 @@ def test_reshape_and_cache_flash(
device
:
str
,
kv_cache_dtype
:
str
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
seed_everything
(
seed
)
torch
.
set_default_device
(
device
)
# Create a random slot mapping.
...
...
@@ -270,6 +273,10 @@ def test_reshape_and_cache_flash(
k_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
),
cond
=
(
head_size
==
HEAD_SIZES
[
0
]))
ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
)
...
...
@@ -291,17 +298,17 @@ def test_reshape_and_cache_flash(
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
assert
torch
.
all
close
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
assert
torch
.
all
close
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
torch
.
testing
.
assert_
close
(
result_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
rtol
=
0.1
)
torch
.
testing
.
assert_
close
(
result_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
rtol
=
0.1
)
else
:
assert
torch
.
all
close
(
key_cache
,
cloned_key_cache
)
assert
torch
.
all
close
(
value_cache
,
cloned_value_cache
)
torch
.
testing
.
assert_
close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_
close
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
...
...
@@ -332,10 +339,8 @@ def test_swap_blocks(
pytest
.
skip
()
if
kv_cache_dtype
==
"fp8"
and
head_size
%
16
:
pytest
.
skip
()
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
seed_everything
(
seed
)
src_device
=
device
if
direction
[
0
]
==
"cuda"
else
'cpu'
dst_device
=
device
if
direction
[
1
]
==
"cuda"
else
'cpu'
...
...
@@ -367,16 +372,24 @@ def test_swap_blocks(
src_value_caches_clone
=
src_value_caches
[
0
].
clone
()
# Call the swap_blocks kernel.
do_opcheck
=
(
head_size
==
HEAD_SIZES
[
0
])
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
),
cond
=
do_opcheck
)
opcheck
(
torch
.
ops
.
_C_cache_ops
.
swap_blocks
,
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
),
cond
=
do_opcheck
)
ops
.
swap_blocks
(
src_key_caches
[
0
],
dist_key_caches
[
0
],
block_mapping_tensor
)
ops
.
swap_blocks
(
src_value_caches
[
0
],
dist_value_caches
[
0
],
block_mapping_tensor
)
for
src
,
dst
in
block_mapping
:
assert
torch
.
all
close
(
src_key_caches_clone
[
src
].
cpu
(),
dist_key_caches
[
0
][
dst
].
cpu
())
assert
torch
.
all
close
(
src_value_caches_clone
[
src
].
cpu
(),
dist_value_caches
[
0
][
dst
].
cpu
())
torch
.
testing
.
assert_
close
(
src_key_caches_clone
[
src
].
cpu
(),
dist_key_caches
[
0
][
dst
].
cpu
())
torch
.
testing
.
assert_
close
(
src_value_caches_clone
[
src
].
cpu
(),
dist_value_caches
[
0
][
dst
].
cpu
())
# @pytest.mark.parametrize("num_heads", NUM_HEADS)
...
...
@@ -396,9 +409,7 @@ def test_swap_blocks(
# seed: int,
# device: str,
# ) -> None:
# random.seed(seed)
# torch.random.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# seed_everything(seed)
# low = -224.0
# high = 224.0
...
...
@@ -412,4 +423,4 @@ def test_swap_blocks(
# converted_cache = torch.empty_like(cache)
# ops.convert_fp8(converted_cache, cache_fp8)
# assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
\ No newline at end of file
# torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
\ No newline at end of file
tests/kernels/test_causal_conv1d.py
0 → 100644
View file @
ad385667
from
typing
import
Optional
import
pytest
import
torch
import
torch.nn.functional
as
F
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
# noqa: F401
from
vllm.attention.backends.utils
import
PAD_SLOT_ID
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.utils
import
seed_everything
def
causal_conv1d_ref
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_states
:
Optional
[
torch
.
Tensor
]
=
None
,
return_final_states
:
bool
=
False
,
final_states_out
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
dtype_in
=
x
.
dtype
x
=
x
.
to
(
weight
.
dtype
)
seqlen
=
x
.
shape
[
-
1
]
dim
,
width
=
weight
.
shape
if
initial_states
is
None
:
out
=
F
.
conv1d
(
x
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
width
-
1
,
groups
=
dim
)
else
:
x
=
torch
.
cat
([
initial_states
,
x
],
dim
=-
1
)
out
=
F
.
conv1d
(
x
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
0
,
groups
=
dim
)
out
=
out
[...,
:
seqlen
]
if
return_final_states
:
final_states
=
F
.
pad
(
x
,
(
width
-
1
-
x
.
shape
[
-
1
],
0
)).
to
(
dtype_in
)
# (batch, dim, width - 1)
if
final_states_out
is
not
None
:
final_states_out
.
copy_
(
final_states
)
else
:
final_states_out
=
final_states
out
=
(
out
if
activation
is
None
else
F
.
silu
(
out
)).
to
(
dtype
=
dtype_in
)
return
(
out
,
None
)
if
not
return_final_states
else
(
out
,
final_states_out
)
def
causal_conv1d_update_ref
(
x
,
conv_state
,
weight
,
bias
=
None
,
activation
=
None
,
cache_seqlens
=
None
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim) or (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
dtype_in
=
x
.
dtype
unsqueeze
=
x
.
dim
()
==
2
if
unsqueeze
:
x
=
x
.
unsqueeze
(
-
1
)
batch
,
dim
,
seqlen
=
x
.
shape
width
=
weight
.
shape
[
1
]
state_len
=
conv_state
.
shape
[
-
1
]
assert
conv_state
.
shape
==
(
batch
,
dim
,
state_len
)
assert
weight
.
shape
==
(
dim
,
width
)
if
cache_seqlens
is
None
:
x_new
=
torch
.
cat
([
conv_state
,
x
],
dim
=-
1
).
to
(
weight
.
dtype
)
# (batch, dim, state_len + seqlen)
conv_state
.
copy_
(
x_new
[:,
:,
-
state_len
:])
else
:
width_idx
=
torch
.
arange
(
-
(
width
-
1
),
0
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
unsqueeze
(
0
)
+
cache_seqlens
.
unsqueeze
(
1
)
width_idx
=
torch
.
remainder
(
width_idx
,
state_len
).
unsqueeze
(
1
).
expand
(
-
1
,
dim
,
-
1
)
x_new
=
torch
.
cat
([
conv_state
.
gather
(
2
,
width_idx
),
x
],
dim
=-
1
).
to
(
weight
.
dtype
)
copy_idx
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
x
.
device
).
unsqueeze
(
0
)
+
cache_seqlens
.
unsqueeze
(
1
)
copy_idx
=
torch
.
remainder
(
copy_idx
,
state_len
).
unsqueeze
(
1
).
expand
(
-
1
,
dim
,
-
1
)
conv_state
.
scatter_
(
2
,
copy_idx
,
x
)
out
=
F
.
conv1d
(
x_new
,
weight
.
unsqueeze
(
1
),
bias
,
padding
=
0
,
groups
=
dim
)[:,
:,
-
seqlen
:]
if
unsqueeze
:
out
=
out
.
squeeze
(
-
1
)
return
(
out
if
activation
is
None
else
F
.
silu
(
out
)).
to
(
dtype
=
dtype_in
)
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
def
causal_conv1d_opcheck_fn
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seq_len
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
has_initial_state
:
Optional
[
torch
.
Tensor
]
=
None
,
conv_states
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
"silu"
,
pad_slot_id
:
int
=
PAD_SLOT_ID
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
if
x
.
stride
(
-
1
)
!=
1
:
x
=
x
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_fwd
,
(
x
,
weight
,
bias
,
conv_states
,
cu_seq_len
,
cache_indices
,
has_initial_state
,
activation
in
[
"silu"
,
"swish"
],
pad_slot_id
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
,
torch
.
float
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
64
])
@
pytest
.
mark
.
parametrize
(
'batch'
,
[
1
])
def
test_causal_conv1d
(
batch
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
# set seed
seed_everything
(
0
)
x
=
torch
.
randn
(
batch
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
).
contiguous
()
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
initial_states
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
initial_states_ref
=
initial_states
.
clone
(
)
if
initial_states
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_fn
(
x
,
weight
,
bias
,
activation
=
activation
,
conv_states
=
initial_states
,
has_initial_state
=
torch
.
ones
(
batch
,
dtype
=
torch
.
bool
,
device
=
x
.
device
))
out_ref
,
final_states_ref
=
causal_conv1d_ref
(
x_ref
,
weight_ref
,
bias_ref
,
initial_states
=
initial_states_ref
,
return_final_states
=
True
,
activation
=
activation
)
assert
initial_states
is
not
None
and
final_states_ref
is
not
None
assert
torch
.
allclose
(
initial_states
,
final_states_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
causal_conv1d_opcheck_fn
(
x
,
weight
,
bias
,
activation
=
activation
,
conv_states
=
initial_states
,
has_initial_state
=
torch
.
ones
(
batch
,
dtype
=
torch
.
bool
,
device
=
x
.
device
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
def
test_causal_conv1d_update
(
dim
,
width
,
seqlen
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
# set seed
seed_everything
(
0
)
batch
=
2
x
=
torch
.
randn
(
batch
,
dim
,
seqlen
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
conv_state
=
torch
.
randn
(
batch
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
conv_state_ref
=
conv_state
.
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation
=
activation
)
out_ref
=
causal_conv1d_update_ref
(
x_ref
,
conv_state_ref
,
weight
,
bias
,
activation
=
activation
)
assert
torch
.
equal
(
conv_state
,
conv_state_ref
)
assert
torch
.
allclose
(
out
,
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_update
,
(
x
,
conv_state
,
weight
,
bias
,
activation
in
[
"silu"
,
"swish"
],
None
,
None
,
PAD_SLOT_ID
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"seqlen"
,
[
1
,
4
,
5
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
2
,
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"dim"
,
[
2048
,
2048
+
16
,
4096
])
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
"with_padding"
,
[
True
,
False
])
def
test_causal_conv1d_update_with_batch_gather
(
with_padding
,
dim
,
width
,
seqlen
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
# set seed
seed_everything
(
0
)
batch_size
=
3
padding
=
5
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
total_entries
=
10
*
batch_size
x
=
torch
.
randn
(
padded_batch_size
,
dim
,
1
,
device
=
device
,
dtype
=
itype
)
x_ref
=
x
.
clone
()
conv_state_indices
=
torch
.
randperm
(
total_entries
)[:
batch_size
].
to
(
dtype
=
torch
.
int32
,
device
=
device
)
unused_states_bool
=
torch
.
ones
(
total_entries
,
dtype
=
torch
.
bool
,
device
=
device
)
unused_states_bool
[
conv_state_indices
]
=
False
padded_state_indices
=
torch
.
concat
([
conv_state_indices
,
torch
.
as_tensor
(
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
)
],
dim
=
0
)
conv_state
=
torch
.
randn
(
total_entries
,
dim
,
width
-
1
,
device
=
device
,
dtype
=
itype
)
conv_state_for_padding_test
=
conv_state
.
clone
()
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
conv_state_ref
=
conv_state
[
conv_state_indices
,
:].
detach
().
clone
()
activation
=
None
if
not
silu_activation
else
"silu"
out
=
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation
=
activation
,
conv_state_indices
=
padded_state_indices
,
pad_slot_id
=
PAD_SLOT_ID
)
out_ref
=
causal_conv1d_update_ref
(
x_ref
[:
batch_size
],
conv_state_ref
,
weight
,
bias
,
activation
=
activation
)
assert
torch
.
equal
(
conv_state
[
conv_state_indices
,
:],
conv_state_ref
)
assert
torch
.
allclose
(
out
[:
batch_size
],
out_ref
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
equal
(
conv_state
[
unused_states_bool
],
conv_state_for_padding_test
[
unused_states_bool
])
opcheck
(
torch
.
ops
.
_C
.
causal_conv1d_update
,
(
x
,
conv_state
,
weight
,
bias
,
activation
in
[
"silu"
,
"swish"
],
None
,
padded_state_indices
,
PAD_SLOT_ID
))
@
pytest
.
mark
.
parametrize
(
"itype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"silu_activation"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"has_bias"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"width"
,
[
4
])
@
pytest
.
mark
.
parametrize
(
'seqlen'
,
[
8
,
16
,
32
,
64
,
128
,
256
,
512
,
784
,
1024
,
2048
,
2049
,
4096
])
@
pytest
.
mark
.
parametrize
(
'dim'
,
[
64
,
4096
])
# tests correctness in case subset of the sequences are padded
@
pytest
.
mark
.
parametrize
(
'with_padding'
,
[
True
,
False
])
def
test_causal_conv1d_varlen
(
with_padding
,
dim
,
seqlen
,
width
,
has_bias
,
silu_activation
,
itype
):
device
=
"cuda"
torch
.
cuda
.
empty_cache
()
rtol
,
atol
=
(
3e-4
,
1e-3
)
if
itype
==
torch
.
float32
else
(
3e-3
,
5e-3
)
if
itype
==
torch
.
bfloat16
:
rtol
,
atol
=
1e-2
,
5e-2
# set seed
seed_everything
(
0
)
seqlens
=
[]
batch_size
=
4
if
seqlen
<
10
:
batch_size
=
1
padding
=
3
if
with_padding
else
0
padded_batch_size
=
batch_size
+
padding
nsplits
=
padded_batch_size
-
1
eos_pos
=
torch
.
randperm
(
seqlen
-
1
)[:
nsplits
].
sort
().
values
seqlens
.
append
(
torch
.
diff
(
torch
.
cat
(
[
torch
.
tensor
([
-
1
]),
eos_pos
,
torch
.
tensor
([
seqlen
-
1
])])).
tolist
())
assert
sum
(
seqlens
[
-
1
])
==
seqlen
assert
all
(
s
>
0
for
s
in
seqlens
[
-
1
])
total_entries
=
batch_size
*
10
cumsum
=
torch
.
cumsum
(
torch
.
tensor
(
seqlens
[
0
]),
dim
=
0
).
to
(
torch
.
int32
)
cumsum
=
torch
.
concat
([
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
),
cumsum
],
dim
=
0
)
x
=
torch
.
randn
(
1
,
4096
+
dim
+
64
,
seqlen
,
device
=
device
,
dtype
=
itype
)[:,
4096
:
4096
+
dim
,
:]
weight
=
torch
.
randn
(
dim
,
width
,
device
=
device
,
dtype
=
itype
)
bias
=
torch
.
randn
(
dim
,
device
=
device
,
dtype
=
itype
)
if
has_bias
else
None
x_ref
=
x
.
clone
()
weight_ref
=
weight
.
clone
()
bias_ref
=
bias
.
clone
()
if
bias
is
not
None
else
None
activation
=
None
if
not
silu_activation
else
"silu"
final_states
=
torch
.
randn
(
total_entries
,
dim
,
width
-
1
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
final_states_ref
=
final_states
.
clone
()
has_initial_states
=
torch
.
randint
(
0
,
2
,
(
cumsum
.
shape
[
0
]
-
1
,
),
dtype
=
torch
.
bool
,
device
=
x
.
device
)
state_indices
=
torch
.
randperm
(
total_entries
,
dtype
=
torch
.
int32
,
device
=
x
.
device
)[:
batch_size
]
padded_state_indices
=
torch
.
concat
([
state_indices
,
torch
.
as_tensor
(
[
PAD_SLOT_ID
]
*
padding
,
dtype
=
torch
.
int32
,
device
=
device
),
],
dim
=-
1
)
out
=
causal_conv1d_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
padded_state_indices
,
has_initial_states
,
final_states
,
activation
,
PAD_SLOT_ID
)
out_ref
=
[]
out_ref_b
=
[]
splits
=
[
torch
.
split
(
var
,
seqlens
[
0
],
dim
=-
1
)
for
var
in
(
x_ref
)]
for
i
in
range
(
len
(
seqlens
[
0
])):
x_s
=
[
v
[
i
].
unsqueeze
(
0
)
for
v
in
splits
][
0
]
if
padded_state_indices
[
i
]
==
PAD_SLOT_ID
:
continue
out_ref_b
.
append
(
causal_conv1d_ref
(
x_s
,
weight_ref
,
bias_ref
,
activation
=
activation
,
return_final_states
=
True
,
final_states_out
=
final_states_ref
[
padded_state_indices
[
i
]].
unsqueeze
(
0
),
initial_states
=
final_states_ref
[
padded_state_indices
[
i
]].
unsqueeze
(
0
)
if
has_initial_states
[
i
]
else
None
))
out_ref
.
append
(
torch
.
cat
([
t
[
0
]
for
t
in
out_ref_b
],
dim
=
2
))
out_ref_tensor
=
torch
.
cat
(
out_ref
,
dim
=
0
)
unpadded_out
=
out
[:,
:
out_ref_tensor
.
shape
[
-
1
]]
assert
torch
.
allclose
(
unpadded_out
,
out_ref_tensor
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
final_states
,
final_states_ref
,
rtol
=
rtol
,
atol
=
atol
)
causal_conv1d_opcheck_fn
(
x
.
squeeze
(
0
),
weight
,
bias
,
cumsum
.
cuda
(),
padded_state_indices
,
has_initial_states
,
final_states
,
activation
)
tests/kernels/test_cutlass.py
View file @
ad385667
...
...
@@ -7,6 +7,7 @@ from typing import Optional, Type
import
pytest
import
torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
...
...
@@ -28,13 +29,16 @@ def to_int8(tensor: torch.Tensor):
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
rand_int8
(
shape
:
tuple
,
device
:
str
=
"cuda"
):
return
to_int8
(
torch
.
rand
(
shape
,
device
=
device
)
*
255
-
128
)
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
(
scale_a
*
(
scale_b
*
(
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))))).
to
(
out_dtype
)
if
bias
is
not
None
:
...
...
@@ -71,7 +75,10 @@ def cutlass_fp8_gemm_helper(m: int,
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5e-2
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5e-2
)
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
def
cutlass_int8_gemm_helper
(
m
:
int
,
...
...
@@ -103,7 +110,10 @@ def cutlass_int8_gemm_helper(m: int,
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
100
,
33
])
...
...
@@ -112,7 +122,7 @@ def cutlass_int8_gemm_helper(m: int,
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_
capability
(
89
)
,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
...
...
@@ -150,7 +160,7 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_
capability
(
89
)
,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
],
...
...
@@ -168,7 +178,7 @@ def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
capability
<
89
,
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_
capability
(
89
)
,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
,
device
:
str
):
...
...
@@ -200,7 +210,7 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_
capability
(
89
)
,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
...
...
@@ -221,6 +231,133 @@ def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
64
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
skip
def
test_cutlass_int8_azp_bias_fold
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
):
# Currently, the test is failing because folding azp into
# 16-bit bias loses too much precision
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
aq_i8
=
rand_int8
((
m
,
k
))
bq_i8
=
rand_int8
((
n
,
k
)).
t
()
aq_i32
=
aq_i8
.
to
(
dtype
=
torch
.
int32
)
bq_i32
=
bq_i8
.
to
(
dtype
=
torch
.
int32
)
aq_f32
=
aq_i8
.
to
(
dtype
=
torch
.
float32
)
bq_f32
=
bq_i8
.
to
(
dtype
=
torch
.
float32
)
b_dq
=
scale_b
*
bq_f32
azp_a
=
torch
.
rand
((
1
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
10
+
1.5
azp_aq_i8
=
(
azp_a
/
scale_a
).
to
(
dtype
=
torch
.
int8
)
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
a_dq
=
scale_a
*
(
aq_i32
+
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
torch
.
testing
.
assert_close
(
a_dq
,
scale_a
*
aq_f32
+
azp_a
)
baseline_dq
=
torch
.
mm
(
a_dq
,
b_dq
).
to
(
out_dtype
)
J
=
torch
.
ones
((
1
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
azp_bias
=
(
azp_a
*
scale_b
*
(
J
@
bq_f32
)).
to
(
out_dtype
)
assert
azp_bias
.
shape
==
(
1
,
n
)
assert
azp_bias
[
0
,
:].
shape
==
(
n
,
)
baseline_q
=
(
scale_a
.
to
(
device
=
'cpu'
)
*
scale_b
.
to
(
device
=
'cpu'
)
*
(
(
aq_i32
+
azp_aq_i8
).
to
(
device
=
'cpu'
)
@
bq_i32
.
to
(
device
=
'cpu'
))).
to
(
dtype
=
out_dtype
,
device
=
'cuda'
)
out
=
ops
.
cutlass_scaled_mm
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
,
bias
=
azp_bias
[
0
,
:])
torch
.
testing
.
assert_close
(
out
,
baseline_dq
,
rtol
=
1e-2
,
atol
=
1e0
)
torch
.
testing
.
assert_close
(
out
,
baseline_q
,
rtol
=
1e-2
,
atol
=
1e0
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
64
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"azp_per_token"
,
[
True
,
False
])
def
test_cutlass_int8_azp
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
,
use_bias
:
bool
,
azp_per_token
:
bool
):
m_azp
=
m
if
azp_per_token
else
1
scale_a
=
torch
.
randn
((
m_azp
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
aq_i8
=
rand_int8
((
m
,
k
))
aq_i32
=
aq_i8
.
to
(
dtype
=
torch
.
int32
)
aq_f32
=
aq_i8
.
to
(
dtype
=
torch
.
float32
)
bq_i8
=
rand_int8
((
n
,
k
)).
t
()
bq_i32
=
bq_i8
.
to
(
dtype
=
torch
.
int32
)
bq_f32
=
bq_i8
.
to
(
dtype
=
torch
.
float32
)
b_dq
=
scale_b
*
bq_f32
azp_a
=
torch
.
rand
(
(
m_azp
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
10
+
1.5
azp_aq_i8
=
(
azp_a
/
scale_a
).
to
(
dtype
=
torch
.
int8
)
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
a_dq
=
scale_a
*
(
aq_i32
-
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
torch
.
testing
.
assert_close
(
a_dq
,
scale_a
*
aq_f32
-
azp_a
,
rtol
=
1e-4
,
atol
=
1e-3
)
if
use_bias
:
bias
=
torch
.
rand
((
1
,
n
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
+
2.5
else
:
bias
=
torch
.
zeros
((
1
,
n
),
device
=
"cuda"
,
dtype
=
out_dtype
)
baseline_dq
=
(
torch
.
mm
(
a_dq
,
b_dq
)
+
bias
).
to
(
out_dtype
)
# int32 mm not supported on CUDA
a_noazp_i32_cpu
=
(
aq_i32
-
azp_aq_i8
).
to
(
device
=
'cpu'
)
cq
=
(
a_noazp_i32_cpu
@
bq_i32
.
to
(
device
=
'cpu'
)).
to
(
device
=
'cuda'
)
baseline_q
=
(
scale_a
*
scale_b
*
cq
+
bias
).
to
(
dtype
=
out_dtype
)
# Hadamard is just the sum of the cols
azp_adj_i32
=
bq_i32
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
azp_i32
=
azp_aq_i8
.
to
(
dtype
=
torch
.
int32
)
func_bias
=
bias
if
use_bias
else
None
if
azp_per_token
:
out
=
ops
.
cutlass_scaled_mm_azp
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
out_dtype
,
azp_adj_i32
,
azp_i32
,
func_bias
)
else
:
azp_with_adj_i32
=
azp_i32
*
azp_adj_i32
out
=
ops
.
cutlass_scaled_mm_azp
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
out_dtype
,
azp_with_adj_i32
,
None
,
func_bias
)
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
rtol
=
1e-2
if
out_dtype
==
torch
.
bfloat16
else
1e-3
atol
=
1e-3
torch
.
testing
.
assert_close
(
out
,
baseline_dq
,
rtol
=
rtol
,
atol
=
atol
)
torch
.
testing
.
assert_close
(
out
,
baseline_q
,
rtol
=
rtol
,
atol
=
atol
)
if
azp_per_token
:
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
,
(
out
,
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
azp_adj_i32
,
azp_i32
,
func_bias
))
else
:
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
,
(
out
,
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
azp_with_adj_i32
,
None
,
func_bias
))
# Test working with a subset of A and B
def
test_cutlass_subset
():
big_m
,
big_n
,
big_k
=
1024
,
1024
,
1024
...
...
@@ -245,7 +382,7 @@ def test_cutlass_subset():
scale_b
,
out_dtype
=
torch
.
bfloat16
)
assert
torch
.
all
close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
torch
.
testing
.
assert_
close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
# Test to make sure cuda graphs work
...
...
@@ -293,4 +430,8 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
bfloat16
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
def
test_cutlass_support_opcheck
():
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp8
,
(
capability
,
))
Prev
1
…
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment