"serialization/tests/TestSerializationNode.cpp" did not exist on "85da5e0f9018719e95350c76b483a2160c84d9d3"
Commit e18f726a authored by lijian6's avatar lijian6
Browse files

1. Fix ll mode 256 experts err.


2. Add internode ll mode.
3. Add test internode ll mode.
Signed-off-by: lijian6's avatarlijian <lijian6@sugon.com>
parent bff923d5
...@@ -12,7 +12,8 @@ export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx ...@@ -12,7 +12,8 @@ export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9 export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240 # export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export ROCSHMEM_HEAP_SIZE=2880100992 export ROCSHMEM_HEAP_SIZE=10737418240
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=0 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py --test-ll-compatibility
...@@ -12,7 +12,8 @@ export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx ...@@ -12,7 +12,8 @@ export UCX_NET_DEVICES=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx
export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9 export ROCSHMEM_ALLOWED_IBV_DEVICES=mlx5_2,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240 # export ROCSHMEM_HEAP_SIZE=536870912 805306368 10737418240
export ROCSHMEM_HEAP_SIZE=2880100992 export ROCSHMEM_HEAP_SIZE=10737418240
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py
# torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py # torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_low_latency_new.py
torchrun --nproc-per-node=1 --nnodes=2 --node-rank=1 --master-addr="10.16.1.37" --master-port=1234 tests/test_internode.py --test-ll-compatibility
...@@ -93,10 +93,9 @@ __forceinline__ __device__ void ...@@ -93,10 +93,9 @@ __forceinline__ __device__ void
nvshmem_barrier_with_same_gpu_idx(const rocshmem::rocshmem_team_t &rdma_team) { nvshmem_barrier_with_same_gpu_idx(const rocshmem::rocshmem_team_t &rdma_team) {
// NOTE: shmem_device_barrier_all() might be an issue as // NOTE: shmem_device_barrier_all() might be an issue as
// it doesn't follow OpenSHMEM specification on ROCm // it doesn't follow OpenSHMEM specification on ROCm
// kLowLatencyMode kLowLatencyMode
// ? void(rocshmem::rocshmem_ctx_barrier(rocshmem::ROCSHMEM_CTX_DEFAULT, rdma_team)) ? void(rocshmem::rocshmem_ctx_barrier(rocshmem::ROCSHMEM_CTX_DEFAULT, rdma_team))
// : rocshmem::rocshmem_barrier_all(); : rocshmem::rocshmem_barrier_all();
rocshmem::rocshmem_barrier_all();
} }
template <bool kLowLatencyMode, int kNumRDMARanks> template <bool kLowLatencyMode, int kNumRDMARanks>
......
...@@ -486,9 +486,7 @@ combine(void* combined_x, ...@@ -486,9 +486,7 @@ combine(void* combined_x,
// Message package // Message package
EP_STATIC_ASSERT(kHidden % FP8_QUANTIZATION_NUM_PER_CHANNEL == 0, "Invalid hidden"); EP_STATIC_ASSERT(kHidden % FP8_QUANTIZATION_NUM_PER_CHANNEL == 0, "Invalid hidden");
constexpr int kNumDivisions = kHidden / FP8_QUANTIZATION_NUM_PER_CHANNEL; constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16);
constexpr int kNumMetaBytes = kNumDivisions * sizeof(float);
constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(hip_bfloat16) + kNumMetaBytes;
EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization");
// 16 is the max possible number of warps in AMD GPUs // 16 is the max possible number of warps in AMD GPUs
......
...@@ -260,18 +260,21 @@ def get_hidden_bytes(args: argparse.Namespace) -> int: ...@@ -260,18 +260,21 @@ def get_hidden_bytes(args: argparse.Namespace) -> int:
def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
num_nodes = int(os.getenv('WORLD_SIZE', 1)) num_nodes = int(os.getenv('WORLD_SIZE', 1))
rank, num_ranks, group = init_dist(local_rank, num_local_ranks) rank, num_ranks, group = init_dist(local_rank, num_local_ranks)
num_rdma_bytes_ll = 0
if args.test_ll_compatibility: if args.test_ll_compatibility:
ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9
num_rdma_bytes_ll = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts)
num_sms = 30 num_sms = 30
num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0) num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0)
hidden_bytes = get_hidden_bytes(args) hidden_bytes = get_hidden_bytes(args)
num_nvl_bytes, num_rdma_bytes = 0, 0 num_nvl_bytes, num_rdma_bytes, num_rdma_bytes_norm = 0, 0, 0
for config in (deep_ep.Buffer.get_dispatch_config(group.size()), deep_ep.Buffer.get_combine_config(group.size())): for config in (deep_ep.Buffer.get_dispatch_config(group.size()), deep_ep.Buffer.get_combine_config(group.size())):
num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes) num_nvl_bytes = max(config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes)
num_rdma_bytes = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes) num_rdma_bytes_norm = max(config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes)
num_rdma_bytes = max(num_rdma_bytes_norm, num_rdma_bytes_ll)
buffer = deep_ep.Buffer(group, num_nvl_bytes, num_rdma_bytes, low_latency_mode=args.test_ll_compatibility, buffer = deep_ep.Buffer(group, num_nvl_bytes, num_rdma_bytes, low_latency_mode=args.test_ll_compatibility,
num_qps_per_rank=num_qps_per_rank, explicitly_destroy=True) num_qps_per_rank=num_qps_per_rank, explicitly_destroy=True)
assert num_local_ranks == 8 and num_ranks > 8 assert num_local_ranks == 8 and num_ranks > 8
......
...@@ -163,6 +163,10 @@ def test_loop(local_rank: int, num_local_ranks: int): ...@@ -163,6 +163,10 @@ def test_loop(local_rank: int, num_local_ranks: int):
for i in range(20): for i in range(20):
assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}' assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}'
buffer.destroy()
dist.barrier()
dist.destroy_process_group()
if __name__ == '__main__': if __name__ == '__main__':
print("main start...") print("main start...")
# TODO: you may modify NUMA binding for less CPU overhead # TODO: you may modify NUMA binding for less CPU overhead
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment