"vllm/vscode:/vscode.git/clone" did not exist on "e0919f331d12dc5dbdefd0775bb6f94dd2fab4e2"
Unverified Commit 50b8d08d authored by jon-chuang's avatar jon-chuang Committed by GitHub
Browse files

[Misc/Testing] Use `torch.testing.assert_close` (#7324)

parent e1655287
...@@ -34,7 +34,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int, ...@@ -34,7 +34,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0) expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank % tp_size] t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_reduce(t) t = tensor_model_parallel_all_reduce(t)
assert torch.allclose(t, expected) torch.testing.assert_close(t, expected)
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
...@@ -62,7 +62,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int, ...@@ -62,7 +62,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
expected = torch.cat(all_tensors, dim=all_gather_dimension) expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank % tp_size] t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_gather(t, all_gather_dimension) t = tensor_model_parallel_all_gather(t, all_gather_dimension)
assert torch.allclose(t, expected) torch.testing.assert_close(t, expected)
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
...@@ -96,12 +96,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, ...@@ -96,12 +96,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
else: else:
recv_dict = broadcast_tensor_dict(src=0) recv_dict = broadcast_tensor_dict(src=0)
assert len(recv_dict) == len(test_dict) assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"]) torch.testing.assert_close(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"]) torch.testing.assert_close(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"] assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"] assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"] assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"]) torch.testing.assert_close(recv_dict["f"], test_dict["f"])
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
...@@ -136,12 +136,12 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int, ...@@ -136,12 +136,12 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
if not get_pp_group().is_first_rank: if not get_pp_group().is_first_rank:
assert len(recv_dict) == len(test_dict) assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"]) torch.testing.assert_close(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"]) torch.testing.assert_close(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"] assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"] assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"] assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"]) torch.testing.assert_close(recv_dict["f"], test_dict["f"])
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
...@@ -163,7 +163,7 @@ def send_recv_test_worker(tp_size: int, pp_size: int, rank: int, ...@@ -163,7 +163,7 @@ def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
get_pp_group().send(test_tensor) get_pp_group().send(test_tensor)
if not get_pp_group().is_first_rank: if not get_pp_group().is_first_rank:
assert torch.allclose(test_tensor, recv_tensor) torch.testing.assert_close(test_tensor, recv_tensor)
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(torch.cuda.device_count() < 2,
......
...@@ -72,8 +72,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): ...@@ -72,8 +72,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
out2 = tensor_model_parallel_all_reduce(inp2) out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2, group=group) dist.all_reduce(inp2, group=group)
graph.replay() graph.replay()
assert torch.allclose(out1, inp1) torch.testing.assert_close(out1, inp1)
assert torch.allclose(out2, inp2) torch.testing.assert_close(out2, inp2)
@ray.remote(num_gpus=1, max_calls=1) @ray.remote(num_gpus=1, max_calls=1)
...@@ -96,13 +96,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port): ...@@ -96,13 +96,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
out = inp out = inp
for _ in range(num_communication): for _ in range(num_communication):
out = fa.all_reduce_unreg(out) out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication)) torch.testing.assert_close(out, inp * (tp_size**num_communication))
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
out = inp out = inp
for _ in range(num_communication): for _ in range(num_communication):
out = fa.all_reduce_unreg(out) out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication)) torch.testing.assert_close(out, inp * (tp_size**num_communication))
@pytest.mark.parametrize("tp_size", [2]) @pytest.mark.parametrize("tp_size", [2])
......
...@@ -69,4 +69,4 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ ...@@ -69,4 +69,4 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
ref_iscale = one / ref_scale ref_iscale = one / ref_scale
ref_out = (as_float32_tensor(x) * ref_iscale).clamp( ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
return ref_out, ref_scale return ref_out, ref_scale.view((1, ))
...@@ -47,7 +47,7 @@ def test_act_and_mul( ...@@ -47,7 +47,7 @@ def test_act_and_mul(
ref_out = layer.forward_native(x) ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch # The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison. # implementations, so we can do exact comparison.
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0) torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
@pytest.mark.parametrize("activation", [FastGELU, NewGELU]) @pytest.mark.parametrize("activation", [FastGELU, NewGELU])
...@@ -73,7 +73,7 @@ def test_activation( ...@@ -73,7 +73,7 @@ def test_activation(
layer = activation() layer = activation()
out = layer(x) out = layer(x)
ref_out = layer.forward_native(x) ref_out = layer.forward_native(x)
assert torch.allclose(out, torch.testing.assert_close(out,
ref_out, ref_out,
atol=get_default_atol(out), atol=get_default_atol(out),
rtol=get_default_rtol(out)) rtol=get_default_rtol(out))
...@@ -276,7 +276,7 @@ def test_paged_attention( ...@@ -276,7 +276,7 @@ def test_paged_attention(
atol, rtol = 1e-3, 1e-5 atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
atol, rtol = 1e-2, 1e-5 atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
def ref_multi_query_kv_attention( def ref_multi_query_kv_attention(
...@@ -379,4 +379,4 @@ def test_multi_query_kv_attention( ...@@ -379,4 +379,4 @@ def test_multi_query_kv_attention(
) )
atol = get_default_atol(output) if is_hip() else 1e-3 atol = get_default_atol(output) if is_hip() else 1e-3
rtol = get_default_rtol(output) if is_hip() else 1e-5 rtol = get_default_rtol(output) if is_hip() else 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
...@@ -327,7 +327,7 @@ def test_paged_attention( ...@@ -327,7 +327,7 @@ def test_paged_attention(
atol, rtol = 1e-3, 1e-5 atol, rtol = 1e-3, 1e-5
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
atol, rtol = 1e-2, 1e-5 atol, rtol = 1e-2, 1e-5
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
def ref_multi_query_kv_attention( def ref_multi_query_kv_attention(
...@@ -441,4 +441,4 @@ def test_varlen_blocksparse_attention_prefill( ...@@ -441,4 +441,4 @@ def test_varlen_blocksparse_attention_prefill(
scale, scale,
dtype, dtype,
) )
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
...@@ -98,10 +98,10 @@ def test_copy_blocks( ...@@ -98,10 +98,10 @@ def test_copy_blocks(
# Compare the results. # Compare the results.
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches): for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
assert torch.allclose(key_cache, cloned_key_cache) torch.testing.assert_close(key_cache, cloned_key_cache)
for value_cache, cloned_value_cache in zip(value_caches, for value_cache, cloned_value_cache in zip(value_caches,
cloned_value_caches): cloned_value_caches):
assert torch.allclose(value_cache, cloned_value_cache) torch.testing.assert_close(value_cache, cloned_value_cache)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
...@@ -184,17 +184,17 @@ def test_reshape_and_cache( ...@@ -184,17 +184,17 @@ def test_reshape_and_cache(
cloned_value_cache[block_idx, :, :, block_offset] = value[i] cloned_value_cache[block_idx, :, :, block_offset] = value[i]
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
assert torch.allclose(result_key_cache, torch.testing.assert_close(result_key_cache,
cloned_key_cache, cloned_key_cache,
atol=0.001, atol=0.001,
rtol=0.1) rtol=0.1)
assert torch.allclose(result_value_cache, torch.testing.assert_close(result_value_cache,
cloned_value_cache, cloned_value_cache,
atol=0.001, atol=0.001,
rtol=0.1) rtol=0.1)
else: else:
assert torch.allclose(key_cache, cloned_key_cache) torch.testing.assert_close(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache) torch.testing.assert_close(value_cache, cloned_value_cache)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
...@@ -290,17 +290,17 @@ def test_reshape_and_cache_flash( ...@@ -290,17 +290,17 @@ def test_reshape_and_cache_flash(
cloned_value_cache[block_idx, block_offset, :, :] = value[i] cloned_value_cache[block_idx, block_offset, :, :] = value[i]
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
assert torch.allclose(result_key_cache, torch.testing.assert_close(result_key_cache,
cloned_key_cache, cloned_key_cache,
atol=0.001, atol=0.001,
rtol=0.1) rtol=0.1)
assert torch.allclose(result_value_cache, torch.testing.assert_close(result_value_cache,
cloned_value_cache, cloned_value_cache,
atol=0.001, atol=0.001,
rtol=0.1) rtol=0.1)
else: else:
assert torch.allclose(key_cache, cloned_key_cache) torch.testing.assert_close(key_cache, cloned_key_cache)
assert torch.allclose(value_cache, cloned_value_cache) torch.testing.assert_close(value_cache, cloned_value_cache)
@pytest.mark.parametrize("direction", COPYING_DIRECTION) @pytest.mark.parametrize("direction", COPYING_DIRECTION)
...@@ -372,9 +372,9 @@ def test_swap_blocks( ...@@ -372,9 +372,9 @@ def test_swap_blocks(
block_mapping_tensor) block_mapping_tensor)
for src, dst in block_mapping: for src, dst in block_mapping:
assert torch.allclose(src_key_caches_clone[src].cpu(), torch.testing.assert_close(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu()) dist_key_caches[0][dst].cpu())
assert torch.allclose(src_value_caches_clone[src].cpu(), torch.testing.assert_close(src_value_caches_clone[src].cpu(),
dist_value_caches[0][dst].cpu()) dist_value_caches[0][dst].cpu())
...@@ -411,4 +411,4 @@ def test_fp8_e4m3_conversion( ...@@ -411,4 +411,4 @@ def test_fp8_e4m3_conversion(
converted_cache = torch.empty_like(cache) converted_cache = torch.empty_like(cache)
ops.convert_fp8(converted_cache, cache_fp8) ops.convert_fp8(converted_cache, cache_fp8)
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1) torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
...@@ -74,7 +74,7 @@ def cutlass_fp8_gemm_helper(m: int, ...@@ -74,7 +74,7 @@ def cutlass_fp8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2) torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
def cutlass_int8_gemm_helper(m: int, def cutlass_int8_gemm_helper(m: int,
...@@ -106,7 +106,7 @@ def cutlass_int8_gemm_helper(m: int, ...@@ -106,7 +106,7 @@ def cutlass_int8_gemm_helper(m: int,
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33]) @pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
...@@ -252,7 +252,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, ...@@ -252,7 +252,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32) a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
assert torch.allclose(a_dq, scale_a * aq_f32 + azp_a) torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype) baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
...@@ -271,8 +271,8 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int, ...@@ -271,8 +271,8 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
scale_b, scale_b,
out_dtype=out_dtype, out_dtype=out_dtype,
bias=azp_bias[0, :]) bias=azp_bias[0, :])
assert torch.allclose(out, baseline_dq, rtol=1e-2, atol=1e0) torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
assert torch.allclose(out, baseline_q, rtol=1e-2, atol=1e0) torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
@pytest.mark.parametrize("m", [32, 64, 128]) @pytest.mark.parametrize("m", [32, 64, 128])
...@@ -302,7 +302,10 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, ...@@ -302,7 +302,10 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32) a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
assert torch.allclose(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3) torch.testing.assert_close(a_dq,
scale_a * aq_f32 - azp_a,
rtol=1e-4,
atol=1e-3)
if use_bias: if use_bias:
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5 bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
...@@ -335,8 +338,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype, ...@@ -335,8 +338,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05% # float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3 rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
atol = 1e-3 atol = 1e-3
assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol) torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol) torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
# Test working with a subset of A and B # Test working with a subset of A and B
...@@ -363,7 +366,7 @@ def test_cutlass_subset(): ...@@ -363,7 +366,7 @@ def test_cutlass_subset():
scale_b, scale_b,
out_dtype=torch.bfloat16) out_dtype=torch.bfloat16)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
# Test to make sure cuda graphs work # Test to make sure cuda graphs work
...@@ -411,4 +414,4 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool): ...@@ -411,4 +414,4 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
baseline = torch.mm(scale_a * a.to(dtype=torch.float32), baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16) scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0) torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
...@@ -126,7 +126,7 @@ def test_flash_attn_with_paged_kv( ...@@ -126,7 +126,7 @@ def test_flash_attn_with_paged_kv(
scale=scale, scale=scale,
soft_cap=soft_cap, soft_cap=soft_cap,
) )
assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -211,5 +211,5 @@ def test_varlen_with_paged_kv( ...@@ -211,5 +211,5 @@ def test_varlen_with_paged_kv(
sliding_window=sliding_window, sliding_window=sliding_window,
soft_cap=soft_cap, soft_cap=soft_cap,
) )
assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -144,7 +144,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], ...@@ -144,7 +144,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
soft_cap=soft_cap) soft_cap=soft_cap)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -244,5 +244,5 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], ...@@ -244,5 +244,5 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
block_tables=block_tables, block_tables=block_tables,
scale=scale, scale=scale,
soft_cap=soft_cap) soft_cap=soft_cap)
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}" f"{torch.max(torch.abs(output - ref_output))}"
...@@ -37,8 +37,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, ...@@ -37,8 +37,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
scale_ub=scale_ub, scale_ub=scale_ub,
use_per_token_if_dynamic=True) use_per_token_if_dynamic=True)
assert torch.allclose(ref_scales, ops_scales) torch.testing.assert_close(ref_scales, ops_scales)
assert torch.allclose(ref_out.to(dtype=torch.float32), torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32)) ops_out.to(dtype=torch.float32))
...@@ -57,8 +57,8 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int, ...@@ -57,8 +57,8 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x) ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
ops_out, ops_scale = ops.scaled_fp8_quant(x) ops_out, ops_scale = ops.scaled_fp8_quant(x)
assert torch.allclose(ref_scale, ops_scale) torch.testing.assert_close(ref_scale, ops_scale)
assert torch.allclose(ref_out.to(dtype=torch.float32), torch.testing.assert_close(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32)) ops_out.to(dtype=torch.float32))
...@@ -84,4 +84,4 @@ def test_fp8_quant_large(seed: int) -> None: ...@@ -84,4 +84,4 @@ def test_fp8_quant_large(seed: int) -> None:
ref_out = ref_out.to(dtype=dtype) ref_out = ref_out.to(dtype=dtype)
ops_out = ops_out.to(dtype=dtype) ops_out = ops_out.to(dtype=dtype)
assert torch.allclose(ref_out, ops_out) torch.testing.assert_close(ref_out, ops_out)
...@@ -29,9 +29,10 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int, ...@@ -29,9 +29,10 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
# kernel # kernel
ops_out, ops_scales = scaled_int8_quant(x) ops_out, ops_scales = scaled_int8_quant(x)
assert torch.allclose(ops_scales, ref_scales) torch.testing.assert_close(ops_scales, ref_scales)
assert torch.allclose(ops_out, ref_out, torch.testing.assert_close(
atol=1) # big atol to account for rounding errors ops_out, ref_out, atol=1,
rtol=0.0) # big atol to account for rounding errors
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
...@@ -54,5 +55,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int, ...@@ -54,5 +55,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
int8_traits.max).to(torch.int8) int8_traits.max).to(torch.int8)
out2, _ = scaled_int8_quant(x, scale) out2, _ = scaled_int8_quant(x, scale)
assert torch.allclose(out1, out2, torch.testing.assert_close(
atol=1) # big atol to account for rounding errors out1, out2, atol=1,
rtol=0.0) # big atol to account for rounding errors
...@@ -48,7 +48,7 @@ def test_rms_norm( ...@@ -48,7 +48,7 @@ def test_rms_norm(
# numerical errors than other operators because they involve reductions. # numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance. # Therefore, we use a larger tolerance.
if add_residual: if add_residual:
assert torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2) torch.testing.assert_close(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
assert torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2) torch.testing.assert_close(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
else: else:
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2) torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
...@@ -122,7 +122,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, ...@@ -122,7 +122,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
assert torch.allclose(marlin_q_w_1, marlin_q_w_2) torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
...@@ -174,7 +174,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, ...@@ -174,7 +174,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
assert torch.allclose(marlin_q_w_1, marlin_q_w_2) torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
......
...@@ -50,7 +50,7 @@ def test_fused_moe( ...@@ -50,7 +50,7 @@ def test_fused_moe(
score = torch.randn((m, e), device='cuda', dtype=dtype) score = torch.randn((m, e), device='cuda', dtype=dtype)
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk) torch_output = torch_moe(a, w1, w2, score, topk)
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
@pytest.mark.parametrize("dtype", @pytest.mark.parametrize("dtype",
...@@ -95,7 +95,7 @@ def test_mixtral_moe(dtype: torch.dtype): ...@@ -95,7 +95,7 @@ def test_mixtral_moe(dtype: torch.dtype):
torch.bfloat16: 1e-2, torch.bfloat16: 1e-2,
} }
assert torch.allclose(hf_states.flatten(0, 1), torch.testing.assert_close(hf_states.flatten(0, 1),
vllm_states, vllm_states,
rtol=mixtral_moe_tol[dtype], rtol=mixtral_moe_tol[dtype],
atol=mixtral_moe_tol[dtype]) atol=mixtral_moe_tol[dtype])
...@@ -67,11 +67,11 @@ def test_rotary_embedding( ...@@ -67,11 +67,11 @@ def test_rotary_embedding(
ref_query, ref_key = rope.forward_native(positions, query, key) ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions, query, key) out_query, out_key = rope.forward(positions, query, key)
# Compare the results. # Compare the results.
assert torch.allclose(out_query, torch.testing.assert_close(out_query,
ref_query, ref_query,
atol=get_default_atol(out_query), atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query)) rtol=get_default_rtol(out_query))
assert torch.allclose(out_key, torch.testing.assert_close(out_key,
ref_key, ref_key,
atol=get_default_atol(out_key), atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key)) rtol=get_default_rtol(out_key))
...@@ -129,11 +129,11 @@ def test_batched_rotary_embedding( ...@@ -129,11 +129,11 @@ def test_batched_rotary_embedding(
dtype=torch.long, dtype=torch.long,
device=device)) device=device))
# Compare the results. # Compare the results.
assert torch.allclose(out_query, torch.testing.assert_close(out_query,
ref_query, ref_query,
atol=get_default_atol(out_query), atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query)) rtol=get_default_rtol(out_query))
assert torch.allclose(out_key, torch.testing.assert_close(out_key,
ref_key, ref_key,
atol=get_default_atol(out_key), atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key)) rtol=get_default_rtol(out_key))
...@@ -200,11 +200,11 @@ def test_batched_rotary_embedding_multi_lora( ...@@ -200,11 +200,11 @@ def test_batched_rotary_embedding_multi_lora(
out_query, out_key = rope.forward(positions, query, key, out_query, out_key = rope.forward(positions, query, key,
query_offsets.flatten()) query_offsets.flatten())
# Compare the results. # Compare the results.
assert torch.allclose(out_query, torch.testing.assert_close(out_query,
ref_query, ref_query,
atol=get_default_atol(out_query), atol=get_default_atol(out_query),
rtol=get_default_rtol(out_query)) rtol=get_default_rtol(out_query))
assert torch.allclose(out_key, torch.testing.assert_close(out_key,
ref_key, ref_key,
atol=get_default_atol(out_key), atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key)) rtol=get_default_rtol(out_key))
......
...@@ -100,11 +100,11 @@ def test_sample_decoding_only(random_sampling, max_best_of, ...@@ -100,11 +100,11 @@ def test_sample_decoding_only(random_sampling, max_best_of,
if modify_greedy_probs and not request_uses_random_sampling: if modify_greedy_probs and not request_uses_random_sampling:
# If we are modifying greedy probs and the request is greedy, # If we are modifying greedy probs and the request is greedy,
# we want to make sure the probs tensor is modified in place # we want to make sure the probs tensor is modified in place
assert torch.allclose( torch.testing.assert_close(
probs[i][sampled_tokens[i]], probs[i][sampled_tokens[i]],
torch.full_like(probs[i][sampled_tokens[i]], 1.0)) torch.full_like(probs[i][sampled_tokens[i]], 1.0))
assert torch.sum(probs[i]) == 1.0 assert torch.sum(probs[i]) == 1.0
assert torch.allclose( torch.testing.assert_close(
sampled_modified_probs[i][0], sampled_modified_probs[i][0],
torch.full_like(sampled_modified_probs[i][0], 1.0)) torch.full_like(sampled_modified_probs[i][0], 1.0))
elif request_uses_random_sampling: elif request_uses_random_sampling:
...@@ -117,7 +117,7 @@ def test_sample_decoding_only(random_sampling, max_best_of, ...@@ -117,7 +117,7 @@ def test_sample_decoding_only(random_sampling, max_best_of,
# If the request is greedy and we are not modifying greedy probs, # If the request is greedy and we are not modifying greedy probs,
# we want to make sure sampled_modified_probs tensor is the same as # we want to make sure sampled_modified_probs tensor is the same as
# the probs tensor. # the probs tensor.
assert torch.allclose(sampled_modified_probs[i][0], torch.testing.assert_close(sampled_modified_probs[i],
probs[i][sampled_tokens[i]]) probs[i][sampled_tokens[i]])
if save_logprobs: if save_logprobs:
......
...@@ -924,5 +924,5 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters, ...@@ -924,5 +924,5 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
* output_under_test: actually observed output value * output_under_test: actually observed output value
''' '''
ideal_output = test_params.packed_qkvo.ideal_output ideal_output = test_params.packed_qkvo.ideal_output
assert torch.allclose(ideal_output, torch.testing.assert_close(ideal_output,
output_under_test.view_as(ideal_output)) output_under_test.view_as(ideal_output))
...@@ -247,7 +247,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: ...@@ -247,7 +247,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype] rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result, torch.testing.assert_close(lora_result,
expected_result, expected_result,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
...@@ -274,7 +274,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None: ...@@ -274,7 +274,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
expected_result = embedding(torch.cat(inputs)) expected_result = embedding(torch.cat(inputs))
rtol, atol = TOLERANCES[lora_result.dtype] rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result, torch.testing.assert_close(lora_result,
expected_result, expected_result,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
...@@ -384,7 +384,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -384,7 +384,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype] rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result, torch.testing.assert_close(lora_result,
expected_result, expected_result,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
...@@ -411,7 +411,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device, ...@@ -411,7 +411,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
expected_result = expanded_embedding(torch.cat(inputs)) expected_result = expanded_embedding(torch.cat(inputs))
rtol, atol = TOLERANCES[lora_result.dtype] rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result, torch.testing.assert_close(lora_result,
expected_result, expected_result,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
...@@ -541,7 +541,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size, ...@@ -541,7 +541,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
embedding_bias=None) embedding_bias=None)
rtol, atol = TOLERANCES[lora_result.dtype] rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result, torch.testing.assert_close(lora_result,
expected_result, expected_result,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
...@@ -614,7 +614,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None: ...@@ -614,7 +614,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype] rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result, torch.testing.assert_close(lora_result,
expected_result, expected_result,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
...@@ -642,7 +642,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None: ...@@ -642,7 +642,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
expected_result = linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0]
rtol, atol = TOLERANCES[lora_result.dtype] rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result, torch.testing.assert_close(lora_result,
expected_result, expected_result,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
...@@ -728,7 +728,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -728,7 +728,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype] rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result, torch.testing.assert_close(lora_result,
expected_result, expected_result,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
...@@ -756,7 +756,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, ...@@ -756,7 +756,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
expected_result = linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0]
rtol, atol = TOLERANCES[lora_result.dtype] rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result, torch.testing.assert_close(lora_result,
expected_result, expected_result,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
...@@ -868,7 +868,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -868,7 +868,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
expected_result = torch.cat(expected_results) expected_result = torch.cat(expected_results)
rtol, atol = TOLERANCES[lora_result.dtype] rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result, torch.testing.assert_close(lora_result,
expected_result, expected_result,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
...@@ -900,7 +900,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, ...@@ -900,7 +900,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
expected_result = linear(torch.cat(inputs))[0] expected_result = linear(torch.cat(inputs))[0]
rtol, atol = TOLERANCES[lora_result.dtype] rtol, atol = TOLERANCES[lora_result.dtype]
assert torch.allclose(lora_result, torch.testing.assert_close(lora_result,
expected_result, expected_result,
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
......
...@@ -533,13 +533,13 @@ def test_packed_loras(dist_init, dummy_model_gate_up): ...@@ -533,13 +533,13 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
packed_lora = model_lora.get_lora("gate_up_proj") packed_lora = model_lora.get_lora("gate_up_proj")
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
assert torch.allclose(packed_lora.lora_a[0], torch.testing.assert_close(packed_lora.lora_a[0],
model_lora.get_lora("gate_proj").lora_a) model_lora.get_lora("gate_proj").lora_a)
assert torch.allclose(packed_lora.lora_b[0], torch.testing.assert_close(packed_lora.lora_b[0],
model_lora.get_lora("gate_proj").lora_b) model_lora.get_lora("gate_proj").lora_b)
assert torch.allclose(packed_lora.lora_a[1], torch.testing.assert_close(packed_lora.lora_a[1],
model_lora.get_lora("up_proj").lora_a) model_lora.get_lora("up_proj").lora_a)
assert torch.allclose(packed_lora.lora_b[1], torch.testing.assert_close(packed_lora.lora_b[1],
model_lora.get_lora("up_proj").lora_b) model_lora.get_lora("up_proj").lora_b)
packed_lora1 = model_lora1.get_lora("gate_up_proj") packed_lora1 = model_lora1.get_lora("gate_up_proj")
...@@ -547,7 +547,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up): ...@@ -547,7 +547,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
assert packed_lora1.lora_a[0] is None assert packed_lora1.lora_a[0] is None
assert packed_lora1.lora_b[0] is None assert packed_lora1.lora_b[0] is None
assert torch.allclose(packed_lora1.lora_a[1], torch.testing.assert_close(packed_lora1.lora_a[1],
model_lora1.get_lora("up_proj").lora_a) model_lora1.get_lora("up_proj").lora_a)
assert torch.allclose(packed_lora1.lora_b[1], torch.testing.assert_close(packed_lora1.lora_b[1],
model_lora1.get_lora("up_proj").lora_b) model_lora1.get_lora("up_proj").lora_b)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment