Unverified Commit a2fa3b73 authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

Use TMA to optimize internode dispatch. (#276)



* Add TMA buffer allocation

* Use TMA for forwarders and NVL receivers

* Use lane 31 to operate TMA.

* Change rdma buffer layout.

* Use TMA to transfer scales also.

* Increase the NVL recv buffer size.

* Disable early stopping.

* Apply similar optimizations on receiver warps.

* Prevent warp divergence.

* Disable aggressive ptx by default.

* Revert using TMA to transfer scales.

* Format.

* Change the layout of dispatch NVL buffer.

* Move topk transformation to recv warps.

* Use TMA to transfer all data in foward warps

* Use TMA to store scales.

* Code lint

---------
Co-authored-by: default avatarChenggang Zhao <chenggangz@deepseek.com>
parent 7705f533
This diff is collapsed.
...@@ -55,7 +55,7 @@ if __name__ == '__main__': ...@@ -55,7 +55,7 @@ if __name__ == '__main__':
os.environ['DISABLE_AGGRESSIVE_PTX_INSTRS'] = '1' os.environ['DISABLE_AGGRESSIVE_PTX_INSTRS'] = '1'
# Disable aggressive PTX instructions # Disable aggressive PTX instructions
if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '0')): if int(os.getenv('DISABLE_AGGRESSIVE_PTX_INSTRS', '1')):
cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS') cxx_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS') nvcc_flags.append('-DDISABLE_AGGRESSIVE_PTX_INSTRS')
......
...@@ -234,7 +234,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): ...@@ -234,7 +234,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_sms = 24 num_sms = 24
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0) num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=args.test_ll_compatibility, buffer = deep_ep.Buffer(group, int(2e9), int(1e9), low_latency_mode=args.test_ll_compatibility,
num_qps_per_rank=num_qps_per_rank) num_qps_per_rank=num_qps_per_rank)
assert num_local_ranks == 8 and num_ranks > 8 assert num_local_ranks == 8 and num_ranks > 8
torch.manual_seed(rank) torch.manual_seed(rank)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment