Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
50b8d08d
Unverified
Commit
50b8d08d
authored
Aug 15, 2024
by
jon-chuang
Committed by
GitHub
Aug 16, 2024
Browse files
[Misc/Testing] Use `torch.testing.assert_close` (#7324)
parent
e1655287
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
20 additions
and
16 deletions
+20
-16
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+5
-3
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+1
-1
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+1
-1
tests/test_logits_processor.py
tests/test_logits_processor.py
+4
-2
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+9
-9
No files found.
tests/quantization/test_fp8.py
View file @
50b8d08d
...
@@ -127,16 +127,18 @@ def test_scaled_fp8_quant(dtype) -> None:
...
@@ -127,16 +127,18 @@ def test_scaled_fp8_quant(dtype) -> None:
# Reference dynamic quantizaton
# Reference dynamic quantizaton
y
=
quantize_ref
(
x
,
inv_scale
)
y
=
quantize_ref
(
x
,
inv_scale
)
assert
torch
.
allclose
(
ref_y
,
per_tensor_dequantize
(
y
,
inv_scale
,
dtype
))
torch
.
testing
.
assert_close
(
ref_y
,
per_tensor_dequantize
(
y
,
inv_scale
,
dtype
))
# Static quantization
# Static quantization
y
,
_
=
ops
.
scaled_fp8_quant
(
x
,
inv_scale
)
y
,
_
=
ops
.
scaled_fp8_quant
(
x
,
inv_scale
)
assert
torch
.
allclose
(
ref_y
,
per_tensor_dequantize
(
y
,
inv_scale
,
dtype
))
torch
.
testing
.
assert_close
(
ref_y
,
per_tensor_dequantize
(
y
,
inv_scale
,
dtype
))
# Padding
# Padding
y
,
_
=
ops
.
scaled_fp8_quant
(
x
,
inv_scale
,
num_token_padding
=
17
)
y
,
_
=
ops
.
scaled_fp8_quant
(
x
,
inv_scale
,
num_token_padding
=
17
)
assert
y
.
shape
[
0
]
==
17
assert
y
.
shape
[
0
]
==
17
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
ref_y
,
ref_y
,
per_tensor_dequantize
(
torch
.
narrow
(
y
,
0
,
0
,
x
.
shape
[
0
]),
inv_scale
,
per_tensor_dequantize
(
torch
.
narrow
(
y
,
0
,
0
,
x
.
shape
[
0
]),
inv_scale
,
dtype
))
dtype
))
tests/samplers/test_sampler.py
View file @
50b8d08d
...
@@ -632,7 +632,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -632,7 +632,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
hf_probs
=
warpers
(
torch
.
zeros_like
(
fake_logits
),
fake_logits
.
clone
())
hf_probs
=
warpers
(
torch
.
zeros_like
(
fake_logits
),
fake_logits
.
clone
())
hf_probs
=
torch
.
softmax
(
hf_probs
,
dim
=-
1
,
dtype
=
torch
.
float
)
hf_probs
=
torch
.
softmax
(
hf_probs
,
dim
=-
1
,
dtype
=
torch
.
float
)
assert
torch
.
all
close
(
hf_probs
,
sample_probs
,
atol
=
1e-5
)
torch
.
testing
.
assert_
close
(
hf_probs
,
sample_probs
,
rtol
=
0.0
,
atol
=
1e-5
)
assert
torch
.
equal
(
hf_probs
.
eq
(
0
),
sample_probs
.
eq
(
0
))
assert
torch
.
equal
(
hf_probs
.
eq
(
0
),
sample_probs
.
eq
(
0
))
...
...
tests/spec_decode/utils.py
View file @
50b8d08d
...
@@ -161,7 +161,7 @@ def assert_logprobs_dict_allclose(
...
@@ -161,7 +161,7 @@ def assert_logprobs_dict_allclose(
single_step_actual_logprobs
[
token_id
].
logprob
)
single_step_actual_logprobs
[
token_id
].
logprob
)
expected
=
torch
.
tensor
(
expected
=
torch
.
tensor
(
single_step_expected_logprobs
[
token_id
].
logprob
)
single_step_expected_logprobs
[
token_id
].
logprob
)
assert
torch
.
all
close
(
actual
,
expected
)
torch
.
testing
.
assert_
close
(
actual
,
expected
)
def
create_sampler_output_list
(
def
create_sampler_output_list
(
...
...
tests/test_logits_processor.py
View file @
50b8d08d
...
@@ -90,5 +90,7 @@ def test_logits_processors(seed: int, device: str):
...
@@ -90,5 +90,7 @@ def test_logits_processors(seed: int, device: str):
assert
torch
.
isinf
(
logits_processor_output
[:,
0
]).
all
()
assert
torch
.
isinf
(
logits_processor_output
[:,
0
]).
all
()
fake_logits
*=
logits_processor
.
scale
fake_logits
*=
logits_processor
.
scale
assert
torch
.
allclose
(
logits_processor_output
[:,
1
],
fake_logits
[:,
1
],
torch
.
testing
.
assert_close
(
logits_processor_output
[:,
1
],
1e-4
)
fake_logits
[:,
1
],
rtol
=
1e-4
,
atol
=
0.0
)
tests/worker/test_model_runner.py
View file @
50b8d08d
...
@@ -77,7 +77,7 @@ def test_prepare_prompt(batch_size):
...
@@ -77,7 +77,7 @@ def test_prepare_prompt(batch_size):
device
=
model_runner
.
device
device
=
model_runner
.
device
assert
attn_metadata
.
num_prefills
>
0
assert
attn_metadata
.
num_prefills
>
0
assert
attn_metadata
.
num_decode_tokens
==
0
assert
attn_metadata
.
num_decode_tokens
==
0
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
seq_lens_tensor
,
attn_metadata
.
seq_lens_tensor
,
torch
.
tensor
(
seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
torch
.
tensor
(
seq_lens
,
device
=
device
,
dtype
=
torch
.
int
))
assert
attn_metadata
.
seq_lens
==
seq_lens
assert
attn_metadata
.
seq_lens
==
seq_lens
...
@@ -90,7 +90,7 @@ def test_prepare_prompt(batch_size):
...
@@ -90,7 +90,7 @@ def test_prepare_prompt(batch_size):
for
seq_len
in
seq_lens
:
for
seq_len
in
seq_lens
:
start_idx
+=
seq_len
start_idx
+=
seq_len
start_loc
.
append
(
start_idx
)
start_loc
.
append
(
start_idx
)
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
query_start_loc
,
attn_metadata
.
query_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
...
@@ -102,10 +102,10 @@ def test_prepare_prompt(batch_size):
...
@@ -102,10 +102,10 @@ def test_prepare_prompt(batch_size):
start_idx
+=
seq_len
start_idx
+=
seq_len
seq_start_loc
.
append
(
start_idx
)
seq_start_loc
.
append
(
start_idx
)
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
seq_start_loc
,
attn_metadata
.
seq_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
context_lens_tensor
,
attn_metadata
.
context_lens_tensor
,
torch
.
zeros
(
attn_metadata
.
context_lens_tensor
.
shape
[
0
],
torch
.
zeros
(
attn_metadata
.
context_lens_tensor
.
shape
[
0
],
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
...
@@ -114,7 +114,7 @@ def test_prepare_prompt(batch_size):
...
@@ -114,7 +114,7 @@ def test_prepare_prompt(batch_size):
expected
=
torch
.
tensor
([[]
for
_
in
range
(
len
(
seq_group_metadata_list
))],
expected
=
torch
.
tensor
([[]
for
_
in
range
(
len
(
seq_group_metadata_list
))],
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
model_runner
.
device
)
device
=
model_runner
.
device
)
assert
torch
.
all
close
(
attn_metadata
.
block_tables
,
expected
)
torch
.
testing
.
assert_
close
(
attn_metadata
.
block_tables
,
expected
)
# Cuda graph should not be used for prerill.
# Cuda graph should not be used for prerill.
assert
attn_metadata
.
use_cuda_graph
is
False
assert
attn_metadata
.
use_cuda_graph
is
False
...
@@ -201,7 +201,7 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -201,7 +201,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# decode has only 1 token for query.
# decode has only 1 token for query.
start_idx
+=
1
start_idx
+=
1
start_loc
.
append
(
start_idx
)
start_loc
.
append
(
start_idx
)
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
query_start_loc
,
attn_metadata
.
query_start_loc
,
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
torch
.
tensor
(
start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
...
@@ -210,15 +210,15 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -210,15 +210,15 @@ def test_prepare_decode_cuda_graph(batch_size):
for
seq_len
in
seq_lens
:
for
seq_len
in
seq_lens
:
start_idx
+=
seq_len
start_idx
+=
seq_len
seq_start_loc
.
append
(
start_idx
)
seq_start_loc
.
append
(
start_idx
)
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
seq_start_loc
,
attn_metadata
.
seq_start_loc
,
torch
.
tensor
(
seq_start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
torch
.
tensor
(
seq_start_loc
,
dtype
=
torch
.
int32
,
device
=
device
))
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
context_lens_tensor
,
attn_metadata
.
context_lens_tensor
,
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
device
))
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
device
))
assert
attn_metadata
.
max_decode_seq_len
==
max
(
seq_lens
)
assert
attn_metadata
.
max_decode_seq_len
==
max
(
seq_lens
)
assert
torch
.
all
close
(
torch
.
testing
.
assert_
close
(
attn_metadata
.
seq_lens_tensor
[:
len
(
seq_lens
)],
attn_metadata
.
seq_lens_tensor
[:
len
(
seq_lens
)],
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
))
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int
,
device
=
device
))
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment