Unverified Commit 50b8d08d authored by jon-chuang's avatar jon-chuang Committed by GitHub
Browse files

[Misc/Testing] Use `torch.testing.assert_close` (#7324)

parent e1655287
...@@ -127,16 +127,18 @@ def test_scaled_fp8_quant(dtype) -> None: ...@@ -127,16 +127,18 @@ def test_scaled_fp8_quant(dtype) -> None:
# Reference dynamic quantizaton # Reference dynamic quantizaton
y = quantize_ref(x, inv_scale) y = quantize_ref(x, inv_scale)
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype)) torch.testing.assert_close(ref_y,
per_tensor_dequantize(y, inv_scale, dtype))
# Static quantization # Static quantization
y, _ = ops.scaled_fp8_quant(x, inv_scale) y, _ = ops.scaled_fp8_quant(x, inv_scale)
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype)) torch.testing.assert_close(ref_y,
per_tensor_dequantize(y, inv_scale, dtype))
# Padding # Padding
y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17) y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
assert y.shape[0] == 17 assert y.shape[0] == 17
assert torch.allclose( torch.testing.assert_close(
ref_y, ref_y,
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale, per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
dtype)) dtype))
...@@ -632,7 +632,7 @@ def test_sampler_top_k_top_p(seed: int, device: str): ...@@ -632,7 +632,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone()) hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5) torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
......
...@@ -161,7 +161,7 @@ def assert_logprobs_dict_allclose( ...@@ -161,7 +161,7 @@ def assert_logprobs_dict_allclose(
single_step_actual_logprobs[token_id].logprob) single_step_actual_logprobs[token_id].logprob)
expected = torch.tensor( expected = torch.tensor(
single_step_expected_logprobs[token_id].logprob) single_step_expected_logprobs[token_id].logprob)
assert torch.allclose(actual, expected) torch.testing.assert_close(actual, expected)
def create_sampler_output_list( def create_sampler_output_list(
......
...@@ -90,5 +90,7 @@ def test_logits_processors(seed: int, device: str): ...@@ -90,5 +90,7 @@ def test_logits_processors(seed: int, device: str):
assert torch.isinf(logits_processor_output[:, 0]).all() assert torch.isinf(logits_processor_output[:, 0]).all()
fake_logits *= logits_processor.scale fake_logits *= logits_processor.scale
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1], torch.testing.assert_close(logits_processor_output[:, 1],
1e-4) fake_logits[:, 1],
rtol=1e-4,
atol=0.0)
...@@ -77,7 +77,7 @@ def test_prepare_prompt(batch_size): ...@@ -77,7 +77,7 @@ def test_prepare_prompt(batch_size):
device = model_runner.device device = model_runner.device
assert attn_metadata.num_prefills > 0 assert attn_metadata.num_prefills > 0
assert attn_metadata.num_decode_tokens == 0 assert attn_metadata.num_decode_tokens == 0
assert torch.allclose( torch.testing.assert_close(
attn_metadata.seq_lens_tensor, attn_metadata.seq_lens_tensor,
torch.tensor(seq_lens, device=device, dtype=torch.int)) torch.tensor(seq_lens, device=device, dtype=torch.int))
assert attn_metadata.seq_lens == seq_lens assert attn_metadata.seq_lens == seq_lens
...@@ -90,7 +90,7 @@ def test_prepare_prompt(batch_size): ...@@ -90,7 +90,7 @@ def test_prepare_prompt(batch_size):
for seq_len in seq_lens: for seq_len in seq_lens:
start_idx += seq_len start_idx += seq_len
start_loc.append(start_idx) start_loc.append(start_idx)
assert torch.allclose( torch.testing.assert_close(
attn_metadata.query_start_loc, attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device)) torch.tensor(start_loc, dtype=torch.int32, device=device))
...@@ -102,10 +102,10 @@ def test_prepare_prompt(batch_size): ...@@ -102,10 +102,10 @@ def test_prepare_prompt(batch_size):
start_idx += seq_len start_idx += seq_len
seq_start_loc.append(start_idx) seq_start_loc.append(start_idx)
assert torch.allclose( torch.testing.assert_close(
attn_metadata.seq_start_loc, attn_metadata.seq_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device)) torch.tensor(start_loc, dtype=torch.int32, device=device))
assert torch.allclose( torch.testing.assert_close(
attn_metadata.context_lens_tensor, attn_metadata.context_lens_tensor,
torch.zeros(attn_metadata.context_lens_tensor.shape[0], torch.zeros(attn_metadata.context_lens_tensor.shape[0],
dtype=torch.int, dtype=torch.int,
...@@ -114,7 +114,7 @@ def test_prepare_prompt(batch_size): ...@@ -114,7 +114,7 @@ def test_prepare_prompt(batch_size):
expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))], expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
dtype=torch.int32, dtype=torch.int32,
device=model_runner.device) device=model_runner.device)
assert torch.allclose(attn_metadata.block_tables, expected) torch.testing.assert_close(attn_metadata.block_tables, expected)
# Cuda graph should not be used for prerill. # Cuda graph should not be used for prerill.
assert attn_metadata.use_cuda_graph is False assert attn_metadata.use_cuda_graph is False
...@@ -201,7 +201,7 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -201,7 +201,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# decode has only 1 token for query. # decode has only 1 token for query.
start_idx += 1 start_idx += 1
start_loc.append(start_idx) start_loc.append(start_idx)
assert torch.allclose( torch.testing.assert_close(
attn_metadata.query_start_loc, attn_metadata.query_start_loc,
torch.tensor(start_loc, dtype=torch.int32, device=device)) torch.tensor(start_loc, dtype=torch.int32, device=device))
...@@ -210,15 +210,15 @@ def test_prepare_decode_cuda_graph(batch_size): ...@@ -210,15 +210,15 @@ def test_prepare_decode_cuda_graph(batch_size):
for seq_len in seq_lens: for seq_len in seq_lens:
start_idx += seq_len start_idx += seq_len
seq_start_loc.append(start_idx) seq_start_loc.append(start_idx)
assert torch.allclose( torch.testing.assert_close(
attn_metadata.seq_start_loc, attn_metadata.seq_start_loc,
torch.tensor(seq_start_loc, dtype=torch.int32, device=device)) torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
assert torch.allclose( torch.testing.assert_close(
attn_metadata.context_lens_tensor, attn_metadata.context_lens_tensor,
torch.tensor(context_lens, dtype=torch.int, device=device)) torch.tensor(context_lens, dtype=torch.int, device=device))
assert attn_metadata.max_decode_seq_len == max(seq_lens) assert attn_metadata.max_decode_seq_len == max(seq_lens)
assert torch.allclose( torch.testing.assert_close(
attn_metadata.seq_lens_tensor[:len(seq_lens)], attn_metadata.seq_lens_tensor[:len(seq_lens)],
torch.tensor(seq_lens, dtype=torch.int, device=device)) torch.tensor(seq_lens, dtype=torch.int, device=device))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment