"docs/vscode:/vscode.git/clone" did not exist on "c3372e87bed990510e4ae0b39f151a34dea24f8b"
Unverified Commit dc97cc9e authored by hx's avatar hx Committed by GitHub
Browse files

[PyTorch] Optimize the performance of permute fusion kernels (#1927)



* optimize permute
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarHongxiao Bai <hongxiaob@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
parent 37da2d3b
......@@ -326,33 +326,37 @@ def _test_permutation_index_map(
te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
torch.testing.assert_close(
pytorch_permute_output.float(),
te_permute_output_,
msg=f"Mismatch in te_permute fwd",
)
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_fwd_input.grad.float(),
te_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if with_probs:
if not BENCHMARK:
torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols
pytorch_permute_output.float(),
te_permute_output_,
msg=f"Mismatch in te_permute fwd",
)
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_fwd_input.grad.float(),
te_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if with_probs:
torch.testing.assert_close(
probs.grad.float(),
te_probs.grad.float(),
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if not pytorch_permute_fwd_input.numel():
print("Empty pytorch_permute_fwd_input activation test passed.")
......@@ -538,34 +542,38 @@ def _test_permutation_mask_map(
te_unpermute_output_ = te_unpermute_output.float()
te_unpermute_fwd_input_grad = te_unpermute_fwd_input.grad.float()
torch.testing.assert_close(
pytorch_permute_output.float(),
te_permute_output_,
msg=f"Mismatch in te_permute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_fwd_input.grad.float(),
te_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if with_probs:
if not BENCHMARK:
torch.testing.assert_close(
pytorch_permute_output.float(),
te_permute_output_,
msg=f"Mismatch in te_permute fwd",
**tols,
)
torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in te_unpermute bwd", **tols
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in te_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_unpermute_fwd_input.grad.float(),
te_unpermute_fwd_input_grad,
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if with_probs:
torch.testing.assert_close(
probs.grad.float(),
te_probs.grad.float(),
msg=f"Mismatch in te_unpermute bwd",
**tols,
)
if not pytorch_permute_fwd_input.numel():
print("Empty pytorch_permute_fwd_input activation test passed.")
......@@ -827,18 +835,19 @@ def _test_moe_chunk_sort(
te_output_ = te_output.float()
te_fwd_input_grad = te_fwd_input.grad.float()
torch.testing.assert_close(
pytorch_output.float(),
te_output_,
msg=f"Mismatch in te_permute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_fwd_input.grad.float(),
te_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
if not BENCHMARK:
torch.testing.assert_close(
pytorch_output.float(),
te_output_,
msg=f"Mismatch in te_permute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_fwd_input.grad.float(),
te_fwd_input_grad,
msg=f"Mismatch in te_permute bwd",
**tols,
)
if not pytorch_fwd_input.numel():
print("Empty pytorch_fwd_input activation test passed.")
......@@ -887,6 +896,7 @@ def _test_permutation_mask_map_alongside_probs(
topK,
num_out_tokens,
tp_size,
BENCHMARK=False,
):
if topK > num_expert:
pytest.skip("topK should be smaller than the number of experts.")
......@@ -1016,21 +1026,73 @@ def _test_permutation_mask_map_alongside_probs(
te_permute_fwd_input_grad = te_permute_fwd_input.grad.float()
te_unpermute_output_ = te_unpermute_output.float()
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in fused_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in fused_permute bwd",
**tols,
)
torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols
)
if not BENCHMARK:
torch.testing.assert_close(
pytorch_unpermute_output.float(),
te_unpermute_output_,
msg=f"Mismatch in fused_unpermute fwd",
**tols,
)
torch.testing.assert_close(
pytorch_permute_fwd_input.grad.float(),
te_permute_fwd_input_grad,
msg=f"Mismatch in fused_permute bwd",
**tols,
)
torch.testing.assert_close(
probs.grad.float(), te_probs.grad.float(), msg=f"Mismatch in prob grad", **tols
)
if BENCHMARK:
t1 = perf_test_cuda_kernel(
lambda: te_permute_with_probs(
te_permute_fwd_input, te_probs, routing_map, num_out_tokens=num_out_tokens
)
)
print(f"permute\t\tfwd: TE: {t1:.3f} ms")
te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs(
te_permute_fwd_input,
te_probs,
routing_map,
num_out_tokens=num_out_tokens,
)
te_permute_bwd_input = torch.rand((num_out_tokens, hidden_size), dtype=dtype).cuda()
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
te_permute_output,
te_permute_bwd_input,
forward_input=[te_permute_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"permute\t\tbwd: TE: {t2:.3f} ms")
chunk_sort_fwd_input = te_permute_output.detach()
chunk_sort_fwd_input.requires_grad_(True)
chunk_sort_fwd_probs = te_permuted_probs.detach()
chunk_sort_fwd_probs.requires_grad_(True)
t1 = perf_test_cuda_kernel(
lambda: te_sort_chunks_by_index_with_probs(
chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda
)
)
print(f"chunk sort\t\tfwd: TE: {t1:.3f} ms")
chunk_sort_output, _ = te_sort_chunks_by_index_with_probs(
chunk_sort_fwd_input, chunk_sort_fwd_probs, split_sizes_cuda, sorted_idxs_cuda
)
t2 = perf_test_cuda_kernel(
lambda: backward_wrapper(
chunk_sort_output,
te_permute_bwd_input,
forward_input=[chunk_sort_fwd_input],
retain_graph=True,
accumulate_grad=False,
)
)
print(f"chunk sort\t\tbwd: TE: {t2:.3f} ms")
def perf_test_cuda_kernel(cuda_kernel_fn):
......@@ -1063,7 +1125,7 @@ if is_bf16_compatible():
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
......@@ -1092,7 +1154,7 @@ def test_permutation_index_map(
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
......@@ -1138,7 +1200,7 @@ def test_permutation_mask_map_empty_input(te_dtype):
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
......@@ -1193,7 +1255,7 @@ fp8_recipes = [
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.parametrize("te_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
@pytest.mark.parametrize("num_tokens", [2048])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [1, 2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
......@@ -1225,7 +1287,7 @@ def test_permutation_mask_map_fp8(
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_index_map_topk1_no_probs(
te_dtype,
......@@ -1252,7 +1314,7 @@ def test_permutation_index_map_topk1_no_probs(
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
def test_permutation_mask_map_topk1_no_probs(
te_dtype,
......@@ -1279,7 +1341,7 @@ def test_permutation_mask_map_topk1_no_probs(
@pytest.mark.parametrize("te_dtype", _te_dtypes)
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [8, 16])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("tp_size", [1, 2, 8])
@pytest.mark.parametrize("hidden_size", [4096])
def test_chunk_permutation(
......@@ -1372,5 +1434,108 @@ def test_permutation_single_case():
)
def benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
):
torch.cuda.nvtx.range_push(
f"{num_tokens}-{num_expert}-{hidden_size}-{topK}-{ep_size}-{tp_size}"
)
torch.cuda.nvtx.range_push("permutation_index_map_with_probs")
_test_permutation_index_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=True,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_with_probs")
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=True,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_without_probs")
_test_permutation_mask_map(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
with_probs=False,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_push("permutation_mask_map_alongside_probs")
_test_permutation_mask_map_alongside_probs(
te_dtype=te_dtype,
num_tokens=num_tokens,
num_expert=num_expert,
hidden_size=hidden_size,
topK=topK,
num_out_tokens=num_out_tokens,
tp_size=tp_size,
BENCHMARK=True,
)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
def benchmark_multiple_cases():
print("GPU:", torch.cuda.get_device_name(0))
# te_dtype = tex.DType.kFloat32
# te_dtype = tex.DType.kFloat16
te_dtype = tex.DType.kBFloat16
ep_size = 64
tp_size = 2
num_tokens = 4096
num_expert = 256
hidden_size = 7168
topK = 8
num_out_tokens = num_tokens * topK
benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
)
ep_size = 8
tp_size = 1
num_tokens = 8192 * 2
num_expert = 128
hidden_size = 4096
topK = 6
num_out_tokens = num_tokens * topK
benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
)
ep_size = 64
tp_size = 2
num_tokens = 16384
num_expert = 4
hidden_size = 7168
topK = 1
num_out_tokens = num_tokens * topK
benchmark_single_case(
te_dtype, num_tokens, num_expert, hidden_size, topK, num_out_tokens, ep_size, tp_size
)
if __name__ == "__main__":
test_permutation_single_case()
benchmark_multiple_cases()
......@@ -349,7 +349,7 @@ class _moe_unpermute_mask_map(torch.autograd.Function):
if restore_shape is None:
restore_shape = inp.shape
num_tokens, hidden_size = restore_shape
num_experts = row_id_map.size(0)
num_experts = (row_id_map.size(1) - 1) // 2
with_probs = merging_probs is not None
if with_probs:
......@@ -651,14 +651,20 @@ class _moe_chunk_sort(torch.autograd.Function):
fp8_scale_inv = inp._scale_inv
fake_dtype = inp.dtype
inp = inp._data
output, row_id_map, permuted_probs = triton_permutation.sort_chunks_by_idx(
inp,
row_id_map = triton_permutation.make_chunk_sort_map(
split_sizes,
sorted_idxs,
num_tokens,
num_splits,
)
output, permuted_probs = triton_permutation.sort_chunks_by_map(
inp,
row_id_map,
probs,
num_tokens,
hidden_size,
num_splits,
is_forward=True,
)
if fp8:
output = Float8Tensor(
......@@ -700,6 +706,7 @@ class _moe_chunk_sort(torch.autograd.Function):
permuted_probs_grad,
ctx.num_tokens,
ctx.hidden_size,
is_forward=False,
)
if fp8:
act_grad = Float8Tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment