Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
50b8d08d
"vllm/vscode:/vscode.git/clone" did not exist on "e0919f331d12dc5dbdefd0775bb6f94dd2fab4e2"
Unverified
Commit
50b8d08d
authored
Aug 15, 2024
by
jon-chuang
Committed by
GitHub
Aug 16, 2024
Browse files
[Misc/Testing] Use `torch.testing.assert_close` (#7324)
parent
e1655287
Changes
25
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
177 additions
and
172 deletions
+177
-172
tests/distributed/test_comm_ops.py
tests/distributed/test_comm_ops.py
+9
-9
tests/distributed/test_custom_all_reduce.py
tests/distributed/test_custom_all_reduce.py
+4
-4
tests/kernels/quant_utils.py
tests/kernels/quant_utils.py
+1
-1
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+5
-5
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+2
-2
tests/kernels/test_blocksparse_attention.py
tests/kernels/test_blocksparse_attention.py
+2
-2
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+27
-27
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+13
-10
tests/kernels/test_flash_attn.py
tests/kernels/test_flash_attn.py
+2
-2
tests/kernels/test_flashinfer.py
tests/kernels/test_flashinfer.py
+2
-2
tests/kernels/test_fp8_quant.py
tests/kernels/test_fp8_quant.py
+7
-7
tests/kernels/test_int8_quant.py
tests/kernels/test_int8_quant.py
+7
-5
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+3
-3
tests/kernels/test_marlin_gemm.py
tests/kernels/test_marlin_gemm.py
+2
-2
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+5
-5
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+24
-24
tests/kernels/test_sampler.py
tests/kernels/test_sampler.py
+4
-4
tests/kernels/utils.py
tests/kernels/utils.py
+2
-2
tests/lora/test_layers.py
tests/lora/test_layers.py
+44
-44
tests/lora/test_lora_manager.py
tests/lora/test_lora_manager.py
+12
-12
No files found.
tests/distributed/test_comm_ops.py
View file @
50b8d08d
...
@@ -34,7 +34,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
...
@@ -34,7 +34,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
expected
=
torch
.
sum
(
torch
.
stack
(
all_tensors
,
dim
=
0
),
dim
=
0
)
expected
=
torch
.
sum
(
torch
.
stack
(
all_tensors
,
dim
=
0
),
dim
=
0
)
t
=
all_tensors
[
rank
%
tp_size
]
t
=
all_tensors
[
rank
%
tp_size
]
t
=
tensor_model_parallel_all_reduce
(
t
)
t
=
tensor_model_parallel_all_reduce
(
t
)
assert
torch
.
all
close
(
t
,
expected
)
torch
.
testing
.
assert_
close
(
t
,
expected
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
...
@@ -62,7 +62,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
...
@@ -62,7 +62,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
expected
=
torch
.
cat
(
all_tensors
,
dim
=
all_gather_dimension
)
expected
=
torch
.
cat
(
all_tensors
,
dim
=
all_gather_dimension
)
t
=
all_tensors
[
rank
%
tp_size
]
t
=
all_tensors
[
rank
%
tp_size
]
t
=
tensor_model_parallel_all_gather
(
t
,
all_gather_dimension
)
t
=
tensor_model_parallel_all_gather
(
t
,
all_gather_dimension
)
assert
torch
.
all
close
(
t
,
expected
)
torch
.
testing
.
assert_
close
(
t
,
expected
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
...
@@ -96,12 +96,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
...
@@ -96,12 +96,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
else
:
else
:
recv_dict
=
broadcast_tensor_dict
(
src
=
0
)
recv_dict
=
broadcast_tensor_dict
(
src
=
0
)
assert
len
(
recv_dict
)
==
len
(
test_dict
)
assert
len
(
recv_dict
)
==
len
(
test_dict
)
assert
torch
.
all
close
(
recv_dict
[
"a"
],
test_dict
[
"a"
])
torch
.
testing
.
assert_
close
(
recv_dict
[
"a"
],
test_dict
[
"a"
])
assert
torch
.
all
close
(
recv_dict
[
"b"
],
test_dict
[
"b"
])
torch
.
testing
.
assert_
close
(
recv_dict
[
"b"
],
test_dict
[
"b"
])
assert
recv_dict
[
"c"
]
==
test_dict
[
"c"
]
assert
recv_dict
[
"c"
]
==
test_dict
[
"c"
]
assert
recv_dict
[
"d"
]
==
test_dict
[
"d"
]
assert
recv_dict
[
"d"
]
==
test_dict
[
"d"
]
assert
recv_dict
[
"e"
]
==
test_dict
[
"e"
]
assert
recv_dict
[
"e"
]
==
test_dict
[
"e"
]
assert
torch
.
all
close
(
recv_dict
[
"f"
],
test_dict
[
"f"
])
torch
.
testing
.
assert_
close
(
recv_dict
[
"f"
],
test_dict
[
"f"
])
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
...
@@ -136,12 +136,12 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
...
@@ -136,12 +136,12 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
if
not
get_pp_group
().
is_first_rank
:
if
not
get_pp_group
().
is_first_rank
:
assert
len
(
recv_dict
)
==
len
(
test_dict
)
assert
len
(
recv_dict
)
==
len
(
test_dict
)
assert
torch
.
all
close
(
recv_dict
[
"a"
],
test_dict
[
"a"
])
torch
.
testing
.
assert_
close
(
recv_dict
[
"a"
],
test_dict
[
"a"
])
assert
torch
.
all
close
(
recv_dict
[
"b"
],
test_dict
[
"b"
])
torch
.
testing
.
assert_
close
(
recv_dict
[
"b"
],
test_dict
[
"b"
])
assert
recv_dict
[
"c"
]
==
test_dict
[
"c"
]
assert
recv_dict
[
"c"
]
==
test_dict
[
"c"
]
assert
recv_dict
[
"d"
]
==
test_dict
[
"d"
]
assert
recv_dict
[
"d"
]
==
test_dict
[
"d"
]
assert
recv_dict
[
"e"
]
==
test_dict
[
"e"
]
assert
recv_dict
[
"e"
]
==
test_dict
[
"e"
]
assert
torch
.
all
close
(
recv_dict
[
"f"
],
test_dict
[
"f"
])
torch
.
testing
.
assert_
close
(
recv_dict
[
"f"
],
test_dict
[
"f"
])
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
...
@@ -163,7 +163,7 @@ def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
...
@@ -163,7 +163,7 @@ def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
get_pp_group
().
send
(
test_tensor
)
get_pp_group
().
send
(
test_tensor
)
if
not
get_pp_group
().
is_first_rank
:
if
not
get_pp_group
().
is_first_rank
:
assert
torch
.
all
close
(
test_tensor
,
recv_tensor
)
torch
.
testing
.
assert_
close
(
test_tensor
,
recv_tensor
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
...
...
tests/distributed/test_custom_all_reduce.py
View file @
50b8d08d
...
@@ -72,8 +72,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
...
@@ -72,8 +72,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
out2
=
tensor_model_parallel_all_reduce
(
inp2
)
out2
=
tensor_model_parallel_all_reduce
(
inp2
)
dist
.
all_reduce
(
inp2
,
group
=
group
)
dist
.
all_reduce
(
inp2
,
group
=
group
)
graph
.
replay
()
graph
.
replay
()
assert
torch
.
all
close
(
out1
,
inp1
)
torch
.
testing
.
assert_
close
(
out1
,
inp1
)
assert
torch
.
all
close
(
out2
,
inp2
)
torch
.
testing
.
assert_
close
(
out2
,
inp2
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
...
@@ -96,13 +96,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
...
@@ -96,13 +96,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
out
=
inp
out
=
inp
for
_
in
range
(
num_communication
):
for
_
in
range
(
num_communication
):
out
=
fa
.
all_reduce_unreg
(
out
)
out
=
fa
.
all_reduce_unreg
(
out
)
assert
torch
.
all
close
(
out
,
inp
*
(
tp_size
**
num_communication
))
torch
.
testing
.
assert_
close
(
out
,
inp
*
(
tp_size
**
num_communication
))
inp
=
torch
.
ones
(
sz
*
4
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
inp
=
torch
.
ones
(
sz
*
4
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
out
=
inp
out
=
inp
for
_
in
range
(
num_communication
):
for
_
in
range
(
num_communication
):
out
=
fa
.
all_reduce_unreg
(
out
)
out
=
fa
.
all_reduce_unreg
(
out
)
assert
torch
.
all
close
(
out
,
inp
*
(
tp_size
**
num_communication
))
torch
.
testing
.
assert_
close
(
out
,
inp
*
(
tp_size
**
num_communication
))
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
...
...
tests/kernels/quant_utils.py
View file @
50b8d08d
...
@@ -69,4 +69,4 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
...
@@ -69,4 +69,4 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
ref_iscale
=
one
/
ref_scale
ref_iscale
=
one
/
ref_scale
ref_out
=
(
as_float32_tensor
(
x
)
*
ref_iscale
).
clamp
(
ref_out
=
(
as_float32_tensor
(
x
)
*
ref_iscale
).
clamp
(
fp8_traits
.
min
,
fp8_traits
.
max
).
to
(
dtype
=
torch
.
float8_e4m3fn
)
fp8_traits
.
min
,
fp8_traits
.
max
).
to
(
dtype
=
torch
.
float8_e4m3fn
)
return
ref_out
,
ref_scale
return
ref_out
,
ref_scale
.
view
((
1
,
))
tests/kernels/test_activation.py
View file @
50b8d08d
...
@@ -47,7 +47,7 @@ def test_act_and_mul(
...
@@ -47,7 +47,7 @@ def test_act_and_mul(
ref_out
=
layer
.
forward_native
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
# The SiLU and GELU implementations are equivalent to the native PyTorch
# The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison.
# implementations, so we can do exact comparison.
assert
torch
.
all
close
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
torch
.
testing
.
assert_
close
(
out
,
ref_out
,
atol
=
0.0
,
rtol
=
0.0
)
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
FastGELU
,
NewGELU
])
@
pytest
.
mark
.
parametrize
(
"activation"
,
[
FastGELU
,
NewGELU
])
...
@@ -73,7 +73,7 @@ def test_activation(
...
@@ -73,7 +73,7 @@ def test_activation(
layer
=
activation
()
layer
=
activation
()
out
=
layer
(
x
)
out
=
layer
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
ref_out
=
layer
.
forward_native
(
x
)
assert
torch
.
all
close
(
out
,
torch
.
testing
.
assert_
close
(
out
,
ref_out
,
ref_out
,
atol
=
get_default_atol
(
out
),
atol
=
get_default_atol
(
out
),
rtol
=
get_default_rtol
(
out
))
rtol
=
get_default_rtol
(
out
))
tests/kernels/test_attention.py
View file @
50b8d08d
...
@@ -276,7 +276,7 @@ def test_paged_attention(
...
@@ -276,7 +276,7 @@ def test_paged_attention(
atol
,
rtol
=
1e-3
,
1e-5
atol
,
rtol
=
1e-3
,
1e-5
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
atol
,
rtol
=
1e-2
,
1e-5
atol
,
rtol
=
1e-2
,
1e-5
assert
torch
.
all
close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_
close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
def
ref_multi_query_kv_attention
(
def
ref_multi_query_kv_attention
(
...
@@ -379,4 +379,4 @@ def test_multi_query_kv_attention(
...
@@ -379,4 +379,4 @@ def test_multi_query_kv_attention(
)
)
atol
=
get_default_atol
(
output
)
if
is_hip
()
else
1e-3
atol
=
get_default_atol
(
output
)
if
is_hip
()
else
1e-3
rtol
=
get_default_rtol
(
output
)
if
is_hip
()
else
1e-5
rtol
=
get_default_rtol
(
output
)
if
is_hip
()
else
1e-5
assert
torch
.
all
close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_
close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
tests/kernels/test_blocksparse_attention.py
View file @
50b8d08d
...
@@ -327,7 +327,7 @@ def test_paged_attention(
...
@@ -327,7 +327,7 @@ def test_paged_attention(
atol
,
rtol
=
1e-3
,
1e-5
atol
,
rtol
=
1e-3
,
1e-5
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
atol
,
rtol
=
1e-2
,
1e-5
atol
,
rtol
=
1e-2
,
1e-5
assert
torch
.
all
close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
torch
.
testing
.
assert_
close
(
output
,
ref_output
,
atol
=
atol
,
rtol
=
rtol
)
def
ref_multi_query_kv_attention
(
def
ref_multi_query_kv_attention
(
...
@@ -441,4 +441,4 @@ def test_varlen_blocksparse_attention_prefill(
...
@@ -441,4 +441,4 @@ def test_varlen_blocksparse_attention_prefill(
scale
,
scale
,
dtype
,
dtype
,
)
)
assert
torch
.
all
close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_
close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
)
tests/kernels/test_cache.py
View file @
50b8d08d
...
@@ -98,10 +98,10 @@ def test_copy_blocks(
...
@@ -98,10 +98,10 @@ def test_copy_blocks(
# Compare the results.
# Compare the results.
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
for
key_cache
,
cloned_key_cache
in
zip
(
key_caches
,
cloned_key_caches
):
assert
torch
.
all
close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_
close
(
key_cache
,
cloned_key_cache
)
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
for
value_cache
,
cloned_value_cache
in
zip
(
value_caches
,
cloned_value_caches
):
cloned_value_caches
):
assert
torch
.
all
close
(
value_cache
,
cloned_value_cache
)
torch
.
testing
.
assert_
close
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
@@ -184,17 +184,17 @@ def test_reshape_and_cache(
...
@@ -184,17 +184,17 @@ def test_reshape_and_cache(
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
cloned_value_cache
[
block_idx
,
:,
:,
block_offset
]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
assert
torch
.
all
close
(
result_key_cache
,
torch
.
testing
.
assert_
close
(
result_key_cache
,
cloned_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
atol
=
0.001
,
rtol
=
0.1
)
rtol
=
0.1
)
assert
torch
.
all
close
(
result_value_cache
,
torch
.
testing
.
assert_
close
(
result_value_cache
,
cloned_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
atol
=
0.001
,
rtol
=
0.1
)
rtol
=
0.1
)
else
:
else
:
assert
torch
.
all
close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_
close
(
key_cache
,
cloned_key_cache
)
assert
torch
.
all
close
(
value_cache
,
cloned_value_cache
)
torch
.
testing
.
assert_
close
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
@@ -290,17 +290,17 @@ def test_reshape_and_cache_flash(
...
@@ -290,17 +290,17 @@ def test_reshape_and_cache_flash(
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
assert
torch
.
all
close
(
result_key_cache
,
torch
.
testing
.
assert_
close
(
result_key_cache
,
cloned_key_cache
,
cloned_key_cache
,
atol
=
0.001
,
atol
=
0.001
,
rtol
=
0.1
)
rtol
=
0.1
)
assert
torch
.
all
close
(
result_value_cache
,
torch
.
testing
.
assert_
close
(
result_value_cache
,
cloned_value_cache
,
cloned_value_cache
,
atol
=
0.001
,
atol
=
0.001
,
rtol
=
0.1
)
rtol
=
0.1
)
else
:
else
:
assert
torch
.
all
close
(
key_cache
,
cloned_key_cache
)
torch
.
testing
.
assert_
close
(
key_cache
,
cloned_key_cache
)
assert
torch
.
all
close
(
value_cache
,
cloned_value_cache
)
torch
.
testing
.
assert_
close
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
...
@@ -372,9 +372,9 @@ def test_swap_blocks(
...
@@ -372,9 +372,9 @@ def test_swap_blocks(
block_mapping_tensor
)
block_mapping_tensor
)
for
src
,
dst
in
block_mapping
:
for
src
,
dst
in
block_mapping
:
assert
torch
.
all
close
(
src_key_caches_clone
[
src
].
cpu
(),
torch
.
testing
.
assert_
close
(
src_key_caches_clone
[
src
].
cpu
(),
dist_key_caches
[
0
][
dst
].
cpu
())
dist_key_caches
[
0
][
dst
].
cpu
())
assert
torch
.
all
close
(
src_value_caches_clone
[
src
].
cpu
(),
torch
.
testing
.
assert_
close
(
src_value_caches_clone
[
src
].
cpu
(),
dist_value_caches
[
0
][
dst
].
cpu
())
dist_value_caches
[
0
][
dst
].
cpu
())
...
@@ -411,4 +411,4 @@ def test_fp8_e4m3_conversion(
...
@@ -411,4 +411,4 @@ def test_fp8_e4m3_conversion(
converted_cache
=
torch
.
empty_like
(
cache
)
converted_cache
=
torch
.
empty_like
(
cache
)
ops
.
convert_fp8
(
converted_cache
,
cache_fp8
)
ops
.
convert_fp8
(
converted_cache
,
cache_fp8
)
assert
torch
.
all
close
(
cache
,
converted_cache
,
atol
=
0.001
,
rtol
=
0.1
)
torch
.
testing
.
assert_
close
(
cache
,
converted_cache
,
atol
=
0.001
,
rtol
=
0.1
)
tests/kernels/test_cutlass.py
View file @
50b8d08d
...
@@ -74,7 +74,7 @@ def cutlass_fp8_gemm_helper(m: int,
...
@@ -74,7 +74,7 @@ def cutlass_fp8_gemm_helper(m: int,
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
assert
torch
.
all
close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5e-2
)
torch
.
testing
.
assert_
close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5e-2
)
def
cutlass_int8_gemm_helper
(
m
:
int
,
def
cutlass_int8_gemm_helper
(
m
:
int
,
...
@@ -106,7 +106,7 @@ def cutlass_int8_gemm_helper(m: int,
...
@@ -106,7 +106,7 @@ def cutlass_int8_gemm_helper(m: int,
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
assert
torch
.
all
close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
torch
.
testing
.
assert_
close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
100
,
33
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
100
,
33
])
...
@@ -252,7 +252,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
...
@@ -252,7 +252,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
a_dq
=
scale_a
*
(
aq_i32
+
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
a_dq
=
scale_a
*
(
aq_i32
+
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
assert
torch
.
all
close
(
a_dq
,
scale_a
*
aq_f32
+
azp_a
)
torch
.
testing
.
assert_
close
(
a_dq
,
scale_a
*
aq_f32
+
azp_a
)
baseline_dq
=
torch
.
mm
(
a_dq
,
b_dq
).
to
(
out_dtype
)
baseline_dq
=
torch
.
mm
(
a_dq
,
b_dq
).
to
(
out_dtype
)
...
@@ -271,8 +271,8 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
...
@@ -271,8 +271,8 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
scale_b
,
scale_b
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
bias
=
azp_bias
[
0
,
:])
bias
=
azp_bias
[
0
,
:])
assert
torch
.
all
close
(
out
,
baseline_dq
,
rtol
=
1e-2
,
atol
=
1e0
)
torch
.
testing
.
assert_
close
(
out
,
baseline_dq
,
rtol
=
1e-2
,
atol
=
1e0
)
assert
torch
.
all
close
(
out
,
baseline_q
,
rtol
=
1e-2
,
atol
=
1e0
)
torch
.
testing
.
assert_
close
(
out
,
baseline_q
,
rtol
=
1e-2
,
atol
=
1e0
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
...
@@ -302,7 +302,10 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
...
@@ -302,7 +302,10 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
a_dq
=
scale_a
*
(
aq_i32
-
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
a_dq
=
scale_a
*
(
aq_i32
-
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
assert
torch
.
allclose
(
a_dq
,
scale_a
*
aq_f32
-
azp_a
,
rtol
=
1e-4
,
atol
=
1e-3
)
torch
.
testing
.
assert_close
(
a_dq
,
scale_a
*
aq_f32
-
azp_a
,
rtol
=
1e-4
,
atol
=
1e-3
)
if
use_bias
:
if
use_bias
:
bias
=
torch
.
rand
((
1
,
n
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
+
2.5
bias
=
torch
.
rand
((
1
,
n
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
+
2.5
...
@@ -335,8 +338,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
...
@@ -335,8 +338,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
rtol
=
1e-2
if
out_dtype
==
torch
.
bfloat16
else
1e-3
rtol
=
1e-2
if
out_dtype
==
torch
.
bfloat16
else
1e-3
atol
=
1e-3
atol
=
1e-3
assert
torch
.
all
close
(
out
,
baseline_dq
,
rtol
=
rtol
,
atol
=
atol
)
torch
.
testing
.
assert_
close
(
out
,
baseline_dq
,
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
all
close
(
out
,
baseline_q
,
rtol
=
rtol
,
atol
=
atol
)
torch
.
testing
.
assert_
close
(
out
,
baseline_q
,
rtol
=
rtol
,
atol
=
atol
)
# Test working with a subset of A and B
# Test working with a subset of A and B
...
@@ -363,7 +366,7 @@ def test_cutlass_subset():
...
@@ -363,7 +366,7 @@ def test_cutlass_subset():
scale_b
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
out_dtype
=
torch
.
bfloat16
)
assert
torch
.
all
close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
torch
.
testing
.
assert_
close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
# Test to make sure cuda graphs work
# Test to make sure cuda graphs work
...
@@ -411,4 +414,4 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
...
@@ -411,4 +414,4 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
bfloat16
)
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
bfloat16
)
assert
torch
.
all
close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
torch
.
testing
.
assert_
close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
tests/kernels/test_flash_attn.py
View file @
50b8d08d
...
@@ -126,7 +126,7 @@ def test_flash_attn_with_paged_kv(
...
@@ -126,7 +126,7 @@ def test_flash_attn_with_paged_kv(
scale
=
scale
,
scale
=
scale
,
soft_cap
=
soft_cap
,
soft_cap
=
soft_cap
,
)
)
assert
torch
.
all
close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_
close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
@@ -211,5 +211,5 @@ def test_varlen_with_paged_kv(
...
@@ -211,5 +211,5 @@ def test_varlen_with_paged_kv(
sliding_window
=
sliding_window
,
sliding_window
=
sliding_window
,
soft_cap
=
soft_cap
,
soft_cap
=
soft_cap
,
)
)
assert
torch
.
all
close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_
close
(
output
,
ref_output
,
atol
=
2e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
tests/kernels/test_flashinfer.py
View file @
50b8d08d
...
@@ -144,7 +144,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
...
@@ -144,7 +144,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
block_tables
=
block_tables
,
block_tables
=
block_tables
,
scale
=
scale
,
scale
=
scale
,
soft_cap
=
soft_cap
)
soft_cap
=
soft_cap
)
assert
torch
.
all
close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_
close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
...
@@ -244,5 +244,5 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
...
@@ -244,5 +244,5 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
block_tables
=
block_tables
,
block_tables
=
block_tables
,
scale
=
scale
,
scale
=
scale
,
soft_cap
=
soft_cap
)
soft_cap
=
soft_cap
)
assert
torch
.
all
close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
torch
.
testing
.
assert_
close
(
output
,
ref_output
,
atol
=
1e-2
,
rtol
=
1e-2
),
\
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
f
"
{
torch
.
max
(
torch
.
abs
(
output
-
ref_output
))
}
"
tests/kernels/test_fp8_quant.py
View file @
50b8d08d
...
@@ -37,8 +37,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
...
@@ -37,8 +37,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
scale_ub
=
scale_ub
,
scale_ub
=
scale_ub
,
use_per_token_if_dynamic
=
True
)
use_per_token_if_dynamic
=
True
)
assert
torch
.
all
close
(
ref_scales
,
ops_scales
)
torch
.
testing
.
assert_
close
(
ref_scales
,
ops_scales
)
assert
torch
.
all
close
(
ref_out
.
to
(
dtype
=
torch
.
float32
),
torch
.
testing
.
assert_
close
(
ref_out
.
to
(
dtype
=
torch
.
float32
),
ops_out
.
to
(
dtype
=
torch
.
float32
))
ops_out
.
to
(
dtype
=
torch
.
float32
))
...
@@ -57,8 +57,8 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
...
@@ -57,8 +57,8 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
ref_out
,
ref_scale
=
ref_dynamic_per_tensor_fp8_quant
(
x
)
ref_out
,
ref_scale
=
ref_dynamic_per_tensor_fp8_quant
(
x
)
ops_out
,
ops_scale
=
ops
.
scaled_fp8_quant
(
x
)
ops_out
,
ops_scale
=
ops
.
scaled_fp8_quant
(
x
)
assert
torch
.
all
close
(
ref_scale
,
ops_scale
)
torch
.
testing
.
assert_
close
(
ref_scale
,
ops_scale
)
assert
torch
.
all
close
(
ref_out
.
to
(
dtype
=
torch
.
float32
),
torch
.
testing
.
assert_
close
(
ref_out
.
to
(
dtype
=
torch
.
float32
),
ops_out
.
to
(
dtype
=
torch
.
float32
))
ops_out
.
to
(
dtype
=
torch
.
float32
))
...
@@ -84,4 +84,4 @@ def test_fp8_quant_large(seed: int) -> None:
...
@@ -84,4 +84,4 @@ def test_fp8_quant_large(seed: int) -> None:
ref_out
=
ref_out
.
to
(
dtype
=
dtype
)
ref_out
=
ref_out
.
to
(
dtype
=
dtype
)
ops_out
=
ops_out
.
to
(
dtype
=
dtype
)
ops_out
=
ops_out
.
to
(
dtype
=
dtype
)
assert
torch
.
all
close
(
ref_out
,
ops_out
)
torch
.
testing
.
assert_
close
(
ref_out
,
ops_out
)
tests/kernels/test_int8_quant.py
View file @
50b8d08d
...
@@ -29,9 +29,10 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
...
@@ -29,9 +29,10 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
# kernel
# kernel
ops_out
,
ops_scales
=
scaled_int8_quant
(
x
)
ops_out
,
ops_scales
=
scaled_int8_quant
(
x
)
assert
torch
.
allclose
(
ops_scales
,
ref_scales
)
torch
.
testing
.
assert_close
(
ops_scales
,
ref_scales
)
assert
torch
.
allclose
(
ops_out
,
ref_out
,
torch
.
testing
.
assert_close
(
atol
=
1
)
# big atol to account for rounding errors
ops_out
,
ref_out
,
atol
=
1
,
rtol
=
0.0
)
# big atol to account for rounding errors
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
@@ -54,5 +55,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
...
@@ -54,5 +55,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits
.
max
).
to
(
torch
.
int8
)
int8_traits
.
max
).
to
(
torch
.
int8
)
out2
,
_
=
scaled_int8_quant
(
x
,
scale
)
out2
,
_
=
scaled_int8_quant
(
x
,
scale
)
assert
torch
.
allclose
(
out1
,
out2
,
torch
.
testing
.
assert_close
(
atol
=
1
)
# big atol to account for rounding errors
out1
,
out2
,
atol
=
1
,
rtol
=
0.0
)
# big atol to account for rounding errors
tests/kernels/test_layernorm.py
View file @
50b8d08d
...
@@ -48,7 +48,7 @@ def test_rms_norm(
...
@@ -48,7 +48,7 @@ def test_rms_norm(
# numerical errors than other operators because they involve reductions.
# numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance.
# Therefore, we use a larger tolerance.
if
add_residual
:
if
add_residual
:
assert
torch
.
all
close
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_
close
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
all
close
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_
close
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
else
:
assert
torch
.
all
close
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_
close
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
tests/kernels/test_marlin_gemm.py
View file @
50b8d08d
...
@@ -122,7 +122,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
...
@@ -122,7 +122,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
assert
torch
.
all
close
(
marlin_q_w_1
,
marlin_q_w_2
)
torch
.
testing
.
assert_
close
(
marlin_q_w_1
,
marlin_q_w_2
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
...
@@ -174,7 +174,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
...
@@ -174,7 +174,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
)
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
assert
torch
.
all
close
(
marlin_q_w_1
,
marlin_q_w_2
)
torch
.
testing
.
assert_
close
(
marlin_q_w_1
,
marlin_q_w_2
)
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"gptq_marlin"
),
...
...
tests/kernels/test_moe.py
View file @
50b8d08d
...
@@ -50,7 +50,7 @@ def test_fused_moe(
...
@@ -50,7 +50,7 @@ def test_fused_moe(
score
=
torch
.
randn
((
m
,
e
),
device
=
'cuda'
,
dtype
=
dtype
)
score
=
torch
.
randn
((
m
,
e
),
device
=
'cuda'
,
dtype
=
dtype
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch_output
=
torch_moe
(
a
,
w1
,
w2
,
score
,
topk
)
assert
torch
.
all
close
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
torch
.
testing
.
assert_
close
(
triton_output
,
torch_output
,
atol
=
1e-2
,
rtol
=
0
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
@
pytest
.
mark
.
parametrize
(
"dtype"
,
...
@@ -95,7 +95,7 @@ def test_mixtral_moe(dtype: torch.dtype):
...
@@ -95,7 +95,7 @@ def test_mixtral_moe(dtype: torch.dtype):
torch
.
bfloat16
:
1e-2
,
torch
.
bfloat16
:
1e-2
,
}
}
assert
torch
.
all
close
(
hf_states
.
flatten
(
0
,
1
),
torch
.
testing
.
assert_
close
(
hf_states
.
flatten
(
0
,
1
),
vllm_states
,
vllm_states
,
rtol
=
mixtral_moe_tol
[
dtype
],
rtol
=
mixtral_moe_tol
[
dtype
],
atol
=
mixtral_moe_tol
[
dtype
])
atol
=
mixtral_moe_tol
[
dtype
])
tests/kernels/test_pos_encoding.py
View file @
50b8d08d
...
@@ -67,11 +67,11 @@ def test_rotary_embedding(
...
@@ -67,11 +67,11 @@ def test_rotary_embedding(
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
)
ref_query
,
ref_key
=
rope
.
forward_native
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
# Compare the results.
# Compare the results.
assert
torch
.
all
close
(
out_query
,
torch
.
testing
.
assert_
close
(
out_query
,
ref_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
rtol
=
get_default_rtol
(
out_query
))
assert
torch
.
all
close
(
out_key
,
torch
.
testing
.
assert_
close
(
out_key
,
ref_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
rtol
=
get_default_rtol
(
out_key
))
...
@@ -129,11 +129,11 @@ def test_batched_rotary_embedding(
...
@@ -129,11 +129,11 @@ def test_batched_rotary_embedding(
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
))
device
=
device
))
# Compare the results.
# Compare the results.
assert
torch
.
all
close
(
out_query
,
torch
.
testing
.
assert_
close
(
out_query
,
ref_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
rtol
=
get_default_rtol
(
out_query
))
assert
torch
.
all
close
(
out_key
,
torch
.
testing
.
assert_
close
(
out_key
,
ref_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
rtol
=
get_default_rtol
(
out_key
))
...
@@ -200,11 +200,11 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -200,11 +200,11 @@ def test_batched_rotary_embedding_multi_lora(
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
,
query_offsets
.
flatten
())
query_offsets
.
flatten
())
# Compare the results.
# Compare the results.
assert
torch
.
all
close
(
out_query
,
torch
.
testing
.
assert_
close
(
out_query
,
ref_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
rtol
=
get_default_rtol
(
out_query
))
assert
torch
.
all
close
(
out_key
,
torch
.
testing
.
assert_
close
(
out_key
,
ref_key
,
ref_key
,
atol
=
get_default_atol
(
out_key
),
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
rtol
=
get_default_rtol
(
out_key
))
...
...
tests/kernels/test_sampler.py
View file @
50b8d08d
...
@@ -100,11 +100,11 @@ def test_sample_decoding_only(random_sampling, max_best_of,
...
@@ -100,11 +100,11 @@ def test_sample_decoding_only(random_sampling, max_best_of,
if
modify_greedy_probs
and
not
request_uses_random_sampling
:
if
modify_greedy_probs
and
not
request_uses_random_sampling
:
# If we are modifying greedy probs and the request is greedy,
# If we are modifying greedy probs and the request is greedy,
# we want to make sure the probs tensor is modified in place
# we want to make sure the probs tensor is modified in place
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
probs
[
i
][
sampled_tokens
[
i
]],
probs
[
i
][
sampled_tokens
[
i
]],
torch
.
full_like
(
probs
[
i
][
sampled_tokens
[
i
]],
1.0
))
torch
.
full_like
(
probs
[
i
][
sampled_tokens
[
i
]],
1.0
))
assert
torch
.
sum
(
probs
[
i
])
==
1.0
assert
torch
.
sum
(
probs
[
i
])
==
1.0
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
sampled_modified_probs
[
i
][
0
],
sampled_modified_probs
[
i
][
0
],
torch
.
full_like
(
sampled_modified_probs
[
i
][
0
],
1.0
))
torch
.
full_like
(
sampled_modified_probs
[
i
][
0
],
1.0
))
elif
request_uses_random_sampling
:
elif
request_uses_random_sampling
:
...
@@ -117,7 +117,7 @@ def test_sample_decoding_only(random_sampling, max_best_of,
...
@@ -117,7 +117,7 @@ def test_sample_decoding_only(random_sampling, max_best_of,
# If the request is greedy and we are not modifying greedy probs,
# If the request is greedy and we are not modifying greedy probs,
# we want to make sure sampled_modified_probs tensor is the same as
# we want to make sure sampled_modified_probs tensor is the same as
# the probs tensor.
# the probs tensor.
assert
torch
.
all
close
(
sampled_modified_probs
[
i
]
[
0
]
,
torch
.
testing
.
assert_
close
(
sampled_modified_probs
[
i
],
probs
[
i
][
sampled_tokens
[
i
]])
probs
[
i
][
sampled_tokens
[
i
]])
if
save_logprobs
:
if
save_logprobs
:
...
...
tests/kernels/utils.py
View file @
50b8d08d
...
@@ -924,5 +924,5 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
...
@@ -924,5 +924,5 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
* output_under_test: actually observed output value
* output_under_test: actually observed output value
'''
'''
ideal_output
=
test_params
.
packed_qkvo
.
ideal_output
ideal_output
=
test_params
.
packed_qkvo
.
ideal_output
assert
torch
.
all
close
(
ideal_output
,
torch
.
testing
.
assert_
close
(
ideal_output
,
output_under_test
.
view_as
(
ideal_output
))
output_under_test
.
view_as
(
ideal_output
))
tests/lora/test_layers.py
View file @
50b8d08d
...
@@ -247,7 +247,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
...
@@ -247,7 +247,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
expected_result
=
torch
.
cat
(
expected_results
)
expected_result
=
torch
.
cat
(
expected_results
)
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
all
close
(
lora_result
,
torch
.
testing
.
assert_
close
(
lora_result
,
expected_result
,
expected_result
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
...
@@ -274,7 +274,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
...
@@ -274,7 +274,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
expected_result
=
embedding
(
torch
.
cat
(
inputs
))
expected_result
=
embedding
(
torch
.
cat
(
inputs
))
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
all
close
(
lora_result
,
torch
.
testing
.
assert_
close
(
lora_result
,
expected_result
,
expected_result
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
...
@@ -384,7 +384,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
...
@@ -384,7 +384,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
expected_result
=
torch
.
cat
(
expected_results
)
expected_result
=
torch
.
cat
(
expected_results
)
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
all
close
(
lora_result
,
torch
.
testing
.
assert_
close
(
lora_result
,
expected_result
,
expected_result
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
...
@@ -411,7 +411,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
...
@@ -411,7 +411,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
expected_result
=
expanded_embedding
(
torch
.
cat
(
inputs
))
expected_result
=
expanded_embedding
(
torch
.
cat
(
inputs
))
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
all
close
(
lora_result
,
torch
.
testing
.
assert_
close
(
lora_result
,
expected_result
,
expected_result
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
...
@@ -541,7 +541,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
...
@@ -541,7 +541,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
embedding_bias
=
None
)
embedding_bias
=
None
)
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
all
close
(
lora_result
,
torch
.
testing
.
assert_
close
(
lora_result
,
expected_result
,
expected_result
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
...
@@ -614,7 +614,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
...
@@ -614,7 +614,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
expected_result
=
torch
.
cat
(
expected_results
)
expected_result
=
torch
.
cat
(
expected_results
)
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
all
close
(
lora_result
,
torch
.
testing
.
assert_
close
(
lora_result
,
expected_result
,
expected_result
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
...
@@ -642,7 +642,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
...
@@ -642,7 +642,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
expected_result
=
linear
(
torch
.
cat
(
inputs
))[
0
]
expected_result
=
linear
(
torch
.
cat
(
inputs
))[
0
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
all
close
(
lora_result
,
torch
.
testing
.
assert_
close
(
lora_result
,
expected_result
,
expected_result
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
...
@@ -728,7 +728,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
...
@@ -728,7 +728,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
expected_result
=
torch
.
cat
(
expected_results
)
expected_result
=
torch
.
cat
(
expected_results
)
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
all
close
(
lora_result
,
torch
.
testing
.
assert_
close
(
lora_result
,
expected_result
,
expected_result
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
...
@@ -756,7 +756,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
...
@@ -756,7 +756,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
expected_result
=
linear
(
torch
.
cat
(
inputs
))[
0
]
expected_result
=
linear
(
torch
.
cat
(
inputs
))[
0
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
all
close
(
lora_result
,
torch
.
testing
.
assert_
close
(
lora_result
,
expected_result
,
expected_result
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
...
@@ -868,7 +868,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
...
@@ -868,7 +868,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
expected_result
=
torch
.
cat
(
expected_results
)
expected_result
=
torch
.
cat
(
expected_results
)
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
all
close
(
lora_result
,
torch
.
testing
.
assert_
close
(
lora_result
,
expected_result
,
expected_result
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
...
@@ -900,7 +900,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
...
@@ -900,7 +900,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
expected_result
=
linear
(
torch
.
cat
(
inputs
))[
0
]
expected_result
=
linear
(
torch
.
cat
(
inputs
))[
0
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
all
close
(
lora_result
,
torch
.
testing
.
assert_
close
(
lora_result
,
expected_result
,
expected_result
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
)
atol
=
atol
)
...
...
tests/lora/test_lora_manager.py
View file @
50b8d08d
...
@@ -533,13 +533,13 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
...
@@ -533,13 +533,13 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
packed_lora
=
model_lora
.
get_lora
(
"gate_up_proj"
)
packed_lora
=
model_lora
.
get_lora
(
"gate_up_proj"
)
assert
packed_lora
and
isinstance
(
packed_lora
,
PackedLoRALayerWeights
)
assert
packed_lora
and
isinstance
(
packed_lora
,
PackedLoRALayerWeights
)
assert
torch
.
all
close
(
packed_lora
.
lora_a
[
0
],
torch
.
testing
.
assert_
close
(
packed_lora
.
lora_a
[
0
],
model_lora
.
get_lora
(
"gate_proj"
).
lora_a
)
model_lora
.
get_lora
(
"gate_proj"
).
lora_a
)
assert
torch
.
all
close
(
packed_lora
.
lora_b
[
0
],
torch
.
testing
.
assert_
close
(
packed_lora
.
lora_b
[
0
],
model_lora
.
get_lora
(
"gate_proj"
).
lora_b
)
model_lora
.
get_lora
(
"gate_proj"
).
lora_b
)
assert
torch
.
all
close
(
packed_lora
.
lora_a
[
1
],
torch
.
testing
.
assert_
close
(
packed_lora
.
lora_a
[
1
],
model_lora
.
get_lora
(
"up_proj"
).
lora_a
)
model_lora
.
get_lora
(
"up_proj"
).
lora_a
)
assert
torch
.
all
close
(
packed_lora
.
lora_b
[
1
],
torch
.
testing
.
assert_
close
(
packed_lora
.
lora_b
[
1
],
model_lora
.
get_lora
(
"up_proj"
).
lora_b
)
model_lora
.
get_lora
(
"up_proj"
).
lora_b
)
packed_lora1
=
model_lora1
.
get_lora
(
"gate_up_proj"
)
packed_lora1
=
model_lora1
.
get_lora
(
"gate_up_proj"
)
...
@@ -547,7 +547,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
...
@@ -547,7 +547,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
assert
packed_lora1
.
lora_a
[
0
]
is
None
assert
packed_lora1
.
lora_a
[
0
]
is
None
assert
packed_lora1
.
lora_b
[
0
]
is
None
assert
packed_lora1
.
lora_b
[
0
]
is
None
assert
torch
.
all
close
(
packed_lora1
.
lora_a
[
1
],
torch
.
testing
.
assert_
close
(
packed_lora1
.
lora_a
[
1
],
model_lora1
.
get_lora
(
"up_proj"
).
lora_a
)
model_lora1
.
get_lora
(
"up_proj"
).
lora_a
)
assert
torch
.
all
close
(
packed_lora1
.
lora_b
[
1
],
torch
.
testing
.
assert_
close
(
packed_lora1
.
lora_b
[
1
],
model_lora1
.
get_lora
(
"up_proj"
).
lora_b
)
model_lora1
.
get_lora
(
"up_proj"
).
lora_b
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment